diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..d30cce47e082646606c64405d3bf8464aa3145a7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/cartoon_boy.png filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/clock.jpg filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/demo/demo_this_is_omini_control.jpg filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/demo/dreambooth_res.jpg filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/demo/monalisa_omini.jpg filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/demo/scene_variation.jpg filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/demo/try_on.jpg filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/monalisa.jpg filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/rc_car.jpg filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/room_corner.jpg filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/tshirt.jpg filter=lfs diff=lfs merge=lfs -text +OminiControl/assets/vase_hq.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/OminiControl/LICENSE b/OminiControl/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7b2de994465d5d0c1f37fc6aecd065f8849ed2d8 --- /dev/null +++ b/OminiControl/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2024] [Zhenxiong Tan] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/OminiControl/README.md b/OminiControl/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6ab48029971c05775acc4146eec40ade2dad451b --- /dev/null +++ b/OminiControl/README.md @@ -0,0 +1,170 @@ +# OminiControl + + + +
+ +arXiv +HuggingFace +HuggingFace +GitHub +HuggingFace + +> **OminiControl: Minimal and Universal Control for Diffusion Transformer** +>
+> Zhenxiong Tan, +> [Songhua Liu](http://121.37.94.87/), +> [Xingyi Yang](https://adamdad.github.io/), +> Qiaochu Xue, +> and +> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/) +>
+> [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore +>
+ + +## Features + +OminiControl is a minimal yet powerful universal control framework for Diffusion Transformer models like [FLUX](https://github.com/black-forest-labs/flux). + +* **Universal Control 🌐**: A unified control framework that supports both subject-driven control and spatial control (such as edge-guided and in-painting generation). + +* **Minimal Design 🚀**: Injects control signals while preserving original model structure. Only introduces 0.1% additional parameters to the base model. + +## News +- **2024-12-26**: ⭐️ Training code are released. Now you can create your own OminiControl model by customizing any control tasks (3D, multi-view, pose-guided, try-on, etc.) with the FLUX model. Check the [training folder](./train) for more details. + +## Quick Start +### Setup (Optional) +1. **Environment setup** +```bash +conda create -n omini python=3.10 +conda activate omini +``` +2. **Requirements installation** +```bash +pip install -r requirements.txt +``` +### Usage example +1. Subject-driven generation: `examples/subject.ipynb` +2. In-painting: `examples/inpainting.ipynb` +3. Canny edge to image, depth to image, colorization, deblurring: `examples/spatial.ipynb` + +### Gradio app +To run the Gradio app for subject-driven generation: +```bash +python -m src.gradio.gradio_app +``` + +### Guidelines for subject-driven generation +1. Input images are automatically center-cropped and resized to 512x512 resolution. +2. When writing prompts, refer to the subject using phrases like `this item`, `the object`, or `it`. e.g. + 1. *A close up view of this item. It is placed on a wooden table.* + 2. *A young lady is wearing this shirt.* +3. The model primarily works with objects rather than human subjects currently, due to the absence of human data in training. + +## Generated samples +### Subject-driven generation +HuggingFace + +**Demos** (Left: condition image; Right: generated image) + +
+ + + + +
+ +
+Text Prompts + +- Prompt1: *A close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!.'* +- Prompt2: *A film style shot. On the moon, this item drives across the moon surface. A flag on it reads 'Omini'. The background is that Earth looms large in the foreground.* +- Prompt3: *In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.* +- Prompt4: *"On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple."* +
+
+More results + +* Try on: + +* Scene variations: + +* Dreambooth dataset: + +* Oye-cartoon finetune: +
+ + +
+
+ +### Spatially aligned control +1. **Image Inpainting** (Left: original image; Center: masked image; Right: filled image) + - Prompt: *The Mona Lisa is wearing a white VR headset with 'Omini' written on it.* +
+ + - Prompt: *A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.* +
+ +2. **Other spatially aligned tasks** (Canny edge to image, depth to image, colorization, deblurring) +
+
+ Click to show +
+ + + + +
+ + Prompt: *A light gray sofa stands against a white wall, featuring a black and white geometric patterned pillow. A white side table sits next to the sofa, topped with a white adjustable desk lamp and some books. Dark hardwood flooring contrasts with the pale walls and furniture.* +
+ + + + +## Models + +**Subject-driven control:** +| Model | Base model | Description | Resolution | +| ------------------------------------------------------------------------------------------------ | -------------- | -------------------------------------------------------------------------------------------------------- | ------------ | +| [`experimental`](https://huggingface.co./Yuanshi/OminiControl/tree/main/experimental) / `subject` | FLUX.1-schnell | The model used in the paper. | (512, 512) | +| [`omini`](https://huggingface.co./Yuanshi/OminiControl/tree/main/omini) / `subject_512` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset. | (512, 512) | +| [`omini`](https://huggingface.co./Yuanshi/OminiControl/tree/main/omini) / `subject_1024` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset and accommodates higher resolution. (To be released) | (1024, 1024) | +| [`oye-cartoon`](https://huggingface.co./saquiboye/oye-cartoon) | FLUX.1-dev | The model has been fine-tuned on [oye-cartoon](https://huggingface.co./datasets/saquiboye/oye-cartoon) dataset by [@saquib764](https://github.com/Saquib764) | (512, 512) | + +**Spatial aligned control:** +| Model | Base model | Description | Resolution | +| --------------------------------------------------------------------------------------------------------- | ---------- | -------------------------------------------------------------------------- | ------------ | +| [`experimental`](https://huggingface.co./Yuanshi/OminiControl/tree/main/experimental) / `` | FLUX.1 | Canny edge to image, depth to image, colorization, deblurring, in-painting | (512, 512) | +| [`experimental`](https://huggingface.co./Yuanshi/OminiControl/tree/main/experimental) / `_1024` | FLUX.1 | Supports higher resolution.(To be released) | (1024, 1024) | + +## Community Extensions +- [ComfyUI-Diffusers-OminiControl](https://github.com/Macoron/ComfyUI-Diffusers-OminiControl) - ComfyUI integration by [@Macoron](https://github.com/Macoron) +- [ComfyUI_RH_OminiControl](https://github.com/HM-RunningHub/ComfyUI_RH_OminiControl) - ComfyUI integration by [@HM-RunningHub](https://github.com/HM-RunningHub) + +## Limitations +1. The model's subject-driven generation primarily works with objects rather than human subjects due to the absence of human data in training. +2. The subject-driven generation model may not work well with `FLUX.1-dev`. +3. The released model currently only supports the resolution of 512x512. + +## Training +Training instructions can be found in this [folder](./train). + + +## To-do +- [x] Release the training code. +- [ ] Release the model for higher resolution (1024x1024). + +## Citation +``` +@article{tan2024ominicontrol, + title={Ominicontrol: Minimal and universal control for diffusion transformer}, + author={Tan, Zhenxiong and Liu, Songhua and Yang, Xingyi and Xue, Qiaochu and Wang, Xinchao}, + journal={arXiv preprint arXiv:2411.15098}, + volume={3}, + year={2024} +} +``` diff --git a/OminiControl/assets/book.jpg b/OminiControl/assets/book.jpg new file mode 100644 index 0000000000000000000000000000000000000000..da7069c305f50ecd042ba889ac1624d02206e98f Binary files /dev/null and b/OminiControl/assets/book.jpg differ diff --git a/OminiControl/assets/cartoon_boy.png b/OminiControl/assets/cartoon_boy.png new file mode 100644 index 0000000000000000000000000000000000000000..2ad7e896bf6f8f645752c2bb9545326536d4b73b --- /dev/null +++ b/OminiControl/assets/cartoon_boy.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4a82c0f9ed09b9468bded7d901beffaf29addc30ed5f72ad72451e1b6344b1c +size 428900 diff --git a/OminiControl/assets/clock.jpg b/OminiControl/assets/clock.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eeee9c0c7ee5281344c4010dcfb09cd40bf18e91 --- /dev/null +++ b/OminiControl/assets/clock.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:41235973f26152ac92d32bfc166fb5f9f1e352c5e16807920238473316ec462b +size 288877 diff --git a/OminiControl/assets/coffee.png b/OminiControl/assets/coffee.png new file mode 100644 index 0000000000000000000000000000000000000000..36ff439ef3a00b597bb587237e79a5812514dc81 Binary files /dev/null and b/OminiControl/assets/coffee.png differ diff --git a/OminiControl/assets/demo/book_omini.jpg b/OminiControl/assets/demo/book_omini.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9cadef5e92d5315ee66f07322c3608be0743868c Binary files /dev/null and b/OminiControl/assets/demo/book_omini.jpg differ diff --git a/OminiControl/assets/demo/clock_omini.jpg b/OminiControl/assets/demo/clock_omini.jpg new file mode 100644 index 0000000000000000000000000000000000000000..df0341e836eb39ea21c4ac38d143228efc6909f6 Binary files /dev/null and b/OminiControl/assets/demo/clock_omini.jpg differ diff --git a/OminiControl/assets/demo/demo_this_is_omini_control.jpg b/OminiControl/assets/demo/demo_this_is_omini_control.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5fa77c908af2db85f98b1d238914b1b92ad85b9e --- /dev/null +++ b/OminiControl/assets/demo/demo_this_is_omini_control.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:798b7c25be6be118dc0de97c444c840869afca633a0d48f99d940aec040a7518 +size 129406 diff --git a/OminiControl/assets/demo/dreambooth_res.jpg b/OminiControl/assets/demo/dreambooth_res.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4f369c03887576cb63229c2b535fd15f7e189023 --- /dev/null +++ b/OminiControl/assets/demo/dreambooth_res.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba36bd861989564dc679acf3b5e56f382f1a11b1596e6f611ea0bd7d81b89680 +size 1935753 diff --git a/OminiControl/assets/demo/man_omini.jpg b/OminiControl/assets/demo/man_omini.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cf8f54498d02e0b38914a80617ca420f81231e6c Binary files /dev/null and b/OminiControl/assets/demo/man_omini.jpg differ diff --git a/OminiControl/assets/demo/monalisa_omini.jpg b/OminiControl/assets/demo/monalisa_omini.jpg new file mode 100644 index 0000000000000000000000000000000000000000..de43c96b35ceafea510add2a9815c56f444a91b1 --- /dev/null +++ b/OminiControl/assets/demo/monalisa_omini.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5ca6c2bf44f19d216b2eb16dcc67d19f11d87220d3ee80f5e5e1ad98a5536dc +size 132717 diff --git a/OminiControl/assets/demo/oranges_omini.jpg b/OminiControl/assets/demo/oranges_omini.jpg new file mode 100644 index 0000000000000000000000000000000000000000..30554406b373eb1547d6922e8c18eef4e8efd292 Binary files /dev/null and b/OminiControl/assets/demo/oranges_omini.jpg differ diff --git a/OminiControl/assets/demo/panda_omini.jpg b/OminiControl/assets/demo/panda_omini.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e69743402e70ee47e124344f5b50fd8a2f8cd41d Binary files /dev/null and b/OminiControl/assets/demo/panda_omini.jpg differ diff --git a/OminiControl/assets/demo/penguin_omini.jpg b/OminiControl/assets/demo/penguin_omini.jpg new file mode 100644 index 0000000000000000000000000000000000000000..15726e226d88e0216f7d8605baf2742b807782a5 Binary files /dev/null and b/OminiControl/assets/demo/penguin_omini.jpg differ diff --git a/OminiControl/assets/demo/rc_car_omini.jpg b/OminiControl/assets/demo/rc_car_omini.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2cd6395e8aeb241c452b92f2655e2975b2b8a5e6 Binary files /dev/null and b/OminiControl/assets/demo/rc_car_omini.jpg differ diff --git a/OminiControl/assets/demo/room_corner_canny.jpg b/OminiControl/assets/demo/room_corner_canny.jpg new file mode 100644 index 0000000000000000000000000000000000000000..295943cec832f5948283377ad68ed486e00bdac5 Binary files /dev/null and b/OminiControl/assets/demo/room_corner_canny.jpg differ diff --git a/OminiControl/assets/demo/room_corner_coloring.jpg b/OminiControl/assets/demo/room_corner_coloring.jpg new file mode 100644 index 0000000000000000000000000000000000000000..087c5fb8912378850b3c1d745f591f7809f08c85 Binary files /dev/null and b/OminiControl/assets/demo/room_corner_coloring.jpg differ diff --git a/OminiControl/assets/demo/room_corner_deblurring.jpg b/OminiControl/assets/demo/room_corner_deblurring.jpg new file mode 100644 index 0000000000000000000000000000000000000000..29ef78bb776f4a1fcc0d7a77435a32d40ec2fef6 Binary files /dev/null and b/OminiControl/assets/demo/room_corner_deblurring.jpg differ diff --git a/OminiControl/assets/demo/room_corner_depth.jpg b/OminiControl/assets/demo/room_corner_depth.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d385433fc25795ad8c1a10ebaf518a3cfd1bbab0 Binary files /dev/null and b/OminiControl/assets/demo/room_corner_depth.jpg differ diff --git a/OminiControl/assets/demo/scene_variation.jpg b/OminiControl/assets/demo/scene_variation.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2b84d8686cba903b5178a5b555a6fdb389aa89b1 --- /dev/null +++ b/OminiControl/assets/demo/scene_variation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39e4e16d2eeb58b3775b6d34c8b3e125d0d19cc36fa90b07c6c8d57624ad4333 +size 958333 diff --git a/OminiControl/assets/demo/shirt_omini.jpg b/OminiControl/assets/demo/shirt_omini.jpg new file mode 100644 index 0000000000000000000000000000000000000000..221880750404583c0db0c308ae99b82d7bcc5127 Binary files /dev/null and b/OminiControl/assets/demo/shirt_omini.jpg differ diff --git a/OminiControl/assets/demo/try_on.jpg b/OminiControl/assets/demo/try_on.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f34aec443f2932ebe5d59fc321b34f6b61232bbc --- /dev/null +++ b/OminiControl/assets/demo/try_on.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6adce5194329a83f0109b4375e00667c341879e64fb55831c70ea3f3b2f99f7e +size 774269 diff --git a/OminiControl/assets/monalisa.jpg b/OminiControl/assets/monalisa.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dbfd24ae1941ae2f44179aab6d17e1e79e978ea9 --- /dev/null +++ b/OminiControl/assets/monalisa.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:188b8b6499e4541f9dfef2a9daf6f1eb920079c9208f587fd97566d6aa4a9719 +size 352759 diff --git a/OminiControl/assets/oranges.jpg b/OminiControl/assets/oranges.jpg new file mode 100644 index 0000000000000000000000000000000000000000..19e38651ebac7633d4b67efd1bf1f0a111e4bd6a Binary files /dev/null and b/OminiControl/assets/oranges.jpg differ diff --git a/OminiControl/assets/penguin.jpg b/OminiControl/assets/penguin.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9b98535482e4bfe6e1febb64d595ef167a404d63 Binary files /dev/null and b/OminiControl/assets/penguin.jpg differ diff --git a/OminiControl/assets/rc_car.jpg b/OminiControl/assets/rc_car.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0f31728c3f1164180ca58a84c6b90a715e349a64 --- /dev/null +++ b/OminiControl/assets/rc_car.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae8aed11029fa3b084deb286c07a8cab5056840c9c123816fe2b504e94233e95 +size 254155 diff --git a/OminiControl/assets/room_corner.jpg b/OminiControl/assets/room_corner.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e16dea6be5f92137c5d82e8469a7ae475e25f8f2 --- /dev/null +++ b/OminiControl/assets/room_corner.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f97bd63df05f5f15ad5dd1a2ccef803e74e12caadd8fe145493fd6d5219045e7 +size 236101 diff --git a/OminiControl/assets/test_in.jpg b/OminiControl/assets/test_in.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4ecce80b5cbc8fc6754279731207f6c7bcc1b3ec Binary files /dev/null and b/OminiControl/assets/test_in.jpg differ diff --git a/OminiControl/assets/test_out.jpg b/OminiControl/assets/test_out.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9b98535482e4bfe6e1febb64d595ef167a404d63 Binary files /dev/null and b/OminiControl/assets/test_out.jpg differ diff --git a/OminiControl/assets/tshirt.jpg b/OminiControl/assets/tshirt.jpg new file mode 100644 index 0000000000000000000000000000000000000000..76c495a4da1f0daeffe660cbb4dba602e5c83233 --- /dev/null +++ b/OminiControl/assets/tshirt.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb1803315765302113a9e7a64dedd4ecba2672028cf093cbc33ef2edd2247c39 +size 301252 diff --git a/OminiControl/assets/vase.jpg b/OminiControl/assets/vase.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7acb40eca66987b29cafb3a8d4c0a9207dcc5bf8 Binary files /dev/null and b/OminiControl/assets/vase.jpg differ diff --git a/OminiControl/assets/vase_hq.jpg b/OminiControl/assets/vase_hq.jpg new file mode 100644 index 0000000000000000000000000000000000000000..378d2487450ec04b1981556bf6dadc344774eaad --- /dev/null +++ b/OminiControl/assets/vase_hq.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:279905e32116792f118802d23b0d96629d98ccbdac9e704e65eaf2e98c752679 +size 2901712 diff --git a/OminiControl/examples/inpainting.ipynb b/OminiControl/examples/inpainting.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2a48abbf702bb2bf272e93810a220130dd529c13 --- /dev/null +++ b/OminiControl/examples/inpainting.ipynb @@ -0,0 +1,143 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.chdir(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from diffusers.pipelines import FluxPipeline\n", + "from src.flux.condition import Condition\n", + "from PIL import Image\n", + "\n", + "from src.flux.generate import generate, seed_everything" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe = FluxPipeline.from_pretrained(\n", + " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n", + ")\n", + "pipe = pipe.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe.load_lora_weights(\n", + " \"Yuanshi/OminiControl\",\n", + " weight_name=f\"experimental/fill.safetensors\",\n", + " adapter_name=\"fill\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/monalisa.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "masked_image = image.copy()\n", + "masked_image.paste((0, 0, 0), (128, 100, 384, 220))\n", + "\n", + "condition = Condition(\"fill\", masked_image)\n", + "\n", + "seed_everything()\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=\"The Mona Lisa is wearing a white VR headset with 'Omini' written on it.\",\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/book.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "w, h, min_dim = image.size + (min(image.size),)\n", + "image = image.crop(\n", + " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n", + ").resize((512, 512))\n", + "\n", + "\n", + "masked_image = image.copy()\n", + "masked_image.paste((0, 0, 0), (150, 150, 350, 250))\n", + "masked_image.paste((0, 0, 0), (200, 380, 320, 420))\n", + "\n", + "condition = Condition(\"fill\", masked_image)\n", + "\n", + "seed_everything()\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=\"A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.\",\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/OminiControl/examples/spatial.ipynb b/OminiControl/examples/spatial.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e9a918717884500f8176e731e59af2429ae2fdc0 --- /dev/null +++ b/OminiControl/examples/spatial.ipynb @@ -0,0 +1,184 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.chdir(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from diffusers.pipelines import FluxPipeline\n", + "from src.flux.condition import Condition\n", + "from PIL import Image\n", + "\n", + "from src.flux.generate import generate, seed_everything" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe = FluxPipeline.from_pretrained(\n", + " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n", + ")\n", + "pipe = pipe.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for condition_type in [\"canny\", \"depth\", \"coloring\", \"deblurring\"]:\n", + " pipe.load_lora_weights(\n", + " \"Yuanshi/OminiControl\",\n", + " weight_name=f\"experimental/{condition_type}.safetensors\",\n", + " adapter_name=condition_type,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/coffee.png\").convert(\"RGB\")\n", + "\n", + "w, h, min_dim = image.size + (min(image.size),)\n", + "image = image.crop(\n", + " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n", + ").resize((512, 512))\n", + "\n", + "prompt = \"In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table.\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "condition = Condition(\"canny\", image)\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "condition = Condition(\"depth\", image)\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "condition = Condition(\"deblurring\", image)\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "condition = Condition(\"coloring\", image)\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/OminiControl/examples/subject.ipynb b/OminiControl/examples/subject.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..249d0773ebae95a02d38100ca767bb64115fe5dc --- /dev/null +++ b/OminiControl/examples/subject.ipynb @@ -0,0 +1,214 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.chdir(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from diffusers.pipelines import FluxPipeline\n", + "from src.flux.condition import Condition\n", + "from PIL import Image\n", + "\n", + "from src.flux.generate import generate, seed_everything" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe = FluxPipeline.from_pretrained(\n", + " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n", + ")\n", + "pipe = pipe.to(\"cuda\")\n", + "pipe.load_lora_weights(\n", + " \"Yuanshi/OminiControl\",\n", + " weight_name=f\"omini/subject_512.safetensors\",\n", + " adapter_name=\"subject\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(\"subject\", image, position_delta=(0, 32))\n", + "\n", + "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n", + "\n", + "\n", + "seed_everything(0)\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=512,\n", + " width=512,\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(\"subject\", image, position_delta=(0, 32))\n", + "\n", + "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n", + "\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=512,\n", + " width=512,\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024, 512))\n", + "concat_image.paste(condition.condition, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(\"subject\", image, position_delta=(0, 32))\n", + "\n", + "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=512,\n", + " width=512,\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024, 512))\n", + "concat_image.paste(condition.condition, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(\"subject\", image, position_delta=(0, 32))\n", + "\n", + "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=512,\n", + " width=512,\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024, 512))\n", + "concat_image.paste(condition.condition, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(\"subject\", image, position_delta=(0, 32))\n", + "\n", + "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=512,\n", + " width=512,\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024, 512))\n", + "concat_image.paste(condition.condition, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/OminiControl/examples/subject_1024.ipynb b/OminiControl/examples/subject_1024.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d0f9de9e7d86dc9aa455b0b162c93ca764f346b3 --- /dev/null +++ b/OminiControl/examples/subject_1024.ipynb @@ -0,0 +1,221 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.chdir(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from diffusers.pipelines import FluxPipeline\n", + "from src.flux.condition import Condition\n", + "from PIL import Image\n", + "\n", + "from src.flux.generate import generate, seed_everything" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe = FluxPipeline.from_pretrained(\n", + " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n", + ")\n", + "pipe = pipe.to(\"cuda\")\n", + "pipe.load_lora_weights(\n", + " \"Yuanshi/OminiControl\",\n", + " weight_name=f\"omini/subject_1024_beta.safetensors\",\n", + " adapter_name=\"subject\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(\"subject\", image)\n", + "\n", + "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n", + "\n", + "\n", + "seed_everything(0)\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=1024,\n", + " width=1024,\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(\"subject\", image)\n", + "\n", + "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n", + "\n", + "\n", + "seed_everything(0)\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=1024,\n", + " width=1024,\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(\"subject\", image)\n", + "\n", + "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=1024,\n", + " width=1024,\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(\"subject\", image)\n", + "\n", + "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n", + "\n", + "seed_everything(0)\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=1024,\n", + " width=1024,\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(\"subject\", image)\n", + "\n", + "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=1024,\n", + " width=1024,\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/OminiControl/requirements.txt b/OminiControl/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6d0bc83f27098b2cf17e27de67797092ff6bbee1 --- /dev/null +++ b/OminiControl/requirements.txt @@ -0,0 +1,9 @@ +transformers +diffusers +peft +opencv-python +protobuf +sentencepiece +gradio +jupyter +torchao \ No newline at end of file diff --git a/OminiControl/src/flux/block.py b/OminiControl/src/flux/block.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9123abca755a4d66ac537260aaf24df2fd04e6 --- /dev/null +++ b/OminiControl/src/flux/block.py @@ -0,0 +1,339 @@ +import torch +from typing import List, Union, Optional, Dict, Any, Callable +from diffusers.models.attention_processor import Attention, F +from .lora_controller import enable_lora + + +def attn_forward( + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + condition_latents: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + cond_rotary_emb: Optional[torch.Tensor] = None, + model_config: Optional[Dict[str, Any]] = {}, +) -> torch.FloatTensor: + batch_size, _, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + with enable_lora( + (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False) + ): + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q( + encoder_hidden_states_query_proj + ) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k( + encoder_hidden_states_key_proj + ) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if condition_latents is not None: + cond_query = attn.to_q(condition_latents) + cond_key = attn.to_k(condition_latents) + cond_value = attn.to_v(condition_latents) + + cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose( + 1, 2 + ) + cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose( + 1, 2 + ) + if attn.norm_q is not None: + cond_query = attn.norm_q(cond_query) + if attn.norm_k is not None: + cond_key = attn.norm_k(cond_key) + + if cond_rotary_emb is not None: + cond_query = apply_rotary_emb(cond_query, cond_rotary_emb) + cond_key = apply_rotary_emb(cond_key, cond_rotary_emb) + + if condition_latents is not None: + query = torch.cat([query, cond_query], dim=2) + key = torch.cat([key, cond_key], dim=2) + value = torch.cat([value, cond_value], dim=2) + + if not model_config.get("union_cond_attn", True): + # If we don't want to use the union condition attention, we need to mask the attention + # between the hidden states and the condition latents + attention_mask = torch.ones( + query.shape[2], key.shape[2], device=query.device, dtype=torch.bool + ) + condition_n = cond_query.shape[2] + attention_mask[-condition_n:, :-condition_n] = False + attention_mask[:-condition_n, -condition_n:] = False + elif model_config.get("independent_condition", False): + attention_mask = torch.ones( + query.shape[2], key.shape[2], device=query.device, dtype=torch.bool + ) + condition_n = cond_query.shape[2] + attention_mask[-condition_n:, :-condition_n] = False + if hasattr(attn, "c_factor"): + attention_mask = torch.zeros( + query.shape[2], key.shape[2], device=query.device, dtype=query.dtype + ) + condition_n = cond_query.shape[2] + bias = torch.log(attn.c_factor[0]) + attention_mask[-condition_n:, :-condition_n] = bias + attention_mask[:-condition_n, -condition_n:] = bias + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + if condition_latents is not None: + encoder_hidden_states, hidden_states, condition_latents = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[ + :, encoder_hidden_states.shape[1] : -condition_latents.shape[1] + ], + hidden_states[:, -condition_latents.shape[1] :], + ) + else: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)): + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if condition_latents is not None: + condition_latents = attn.to_out[0](condition_latents) + condition_latents = attn.to_out[1](condition_latents) + + return ( + (hidden_states, encoder_hidden_states, condition_latents) + if condition_latents is not None + else (hidden_states, encoder_hidden_states) + ) + elif condition_latents is not None: + # if there are condition_latents, we need to separate the hidden_states and the condition_latents + hidden_states, condition_latents = ( + hidden_states[:, : -condition_latents.shape[1]], + hidden_states[:, -condition_latents.shape[1] :], + ) + return hidden_states, condition_latents + else: + return hidden_states + + +def block_forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + condition_latents: torch.FloatTensor, + temb: torch.FloatTensor, + cond_temb: torch.FloatTensor, + cond_rotary_emb=None, + image_rotary_emb=None, + model_config: Optional[Dict[str, Any]] = {}, +): + use_cond = condition_latents is not None + with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, emb=temb + ) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( + self.norm1_context(encoder_hidden_states, emb=temb) + ) + + if use_cond: + ( + norm_condition_latents, + cond_gate_msa, + cond_shift_mlp, + cond_scale_mlp, + cond_gate_mlp, + ) = self.norm1(condition_latents, emb=cond_temb) + + # Attention. + result = attn_forward( + self.attn, + model_config=model_config, + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + condition_latents=norm_condition_latents if use_cond else None, + image_rotary_emb=image_rotary_emb, + cond_rotary_emb=cond_rotary_emb if use_cond else None, + ) + attn_output, context_attn_output = result[:2] + cond_attn_output = result[2] if use_cond else None + + # Process attention outputs for the `hidden_states`. + # 1. hidden_states + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + # 2. encoder_hidden_states + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + # 3. condition_latents + if use_cond: + cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output + condition_latents = condition_latents + cond_attn_output + if model_config.get("add_cond_attn", False): + hidden_states += cond_attn_output + + # LayerNorm + MLP. + # 1. hidden_states + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + # 2. encoder_hidden_states + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + ) + # 3. condition_latents + if use_cond: + norm_condition_latents = self.norm2(condition_latents) + norm_condition_latents = ( + norm_condition_latents * (1 + cond_scale_mlp[:, None]) + + cond_shift_mlp[:, None] + ) + + # Feed-forward. + with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)): + # 1. hidden_states + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + # 2. encoder_hidden_states + context_ff_output = self.ff_context(norm_encoder_hidden_states) + context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output + # 3. condition_latents + if use_cond: + cond_ff_output = self.ff(norm_condition_latents) + cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output + + # Process feed-forward outputs. + hidden_states = hidden_states + ff_output + encoder_hidden_states = encoder_hidden_states + context_ff_output + if use_cond: + condition_latents = condition_latents + cond_ff_output + + # Clip to avoid overflow. + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states, condition_latents if use_cond else None + + +def single_block_forward( + self, + hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + condition_latents: torch.FloatTensor = None, + cond_temb: torch.FloatTensor = None, + cond_rotary_emb=None, + model_config: Optional[Dict[str, Any]] = {}, +): + + using_cond = condition_latents is not None + residual = hidden_states + with enable_lora( + ( + self.norm.linear, + self.proj_mlp, + ), + model_config.get("latent_lora", False), + ): + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + if using_cond: + residual_cond = condition_latents + norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb) + mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents)) + + attn_output = attn_forward( + self.attn, + model_config=model_config, + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **( + { + "condition_latents": norm_condition_latents, + "cond_rotary_emb": cond_rotary_emb if using_cond else None, + } + if using_cond + else {} + ), + ) + if using_cond: + attn_output, cond_attn_output = attn_output + + with enable_lora((self.proj_out,), model_config.get("latent_lora", False)): + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if using_cond: + condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2) + cond_gate = cond_gate.unsqueeze(1) + condition_latents = cond_gate * self.proj_out(condition_latents) + condition_latents = residual_cond + condition_latents + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states if not using_cond else (hidden_states, condition_latents) diff --git a/OminiControl/src/flux/condition.py b/OminiControl/src/flux/condition.py new file mode 100644 index 0000000000000000000000000000000000000000..b2dcd36f817f1285f8e9060f4ea9583742e04e8e --- /dev/null +++ b/OminiControl/src/flux/condition.py @@ -0,0 +1,138 @@ +import torch +from typing import Optional, Union, List, Tuple +from diffusers.pipelines import FluxPipeline +from PIL import Image, ImageFilter +import numpy as np +import cv2 + +from .pipeline_tools import encode_images + +condition_dict = { + "depth": 0, + "canny": 1, + "subject": 4, + "coloring": 6, + "deblurring": 7, + "depth_pred": 8, + "fill": 9, + "sr": 10, + "cartoon": 11, +} + + +class Condition(object): + def __init__( + self, + condition_type: str, + raw_img: Union[Image.Image, torch.Tensor] = None, + condition: Union[Image.Image, torch.Tensor] = None, + mask=None, + position_delta=None, + position_scale=1.0, + ) -> None: + self.condition_type = condition_type + assert raw_img is not None or condition is not None + if raw_img is not None: + self.condition = self.get_condition(condition_type, raw_img) + else: + self.condition = condition + self.position_delta = position_delta + self.position_scale = position_scale + # TODO: Add mask support + assert mask is None, "Mask not supported yet" + + def get_condition( + self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor] + ) -> Union[Image.Image, torch.Tensor]: + """ + Returns the condition image. + """ + if condition_type == "depth": + from transformers import pipeline + + depth_pipe = pipeline( + task="depth-estimation", + model="LiheYoung/depth-anything-small-hf", + device="cuda", + ) + source_image = raw_img.convert("RGB") + condition_img = depth_pipe(source_image)["depth"].convert("RGB") + return condition_img + elif condition_type == "canny": + img = np.array(raw_img) + edges = cv2.Canny(img, 100, 200) + edges = Image.fromarray(edges).convert("RGB") + return edges + elif condition_type == "subject": + return raw_img + elif condition_type == "coloring": + return raw_img.convert("L").convert("RGB") + elif condition_type == "deblurring": + condition_image = ( + raw_img.convert("RGB") + .filter(ImageFilter.GaussianBlur(10)) + .convert("RGB") + ) + return condition_image + elif condition_type == "fill": + return raw_img.convert("RGB") + elif condition_type == "cartoon": + return raw_img.convert("RGB") + return self.condition + + @property + def type_id(self) -> int: + """ + Returns the type id of the condition. + """ + return condition_dict[self.condition_type] + + @classmethod + def get_type_id(cls, condition_type: str) -> int: + """ + Returns the type id of the condition. + """ + return condition_dict[condition_type] + + def encode( + self, pipe: FluxPipeline, empty: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Encodes the condition into tokens, ids and type_id. + """ + if self.condition_type in [ + "depth", + "canny", + "subject", + "coloring", + "deblurring", + "depth_pred", + "fill", + "sr", + "cartoon", + ]: + if empty: + # make the condition black + e_condition = Image.new("RGB", self.condition.size, (0, 0, 0)) + e_condition = e_condition.convert("RGB") + tokens, ids = encode_images(pipe, e_condition) + else: + tokens, ids = encode_images(pipe, self.condition) + tokens, ids = encode_images(pipe, self.condition) + else: + raise NotImplementedError( + f"Condition type {self.condition_type} not implemented" + ) + if self.position_delta is None and self.condition_type == "subject": + self.position_delta = [0, -self.condition.size[0] // 16] + if self.position_delta is not None: + ids[:, 1] += self.position_delta[0] + ids[:, 2] += self.position_delta[1] + if self.position_scale != 1.0: + scale_bias = (self.position_scale - 1.0) / 2 + ids[:, 1] *= self.position_scale + ids[:, 2] *= self.position_scale + ids[:, 1] += scale_bias + ids[:, 2] += scale_bias + type_id = torch.ones_like(ids[:, :1]) * self.type_id + return tokens, ids, type_id diff --git a/OminiControl/src/flux/generate.py b/OminiControl/src/flux/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..a380633d740a5dbad731d0e154f26f16e4b5bc13 --- /dev/null +++ b/OminiControl/src/flux/generate.py @@ -0,0 +1,321 @@ +import torch +import yaml, os +from diffusers.pipelines import FluxPipeline +from typing import List, Union, Optional, Dict, Any, Callable +from .transformer import tranformer_forward +from .condition import Condition + +from diffusers.pipelines.flux.pipeline_flux import ( + FluxPipelineOutput, + calculate_shift, + retrieve_timesteps, + np, +) + + +def get_config(config_path: str = None): + config_path = config_path or os.environ.get("XFL_CONFIG") + if not config_path: + return {} + with open(config_path, "r") as f: + config = yaml.safe_load(f) + return config + + +def prepare_params( + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + **kwargs: dict, +): + return ( + prompt, + prompt_2, + height, + width, + num_inference_steps, + timesteps, + guidance_scale, + num_images_per_prompt, + generator, + latents, + prompt_embeds, + pooled_prompt_embeds, + output_type, + return_dict, + joint_attention_kwargs, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + max_sequence_length, + ) + + +def seed_everything(seed: int = 42): + torch.backends.cudnn.deterministic = True + torch.manual_seed(seed) + np.random.seed(seed) + + +@torch.no_grad() +def generate( + pipeline: FluxPipeline, + conditions: List[Condition] = None, + config_path: str = None, + model_config: Optional[Dict[str, Any]] = {}, + condition_scale: float = 1.0, + default_lora: bool = False, + image_guidance_scale: float = 1.0, + **params: dict, +): + model_config = model_config or get_config(config_path).get("model", {}) + if condition_scale != 1: + for name, module in pipeline.transformer.named_modules(): + if not name.endswith(".attn"): + continue + module.c_factor = torch.ones(1, 1) * condition_scale + + self = pipeline + ( + prompt, + prompt_2, + height, + width, + num_inference_steps, + timesteps, + guidance_scale, + num_images_per_prompt, + generator, + latents, + prompt_embeds, + pooled_prompt_embeds, + output_type, + return_dict, + joint_attention_kwargs, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + max_sequence_length, + ) = prepare_params(**params) + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) + if self.joint_attention_kwargs is not None + else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 4.1. Prepare conditions + condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3)) + use_condition = conditions is not None or [] + if use_condition: + assert len(conditions) <= 1, "Only one condition is supported for now." + if not default_lora: + pipeline.set_adapters(conditions[0].condition_type) + for condition in conditions: + tokens, ids, type_id = condition.encode(self) + condition_latents.append(tokens) # [batch_size, token_n, token_dim] + condition_ids.append(ids) # [token_n, id_dim(3)] + condition_type_ids.append(type_id) # [token_n, 1] + condition_latents = torch.cat(condition_latents, dim=1) + condition_ids = torch.cat(condition_ids, dim=0) + condition_type_ids = torch.cat(condition_type_ids, dim=0) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.tensor([guidance_scale], device=device) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + noise_pred = tranformer_forward( + self.transformer, + model_config=model_config, + # Inputs of the condition (new feature) + condition_latents=condition_latents if use_condition else None, + condition_ids=condition_ids if use_condition else None, + condition_type_ids=condition_type_ids if use_condition else None, + # Inputs to the original transformer + hidden_states=latents, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if image_guidance_scale != 1.0: + uncondition_latents = condition.encode(self, empty=True)[0] + unc_pred = tranformer_forward( + self.transformer, + model_config=model_config, + # Inputs of the condition (new feature) + condition_latents=uncondition_latents if use_condition else None, + condition_ids=condition_ids if use_condition else None, + condition_type_ids=condition_type_ids if use_condition else None, + # Inputs to the original transformer + hidden_states=latents, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timestep / 1000, + guidance=torch.ones_like(guidance), + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = ( + latents / self.vae.config.scaling_factor + ) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if condition_scale != 1: + for name, module in pipeline.transformer.named_modules(): + if not name.endswith(".attn"): + continue + del module.c_factor + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/OminiControl/src/flux/lora_controller.py b/OminiControl/src/flux/lora_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..21b23eae2bdeaab171df4616e174dd6d96351620 --- /dev/null +++ b/OminiControl/src/flux/lora_controller.py @@ -0,0 +1,75 @@ +from peft.tuners.tuners_utils import BaseTunerLayer +from typing import List, Any, Optional, Type + + +class enable_lora: + def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None: + self.activated: bool = activated + if activated: + return + self.lora_modules: List[BaseTunerLayer] = [ + each for each in lora_modules if isinstance(each, BaseTunerLayer) + ] + self.scales = [ + { + active_adapter: lora_module.scaling[active_adapter] + for active_adapter in lora_module.active_adapters + } + for lora_module in self.lora_modules + ] + + def __enter__(self) -> None: + if self.activated: + return + + for lora_module in self.lora_modules: + if not isinstance(lora_module, BaseTunerLayer): + continue + lora_module.scale_layer(0) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + if self.activated: + return + for i, lora_module in enumerate(self.lora_modules): + if not isinstance(lora_module, BaseTunerLayer): + continue + for active_adapter in lora_module.active_adapters: + lora_module.scaling[active_adapter] = self.scales[i][active_adapter] + + +class set_lora_scale: + def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None: + self.lora_modules: List[BaseTunerLayer] = [ + each for each in lora_modules if isinstance(each, BaseTunerLayer) + ] + self.scales = [ + { + active_adapter: lora_module.scaling[active_adapter] + for active_adapter in lora_module.active_adapters + } + for lora_module in self.lora_modules + ] + self.scale = scale + + def __enter__(self) -> None: + for lora_module in self.lora_modules: + if not isinstance(lora_module, BaseTunerLayer): + continue + lora_module.scale_layer(self.scale) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + for i, lora_module in enumerate(self.lora_modules): + if not isinstance(lora_module, BaseTunerLayer): + continue + for active_adapter in lora_module.active_adapters: + lora_module.scaling[active_adapter] = self.scales[i][active_adapter] diff --git a/OminiControl/src/flux/pipeline_tools.py b/OminiControl/src/flux/pipeline_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..36174e4fe27dd12dbe629e3b92ca160b3600d889 --- /dev/null +++ b/OminiControl/src/flux/pipeline_tools.py @@ -0,0 +1,52 @@ +from diffusers.pipelines import FluxPipeline +from diffusers.utils import logging +from diffusers.pipelines.flux.pipeline_flux import logger +from torch import Tensor + + +def encode_images(pipeline: FluxPipeline, images: Tensor): + images = pipeline.image_processor.preprocess(images) + images = images.to(pipeline.device).to(pipeline.dtype) + images = pipeline.vae.encode(images).latent_dist.sample() + images = ( + images - pipeline.vae.config.shift_factor + ) * pipeline.vae.config.scaling_factor + images_tokens = pipeline._pack_latents(images, *images.shape) + images_ids = pipeline._prepare_latent_image_ids( + images.shape[0], + images.shape[2], + images.shape[3], + pipeline.device, + pipeline.dtype, + ) + if images_tokens.shape[1] != images_ids.shape[0]: + images_ids = pipeline._prepare_latent_image_ids( + images.shape[0], + images.shape[2] // 2, + images.shape[3] // 2, + pipeline.device, + pipeline.dtype, + ) + return images_tokens, images_ids + + +def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512): + # Turn off warnings (CLIP overflow) + logger.setLevel(logging.ERROR) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = pipeline.encode_prompt( + prompt=prompts, + prompt_2=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + device=pipeline.device, + num_images_per_prompt=1, + max_sequence_length=max_sequence_length, + lora_scale=None, + ) + # Turn on warnings + logger.setLevel(logging.WARNING) + return prompt_embeds, pooled_prompt_embeds, text_ids diff --git a/OminiControl/src/flux/transformer.py b/OminiControl/src/flux/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1a31e14bf9a72f1b42775419ad251ee5bb65e764 --- /dev/null +++ b/OminiControl/src/flux/transformer.py @@ -0,0 +1,252 @@ +import torch +from diffusers.pipelines import FluxPipeline +from typing import List, Union, Optional, Dict, Any, Callable +from .block import block_forward, single_block_forward +from .lora_controller import enable_lora +from accelerate.utils import is_torch_version +from diffusers.models.transformers.transformer_flux import ( + FluxTransformer2DModel, + Transformer2DModelOutput, + USE_PEFT_BACKEND, + scale_lora_layers, + unscale_lora_layers, + logger, +) +import numpy as np + + +def prepare_params( + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + **kwargs: dict, +): + return ( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + img_ids, + txt_ids, + guidance, + joint_attention_kwargs, + controlnet_block_samples, + controlnet_single_block_samples, + return_dict, + ) + + +def tranformer_forward( + transformer: FluxTransformer2DModel, + condition_latents: torch.Tensor, + condition_ids: torch.Tensor, + condition_type_ids: torch.Tensor, + model_config: Optional[Dict[str, Any]] = {}, + c_t=0, + **params: dict, +): + self = transformer + use_condition = condition_latents is not None + + ( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + img_ids, + txt_ids, + guidance, + joint_attention_kwargs, + controlnet_block_samples, + controlnet_single_block_samples, + return_dict, + ) = prepare_params(**params) + + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if ( + joint_attention_kwargs is not None + and joint_attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)): + hidden_states = self.x_embedder(hidden_states) + condition_latents = self.x_embedder(condition_latents) if use_condition else None + + timestep = timestep.to(hidden_states.dtype) * 1000 + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + + cond_temb = ( + self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections) + if guidance is None + else self.time_text_embed( + torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections + ) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + if use_condition: + # condition_ids[:, :1] = condition_type_ids + cond_rotary_emb = self.pos_embed(condition_ids) + + # hidden_states = torch.cat([hidden_states, condition_latents], dim=1) + + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + encoder_hidden_states, hidden_states, condition_latents = ( + torch.utils.checkpoint.checkpoint( + block_forward, + self=block, + model_config=model_config, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + condition_latents=condition_latents if use_condition else None, + temb=temb, + cond_temb=cond_temb if use_condition else None, + cond_rotary_emb=cond_rotary_emb if use_condition else None, + image_rotary_emb=image_rotary_emb, + **ckpt_kwargs, + ) + ) + + else: + encoder_hidden_states, hidden_states, condition_latents = block_forward( + block, + model_config=model_config, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + condition_latents=condition_latents if use_condition else None, + temb=temb, + cond_temb=cond_temb if use_condition else None, + cond_rotary_emb=cond_rotary_emb if use_condition else None, + image_rotary_emb=image_rotary_emb, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len( + controlnet_block_samples + ) + interval_control = int(np.ceil(interval_control)) + hidden_states = ( + hidden_states + + controlnet_block_samples[index_block // interval_control] + ) + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + result = torch.utils.checkpoint.checkpoint( + single_block_forward, + self=block, + model_config=model_config, + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + **( + { + "condition_latents": condition_latents, + "cond_temb": cond_temb, + "cond_rotary_emb": cond_rotary_emb, + } + if use_condition + else {} + ), + **ckpt_kwargs, + ) + + else: + result = single_block_forward( + block, + model_config=model_config, + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + **( + { + "condition_latents": condition_latents, + "cond_temb": cond_temb, + "cond_rotary_emb": cond_rotary_emb, + } + if use_condition + else {} + ), + ) + if use_condition: + hidden_states, condition_latents = result + else: + hidden_states = result + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len( + controlnet_single_block_samples + ) + interval_control = int(np.ceil(interval_control)) + hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( + hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + controlnet_single_block_samples[index_block // interval_control] + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/OminiControl/src/gradio/gradio_app.py b/OminiControl/src/gradio/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..f8190be9c0cd1985e664761a40aad5569dc19dc9 --- /dev/null +++ b/OminiControl/src/gradio/gradio_app.py @@ -0,0 +1,115 @@ +import gradio as gr +import torch +from PIL import Image, ImageDraw, ImageFont +from diffusers.pipelines import FluxPipeline +from diffusers import FluxTransformer2DModel +import numpy as np + +from ..flux.condition import Condition +from ..flux.generate import seed_everything, generate + +pipe = None +use_int8 = False + + +def get_gpu_memory(): + return torch.cuda.get_device_properties(0).total_memory / 1024**3 + + +def init_pipeline(): + global pipe + if use_int8 or get_gpu_memory() < 33: + transformer_model = FluxTransformer2DModel.from_pretrained( + "sayakpaul/flux.1-schell-int8wo-improved", + torch_dtype=torch.bfloat16, + use_safetensors=False, + ) + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + transformer=transformer_model, + torch_dtype=torch.bfloat16, + ) + else: + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 + ) + pipe = pipe.to("cuda") + pipe.load_lora_weights( + "Yuanshi/OminiControl", + weight_name="omini/subject_512.safetensors", + adapter_name="subject", + ) + + +def process_image_and_text(image, text): + # center crop image + w, h, min_size = image.size[0], image.size[1], min(image.size) + image = image.crop( + ( + (w - min_size) // 2, + (h - min_size) // 2, + (w + min_size) // 2, + (h + min_size) // 2, + ) + ) + image = image.resize((512, 512)) + + condition = Condition("subject", image, position_delta=(0, 32)) + + if pipe is None: + init_pipeline() + + result_img = generate( + pipe, + prompt=text.strip(), + conditions=[condition], + num_inference_steps=8, + height=512, + width=512, + ).images[0] + + return result_img + + +def get_samples(): + sample_list = [ + { + "image": "assets/oranges.jpg", + "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'", + }, + { + "image": "assets/penguin.jpg", + "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'", + }, + { + "image": "assets/rc_car.jpg", + "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.", + }, + { + "image": "assets/clock.jpg", + "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.", + }, + { + "image": "assets/tshirt.jpg", + "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.", + }, + ] + return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list] + + +demo = gr.Interface( + fn=process_image_and_text, + inputs=[ + gr.Image(type="pil"), + gr.Textbox(lines=2), + ], + outputs=gr.Image(type="pil"), + title="OminiControl / Subject driven generation", + examples=get_samples(), +) + +if __name__ == "__main__": + init_pipeline() + demo.launch( + debug=True, + ) diff --git a/OminiControl/src/train/callbacks.py b/OminiControl/src/train/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..1ccc7488ce1f1d4cf60707eb8ea5cfe9659c7132 --- /dev/null +++ b/OminiControl/src/train/callbacks.py @@ -0,0 +1,253 @@ +import lightning as L +from PIL import Image, ImageFilter, ImageDraw +import numpy as np +from transformers import pipeline +import cv2 +import torch +import os + +try: + import wandb +except ImportError: + wandb = None + +from ..flux.condition import Condition +from ..flux.generate import generate + + +class TrainingCallback(L.Callback): + def __init__(self, run_name, training_config: dict = {}): + self.run_name, self.training_config = run_name, training_config + + self.print_every_n_steps = training_config.get("print_every_n_steps", 10) + self.save_interval = training_config.get("save_interval", 1000) + self.sample_interval = training_config.get("sample_interval", 1000) + self.save_path = training_config.get("save_path", "./output") + + self.wandb_config = training_config.get("wandb", None) + self.use_wandb = ( + wandb is not None and os.environ.get("WANDB_API_KEY") is not None + ) + + self.total_steps = 0 + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + gradient_size = 0 + max_gradient_size = 0 + count = 0 + for _, param in pl_module.named_parameters(): + if param.grad is not None: + gradient_size += param.grad.norm(2).item() + max_gradient_size = max(max_gradient_size, param.grad.norm(2).item()) + count += 1 + if count > 0: + gradient_size /= count + + self.total_steps += 1 + + # Print training progress every n steps + if self.use_wandb: + report_dict = { + "steps": batch_idx, + "steps": self.total_steps, + "epoch": trainer.current_epoch, + "gradient_size": gradient_size, + } + loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches + report_dict["loss"] = loss_value + report_dict["t"] = pl_module.last_t + wandb.log(report_dict) + + if self.total_steps % self.print_every_n_steps == 0: + print( + f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}" + ) + + # Save LoRA weights at specified intervals + if self.total_steps % self.save_interval == 0: + print( + f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights" + ) + pl_module.save_lora( + f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}" + ) + + # Generate and save a sample image at specified intervals + if self.total_steps % self.sample_interval == 0: + print( + f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample" + ) + self.generate_a_sample( + trainer, + pl_module, + f"{self.save_path}/{self.run_name}/output", + f"lora_{self.total_steps}", + batch["condition_type"][ + 0 + ], # Use the condition type from the current batch + ) + + @torch.no_grad() + def generate_a_sample( + self, + trainer, + pl_module, + save_path, + file_name, + condition_type="super_resolution", + ): + # TODO: change this two variables to parameters + condition_size = trainer.training_config["dataset"]["condition_size"] + target_size = trainer.training_config["dataset"]["target_size"] + position_scale = trainer.training_config["dataset"].get("position_scale", 1.0) + + generator = torch.Generator(device=pl_module.device) + generator.manual_seed(42) + + test_list = [] + + if condition_type == "subject": + test_list.extend( + [ + ( + Image.open("assets/test_in.jpg"), + [0, -32], + "Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene.", + ), + ( + Image.open("assets/test_out.jpg"), + [0, -32], + "In a bright room. It is placed on a table.", + ), + ] + ) + elif condition_type == "canny": + condition_img = Image.open("assets/vase_hq.jpg").resize( + (condition_size, condition_size) + ) + condition_img = np.array(condition_img) + condition_img = cv2.Canny(condition_img, 100, 200) + condition_img = Image.fromarray(condition_img).convert("RGB") + test_list.append( + ( + condition_img, + [0, 0], + "A beautiful vase on a table.", + {"position_scale": position_scale} if position_scale != 1.0 else {}, + ) + ) + elif condition_type == "coloring": + condition_img = ( + Image.open("assets/vase_hq.jpg") + .resize((condition_size, condition_size)) + .convert("L") + .convert("RGB") + ) + test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) + elif condition_type == "depth": + if not hasattr(self, "deepth_pipe"): + self.deepth_pipe = pipeline( + task="depth-estimation", + model="LiheYoung/depth-anything-small-hf", + device="cpu", + ) + condition_img = ( + Image.open("assets/vase_hq.jpg") + .resize((condition_size, condition_size)) + .convert("RGB") + ) + condition_img = self.deepth_pipe(condition_img)["depth"].convert("RGB") + test_list.append( + ( + condition_img, + [0, 0], + "A beautiful vase on a table.", + {"position_scale": position_scale} if position_scale != 1.0 else {}, + ) + ) + elif condition_type == "depth_pred": + condition_img = ( + Image.open("assets/vase_hq.jpg") + .resize((condition_size, condition_size)) + .convert("RGB") + ) + test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) + elif condition_type == "deblurring": + blur_radius = 5 + image = Image.open("./assets/vase_hq.jpg") + condition_img = ( + image.convert("RGB") + .resize((condition_size, condition_size)) + .filter(ImageFilter.GaussianBlur(blur_radius)) + .convert("RGB") + ) + test_list.append( + ( + condition_img, + [0, 0], + "A beautiful vase on a table.", + {"position_scale": position_scale} if position_scale != 1.0 else {}, + ) + ) + elif condition_type == "fill": + condition_img = ( + Image.open("./assets/vase_hq.jpg") + .resize((condition_size, condition_size)) + .convert("RGB") + ) + mask = Image.new("L", condition_img.size, 0) + draw = ImageDraw.Draw(mask) + a = condition_img.size[0] // 4 + b = a * 3 + draw.rectangle([a, a, b, b], fill=255) + condition_img = Image.composite( + condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask + ) + test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) + elif condition_type == "sr": + condition_img = ( + Image.open("assets/vase_hq.jpg") + .resize((condition_size, condition_size)) + .convert("RGB") + ) + test_list.append((condition_img, [0, -16], "A beautiful vase on a table.")) + elif condition_type == "cartoon": + condition_img = ( + Image.open("assets/cartoon_boy.png") + .resize((condition_size, condition_size)) + .convert("RGB") + ) + test_list.append( + ( + condition_img, + [0, -16], + "A cartoon character in a white background. He is looking right, and running.", + ) + ) + else: + raise NotImplementedError + + if not os.path.exists(save_path): + os.makedirs(save_path) + for i, (condition_img, position_delta, prompt, *others) in enumerate(test_list): + condition = Condition( + condition_type=condition_type, + condition=condition_img.resize( + (condition_size, condition_size) + ).convert("RGB"), + position_delta=position_delta, + **(others[0] if others else {}), + ) + res = generate( + pl_module.flux_pipe, + prompt=prompt, + conditions=[condition], + height=target_size, + width=target_size, + generator=generator, + model_config=pl_module.model_config, + default_lora=True, + ) + res.images[0].save( + os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg") + ) diff --git a/OminiControl/src/train/data.py b/OminiControl/src/train/data.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3c315146d387f8ee34db58b29867885d558bed --- /dev/null +++ b/OminiControl/src/train/data.py @@ -0,0 +1,323 @@ +from PIL import Image, ImageFilter, ImageDraw +import cv2 +import numpy as np +from torch.utils.data import Dataset +import torchvision.transforms as T +import random + + +class Subject200KDataset(Dataset): + def __init__( + self, + base_dataset, + condition_size: int = 512, + target_size: int = 512, + image_size: int = 512, + padding: int = 0, + condition_type: str = "subject", + drop_text_prob: float = 0.1, + drop_image_prob: float = 0.1, + return_pil_image: bool = False, + ): + self.base_dataset = base_dataset + self.condition_size = condition_size + self.target_size = target_size + self.image_size = image_size + self.padding = padding + self.condition_type = condition_type + self.drop_text_prob = drop_text_prob + self.drop_image_prob = drop_image_prob + self.return_pil_image = return_pil_image + + self.to_tensor = T.ToTensor() + + def __len__(self): + return len(self.base_dataset) * 2 + + def __getitem__(self, idx): + # If target is 0, left image is target, right image is condition + target = idx % 2 + item = self.base_dataset[idx // 2] + + # Crop the image to target and condition + image = item["image"] + left_img = image.crop( + ( + self.padding, + self.padding, + self.image_size + self.padding, + self.image_size + self.padding, + ) + ) + right_img = image.crop( + ( + self.image_size + self.padding * 2, + self.padding, + self.image_size * 2 + self.padding * 2, + self.image_size + self.padding, + ) + ) + + # Get the target and condition image + target_image, condition_img = ( + (left_img, right_img) if target == 0 else (right_img, left_img) + ) + + # Resize the image + condition_img = condition_img.resize( + (self.condition_size, self.condition_size) + ).convert("RGB") + target_image = target_image.resize( + (self.target_size, self.target_size) + ).convert("RGB") + + # Get the description + description = item["description"][ + "description_0" if target == 0 else "description_1" + ] + + # Randomly drop text or image + drop_text = random.random() < self.drop_text_prob + drop_image = random.random() < self.drop_image_prob + if drop_text: + description = "" + if drop_image: + condition_img = Image.new( + "RGB", (self.condition_size, self.condition_size), (0, 0, 0) + ) + + return { + "image": self.to_tensor(target_image), + "condition": self.to_tensor(condition_img), + "condition_type": self.condition_type, + "description": description, + # 16 is the downscale factor of the image + "position_delta": np.array([0, -self.condition_size // 16]), + **({"pil_image": image} if self.return_pil_image else {}), + } + + +class ImageConditionDataset(Dataset): + def __init__( + self, + base_dataset, + condition_size: int = 512, + target_size: int = 512, + condition_type: str = "canny", + drop_text_prob: float = 0.1, + drop_image_prob: float = 0.1, + return_pil_image: bool = False, + position_scale=1.0, + ): + self.base_dataset = base_dataset + self.condition_size = condition_size + self.target_size = target_size + self.condition_type = condition_type + self.drop_text_prob = drop_text_prob + self.drop_image_prob = drop_image_prob + self.return_pil_image = return_pil_image + self.position_scale = position_scale + + self.to_tensor = T.ToTensor() + + def __len__(self): + return len(self.base_dataset) + + @property + def depth_pipe(self): + if not hasattr(self, "_depth_pipe"): + from transformers import pipeline + + self._depth_pipe = pipeline( + task="depth-estimation", + model="LiheYoung/depth-anything-small-hf", + device="cpu", + ) + return self._depth_pipe + + def _get_canny_edge(self, img): + resize_ratio = self.condition_size / max(img.size) + img = img.resize( + (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio)) + ) + img_np = np.array(img) + img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) + edges = cv2.Canny(img_gray, 100, 200) + return Image.fromarray(edges).convert("RGB") + + def __getitem__(self, idx): + image = self.base_dataset[idx]["jpg"] + image = image.resize((self.target_size, self.target_size)).convert("RGB") + description = self.base_dataset[idx]["json"]["prompt"] + + enable_scale = random.random() < 1 + if not enable_scale: + condition_size = int(self.condition_size * self.position_scale) + position_scale = 1.0 + else: + condition_size = self.condition_size + position_scale = self.position_scale + + # Get the condition image + position_delta = np.array([0, 0]) + if self.condition_type == "canny": + condition_img = self._get_canny_edge(image) + elif self.condition_type == "coloring": + condition_img = ( + image.resize((condition_size, condition_size)) + .convert("L") + .convert("RGB") + ) + elif self.condition_type == "deblurring": + blur_radius = random.randint(1, 10) + condition_img = ( + image.convert("RGB") + .filter(ImageFilter.GaussianBlur(blur_radius)) + .resize((condition_size, condition_size)) + .convert("RGB") + ) + elif self.condition_type == "depth": + condition_img = self.depth_pipe(image)["depth"].convert("RGB") + condition_img = condition_img.resize((condition_size, condition_size)) + elif self.condition_type == "depth_pred": + condition_img = image + image = self.depth_pipe(condition_img)["depth"].convert("RGB") + description = f"[depth] {description}" + elif self.condition_type == "fill": + condition_img = image.resize((condition_size, condition_size)).convert( + "RGB" + ) + w, h = image.size + x1, x2 = sorted([random.randint(0, w), random.randint(0, w)]) + y1, y2 = sorted([random.randint(0, h), random.randint(0, h)]) + mask = Image.new("L", image.size, 0) + draw = ImageDraw.Draw(mask) + draw.rectangle([x1, y1, x2, y2], fill=255) + if random.random() > 0.5: + mask = Image.eval(mask, lambda a: 255 - a) + condition_img = Image.composite( + image, Image.new("RGB", image.size, (0, 0, 0)), mask + ) + elif self.condition_type == "sr": + condition_img = image.resize((condition_size, condition_size)).convert( + "RGB" + ) + position_delta = np.array([0, -condition_size // 16]) + + else: + raise ValueError(f"Condition type {self.condition_type} not implemented") + + # Randomly drop text or image + drop_text = random.random() < self.drop_text_prob + drop_image = random.random() < self.drop_image_prob + if drop_text: + description = "" + if drop_image: + condition_img = Image.new( + "RGB", (condition_size, condition_size), (0, 0, 0) + ) + + return { + "image": self.to_tensor(image), + "condition": self.to_tensor(condition_img), + "condition_type": self.condition_type, + "description": description, + "position_delta": position_delta, + **({"pil_image": [image, condition_img]} if self.return_pil_image else {}), + **({"position_scale": position_scale} if position_scale != 1.0 else {}), + } + + +class CartoonDataset(Dataset): + def __init__( + self, + base_dataset, + condition_size: int = 1024, + target_size: int = 1024, + image_size: int = 1024, + padding: int = 0, + condition_type: str = "cartoon", + drop_text_prob: float = 0.1, + drop_image_prob: float = 0.1, + return_pil_image: bool = False, + ): + self.base_dataset = base_dataset + self.condition_size = condition_size + self.target_size = target_size + self.image_size = image_size + self.padding = padding + self.condition_type = condition_type + self.drop_text_prob = drop_text_prob + self.drop_image_prob = drop_image_prob + self.return_pil_image = return_pil_image + + self.to_tensor = T.ToTensor() + + def __len__(self): + return len(self.base_dataset) + + def __getitem__(self, idx): + data = self.base_dataset[idx] + condition_img = data["condition"] + target_image = data["target"] + + # Tag + tag = data["tags"][0] + + target_description = data["target_description"] + + description = { + "lion": "lion like animal", + "bear": "bear like animal", + "gorilla": "gorilla like animal", + "dog": "dog like animal", + "elephant": "elephant like animal", + "eagle": "eagle like bird", + "tiger": "tiger like animal", + "owl": "owl like bird", + "woman": "woman", + "parrot": "parrot like bird", + "mouse": "mouse like animal", + "man": "man", + "pigeon": "pigeon like bird", + "girl": "girl", + "panda": "panda like animal", + "crocodile": "crocodile like animal", + "rabbit": "rabbit like animal", + "boy": "boy", + "monkey": "monkey like animal", + "cat": "cat like animal", + } + + # Resize the image + condition_img = condition_img.resize( + (self.condition_size, self.condition_size) + ).convert("RGB") + target_image = target_image.resize( + (self.target_size, self.target_size) + ).convert("RGB") + + # Process datum to create description + description = data.get( + "description", + f"Photo of a {description[tag]} cartoon character in a white background. Character is facing {target_description['facing_direction']}. Character pose is {target_description['pose']}.", + ) + + # Randomly drop text or image + drop_text = random.random() < self.drop_text_prob + drop_image = random.random() < self.drop_image_prob + if drop_text: + description = "" + if drop_image: + condition_img = Image.new( + "RGB", (self.condition_size, self.condition_size), (0, 0, 0) + ) + + return { + "image": self.to_tensor(target_image), + "condition": self.to_tensor(condition_img), + "condition_type": self.condition_type, + "description": description, + # 16 is the downscale factor of the image + "position_delta": np.array([0, -16]), + } diff --git a/OminiControl/src/train/model.py b/OminiControl/src/train/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b143858d2e90cbb11e5033766febafd49ea7697a --- /dev/null +++ b/OminiControl/src/train/model.py @@ -0,0 +1,185 @@ +import lightning as L +from diffusers.pipelines import FluxPipeline +import torch +from peft import LoraConfig, get_peft_model_state_dict + +import prodigyopt + +from ..flux.transformer import tranformer_forward +from ..flux.condition import Condition +from ..flux.pipeline_tools import encode_images, prepare_text_input + + +class OminiModel(L.LightningModule): + def __init__( + self, + flux_pipe_id: str, + lora_path: str = None, + lora_config: dict = None, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + model_config: dict = {}, + optimizer_config: dict = None, + gradient_checkpointing: bool = False, + ): + # Initialize the LightningModule + super().__init__() + self.model_config = model_config + self.optimizer_config = optimizer_config + + # Load the Flux pipeline + self.flux_pipe: FluxPipeline = ( + FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device) + ) + self.transformer = self.flux_pipe.transformer + self.transformer.gradient_checkpointing = gradient_checkpointing + self.transformer.train() + + # Freeze the Flux pipeline + self.flux_pipe.text_encoder.requires_grad_(False).eval() + self.flux_pipe.text_encoder_2.requires_grad_(False).eval() + self.flux_pipe.vae.requires_grad_(False).eval() + + # Initialize LoRA layers + self.lora_layers = self.init_lora(lora_path, lora_config) + + self.to(device).to(dtype) + + def init_lora(self, lora_path: str, lora_config: dict): + assert lora_path or lora_config + if lora_path: + # TODO: Implement this + raise NotImplementedError + else: + self.transformer.add_adapter(LoraConfig(**lora_config)) + # TODO: Check if this is correct (p.requires_grad) + lora_layers = filter( + lambda p: p.requires_grad, self.transformer.parameters() + ) + return list(lora_layers) + + def save_lora(self, path: str): + FluxPipeline.save_lora_weights( + save_directory=path, + transformer_lora_layers=get_peft_model_state_dict(self.transformer), + safe_serialization=True, + ) + + def configure_optimizers(self): + # Freeze the transformer + self.transformer.requires_grad_(False) + opt_config = self.optimizer_config + + # Set the trainable parameters + self.trainable_params = self.lora_layers + + # Unfreeze trainable parameters + for p in self.trainable_params: + p.requires_grad_(True) + + # Initialize the optimizer + if opt_config["type"] == "AdamW": + optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"]) + elif opt_config["type"] == "Prodigy": + optimizer = prodigyopt.Prodigy( + self.trainable_params, + **opt_config["params"], + ) + elif opt_config["type"] == "SGD": + optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"]) + else: + raise NotImplementedError + + return optimizer + + def training_step(self, batch, batch_idx): + step_loss = self.step(batch) + self.log_loss = ( + step_loss.item() + if not hasattr(self, "log_loss") + else self.log_loss * 0.95 + step_loss.item() * 0.05 + ) + return step_loss + + def step(self, batch): + imgs = batch["image"] + conditions = batch["condition"] + condition_types = batch["condition_type"] + prompts = batch["description"] + position_delta = batch["position_delta"][0] + position_scale = float(batch.get("position_scale", [1.0])[0]) + + # Prepare inputs + with torch.no_grad(): + # Prepare image input + x_0, img_ids = encode_images(self.flux_pipe, imgs) + + # Prepare text input + prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input( + self.flux_pipe, prompts + ) + + # Prepare t and x_t + t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device)) + x_1 = torch.randn_like(x_0).to(self.device) + t_ = t.unsqueeze(1).unsqueeze(1) + x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype) + + # Prepare conditions + condition_latents, condition_ids = encode_images(self.flux_pipe, conditions) + + # Add position delta + condition_ids[:, 1] += position_delta[0] + condition_ids[:, 2] += position_delta[1] + + if position_scale != 1.0: + scale_bias = (position_scale - 1.0) / 2 + condition_ids[:, 1] *= position_scale + condition_ids[:, 2] *= position_scale + condition_ids[:, 1] += scale_bias + condition_ids[:, 2] += scale_bias + + # Prepare condition type + condition_type_ids = torch.tensor( + [ + Condition.get_type_id(condition_type) + for condition_type in condition_types + ] + ).to(self.device) + condition_type_ids = ( + torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0] + ).unsqueeze(1) + + # Prepare guidance + guidance = ( + torch.ones_like(t).to(self.device) + if self.transformer.config.guidance_embeds + else None + ) + + # Forward pass + transformer_out = tranformer_forward( + self.transformer, + # Model config + model_config=self.model_config, + # Inputs of the condition (new feature) + condition_latents=condition_latents, + condition_ids=condition_ids, + condition_type_ids=condition_type_ids, + # Inputs to the original transformer + hidden_states=x_t, + timestep=t, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=img_ids, + joint_attention_kwargs=None, + return_dict=False, + ) + pred = transformer_out[0] + + # Compute loss + loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean") + self.last_t = t.mean().item() + return loss diff --git a/OminiControl/src/train/train.py b/OminiControl/src/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..1dbdc2989c7ca4261f2198de92f1e243c90262fe --- /dev/null +++ b/OminiControl/src/train/train.py @@ -0,0 +1,178 @@ +from torch.utils.data import DataLoader +import torch +import lightning as L +import yaml +import os +import time + +from datasets import load_dataset + +from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset +from .model import OminiModel +from .callbacks import TrainingCallback + + +def get_rank(): + try: + rank = int(os.environ.get("LOCAL_RANK")) + except: + rank = 0 + return rank + + +def get_config(): + config_path = os.environ.get("XFL_CONFIG") + assert config_path is not None, "Please set the XFL_CONFIG environment variable" + with open(config_path, "r") as f: + config = yaml.safe_load(f) + return config + + +def init_wandb(wandb_config, run_name): + import wandb + + try: + assert os.environ.get("WANDB_API_KEY") is not None + wandb.init( + project=wandb_config["project"], + name=run_name, + config={}, + ) + except Exception as e: + print("Failed to initialize WanDB:", e) + + +def main(): + # Initialize + is_main_process, rank = get_rank() == 0, get_rank() + torch.cuda.set_device(rank) + config = get_config() + training_config = config["train"] + run_name = time.strftime("%Y%m%d-%H%M%S") + + # Initialize WanDB + wandb_config = training_config.get("wandb", None) + if wandb_config is not None and is_main_process: + init_wandb(wandb_config, run_name) + + print("Rank:", rank) + if is_main_process: + print("Config:", config) + + # Initialize dataset and dataloader + if training_config["dataset"]["type"] == "subject": + dataset = load_dataset("Yuanshi/Subjects200K") + + # Define filter function + def filter_func(item): + if not item.get("quality_assessment"): + return False + return all( + item["quality_assessment"].get(key, 0) >= 5 + for key in ["compositeStructure", "objectConsistency", "imageQuality"] + ) + + # Filter dataset + if not os.path.exists("./cache/dataset"): + os.makedirs("./cache/dataset") + data_valid = dataset["train"].filter( + filter_func, + num_proc=16, + cache_file_name="./cache/dataset/data_valid.arrow", + ) + dataset = Subject200KDataset( + data_valid, + condition_size=training_config["dataset"]["condition_size"], + target_size=training_config["dataset"]["target_size"], + image_size=training_config["dataset"]["image_size"], + padding=training_config["dataset"]["padding"], + condition_type=training_config["condition_type"], + drop_text_prob=training_config["dataset"]["drop_text_prob"], + drop_image_prob=training_config["dataset"]["drop_image_prob"], + ) + elif training_config["dataset"]["type"] == "img": + # Load dataset text-to-image-2M + dataset = load_dataset( + "webdataset", + data_files={"train": training_config["dataset"]["urls"]}, + split="train", + cache_dir="cache/t2i2m", + num_proc=32, + ) + dataset = ImageConditionDataset( + dataset, + condition_size=training_config["dataset"]["condition_size"], + target_size=training_config["dataset"]["target_size"], + condition_type=training_config["condition_type"], + drop_text_prob=training_config["dataset"]["drop_text_prob"], + drop_image_prob=training_config["dataset"]["drop_image_prob"], + position_scale=training_config["dataset"].get("position_scale", 1.0), + ) + elif training_config["dataset"]["type"] == "cartoon": + dataset = load_dataset("saquiboye/oye-cartoon", split="train") + dataset = CartoonDataset( + dataset, + condition_size=training_config["dataset"]["condition_size"], + target_size=training_config["dataset"]["target_size"], + image_size=training_config["dataset"]["image_size"], + padding=training_config["dataset"]["padding"], + condition_type=training_config["condition_type"], + drop_text_prob=training_config["dataset"]["drop_text_prob"], + drop_image_prob=training_config["dataset"]["drop_image_prob"], + ) + else: + raise NotImplementedError + + print("Dataset length:", len(dataset)) + train_loader = DataLoader( + dataset, + batch_size=training_config["batch_size"], + shuffle=True, + num_workers=training_config["dataloader_workers"], + ) + + # Initialize model + trainable_model = OminiModel( + flux_pipe_id=config["flux_path"], + lora_config=training_config["lora_config"], + device=f"cuda", + dtype=getattr(torch, config["dtype"]), + optimizer_config=training_config["optimizer"], + model_config=config.get("model", {}), + gradient_checkpointing=training_config.get("gradient_checkpointing", False), + ) + + # Callbacks for logging and saving checkpoints + training_callbacks = ( + [TrainingCallback(run_name, training_config=training_config)] + if is_main_process + else [] + ) + + # Initialize trainer + trainer = L.Trainer( + accumulate_grad_batches=training_config["accumulate_grad_batches"], + callbacks=training_callbacks, + enable_checkpointing=False, + enable_progress_bar=False, + logger=False, + max_steps=training_config.get("max_steps", -1), + max_epochs=training_config.get("max_epochs", -1), + gradient_clip_val=training_config.get("gradient_clip_val", 0.5), + ) + + setattr(trainer, "training_config", training_config) + + # Save config + save_path = training_config.get("save_path", "./output") + if is_main_process: + os.makedirs(f"{save_path}/{run_name}") + with open(f"{save_path}/{run_name}/config.yaml", "w") as f: + yaml.dump(config, f) + + # Start training + trainer.fit(trainable_model, train_loader) + + +if __name__ == "__main__": + main() diff --git a/OminiControl/train/README.md b/OminiControl/train/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2ad489a53e8398568754e84a7ebb6bf9ff4780b7 --- /dev/null +++ b/OminiControl/train/README.md @@ -0,0 +1,138 @@ +# OminiControl Training 🛠️ + +## Preparation + +### Setup +1. **Environment** + ```bash + conda create -n omini python=3.10 + conda activate omini + ``` +2. **Requirements** + ```bash + pip install -r train/requirements.txt + ``` + +### Dataset +1. Download dataset [Subject200K](https://huggingface.co./datasets/Yuanshi/Subjects200K). (**subject-driven generation**) + ``` + bash train/script/data_download/data_download1.sh + ``` +2. Download dataset [text-to-image-2M](https://huggingface.co./datasets/jackyhate/text-to-image-2M). (**spatial control task**) + ``` + bash train/script/data_download/data_download2.sh + ``` + **Note:** By default, only a few files are downloaded. You can modify `data_download2.sh` to download additional datasets. Remember to update the config file to specify the training data accordingly. + +## Training + +### Start training training +**Config file path**: `./train/config` + +**Scripts path**: `./train/script` + +1. Subject-driven generation + ```bash + bash train/script/train_subject.sh + ``` +2. Spatial control task + ```bash + bash train/script/train_canny.sh + ``` + +**Note**: Detailed WanDB settings and GPU settings can be found in the script files and the config files. + +### Other spatial control tasks +This repository supports 5 spatial control tasks: +1. Canny edge to image (`canny`) +2. Image colorization (`coloring`) +3. Image deblurring (`deblurring`) +4. Depth map to image (`depth`) +5. Image to depth map (`depth_pred`) +6. Image inpainting (`fill`) +7. Super resolution (`sr`) + +You can modify the `condition_type` parameter in config file `config/canny_512.yaml` to switch between different tasks. + +### Customize your own task +You can customize your own task by constructing a new dataset and modifying the training code. + +
+Instructions + +1. **Dataset** : + + Construct a new dataset with the following format: (`src/train/data.py`) + ```python + class MyDataset(Dataset): + def __init__(self, ...): + ... + def __len__(self): + ... + def __getitem__(self, idx): + ... + return { + "image": image, + "condition": condition_img, + "condition_type": "your_condition_type", + "description": description, + "position_delta": position_delta + } + ``` + **Note:** For spatial control tasks, set the `position_delta` to be `[0, 0]`. For non-spatial control tasks, set `position_delta` to be `[0, -condition_width // 16]`. +2. **Condition**: + + Add a new condition type in the `Condition` class. (`src/flux/condition.py`) + ```python + condition_dict = { + ... + "your_condition_type": your_condition_id_number, # Add your condition type here + } + ... + if condition_type in [ + ... + "your_condition_type", # Add your condition type here + ]: + ... + ``` +3. **Test**: + + Add a new test function for your task. (`src/train/callbacks.py`) + ```python + if self.condition_type == "your_condition_type": + condition_img = ( + Image.open("images/vase.jpg") + .resize((condition_size, condition_size)) + .convert("RGB") + ) + ... + test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) + ``` + +4. **Import relevant dataset in the training script** + Update the file in the following section. (`src/train/train.py`) + ```python + from .data import ( + ImageConditionDataset, + Subject200KDateset, + MyDataset + ) + ... + + # Initialize dataset and dataloader + if training_config["dataset"]["type"] == "your_condition_type": + ... + ``` + +
+ +## Hardware requirement +**Note**: Memory optimization (like dynamic T5 model loading) is pending implementation. + +**Recommanded** +- Hardware: 2x NVIDIA H100 GPUs +- Memory: ~80GB GPU memory + +**Minimal** +- Hardware: 1x NVIDIA L20 GPU +- Memory: ~48GB GPU memory \ No newline at end of file diff --git a/OminiControl/train/config/canny_512.yaml b/OminiControl/train/config/canny_512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a0a4602870ecba8bf8dd322b6333aa40ee74cc0 --- /dev/null +++ b/OminiControl/train/config/canny_512.yaml @@ -0,0 +1,48 @@ +flux_path: "black-forest-labs/FLUX.1-dev" +dtype: "bfloat16" + +model: + union_cond_attn: true + add_cond_attn: false + latent_lora: false + +train: + batch_size: 1 + accumulate_grad_batches: 1 + dataloader_workers: 5 + save_interval: 1000 + sample_interval: 100 + max_steps: -1 + gradient_checkpointing: true + save_path: "runs" + + # Specify the type of condition to use. + # Options: ["canny", "coloring", "deblurring", "depth", "depth_pred", "fill"] + condition_type: "canny" + dataset: + type: "img" + urls: + - "https://huggingface.co./datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000046.tar" + - "https://huggingface.co./datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000045.tar" + cache_name: "data_512_2M" + condition_size: 512 + target_size: 512 + drop_text_prob: 0.1 + drop_image_prob: 0.1 + + wandb: + project: "OminiControl" + + lora_config: + r: 4 + lora_alpha: 4 + init_lora_weights: "gaussian" + target_modules: "(.*x_embedder|.*(?