diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..d30cce47e082646606c64405d3bf8464aa3145a7 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/cartoon_boy.png filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/clock.jpg filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/demo/demo_this_is_omini_control.jpg filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/demo/dreambooth_res.jpg filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/demo/monalisa_omini.jpg filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/demo/scene_variation.jpg filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/demo/try_on.jpg filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/monalisa.jpg filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/rc_car.jpg filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/room_corner.jpg filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/tshirt.jpg filter=lfs diff=lfs merge=lfs -text
+OminiControl/assets/vase_hq.jpg filter=lfs diff=lfs merge=lfs -text
diff --git a/OminiControl/LICENSE b/OminiControl/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..7b2de994465d5d0c1f37fc6aecd065f8849ed2d8
--- /dev/null
+++ b/OminiControl/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [2024] [Zhenxiong Tan]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/OminiControl/README.md b/OminiControl/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6ab48029971c05775acc4146eec40ade2dad451b
--- /dev/null
+++ b/OminiControl/README.md
@@ -0,0 +1,170 @@
+# OminiControl
+
+
+
+
+
+
+
+
+
+
+
+> **OminiControl: Minimal and Universal Control for Diffusion Transformer**
+>
+> Zhenxiong Tan,
+> [Songhua Liu](http://121.37.94.87/),
+> [Xingyi Yang](https://adamdad.github.io/),
+> Qiaochu Xue,
+> and
+> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
+>
+> [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore
+>
+
+
+## Features
+
+OminiControl is a minimal yet powerful universal control framework for Diffusion Transformer models like [FLUX](https://github.com/black-forest-labs/flux).
+
+* **Universal Control 🌐**: A unified control framework that supports both subject-driven control and spatial control (such as edge-guided and in-painting generation).
+
+* **Minimal Design 🚀**: Injects control signals while preserving original model structure. Only introduces 0.1% additional parameters to the base model.
+
+## News
+- **2024-12-26**: ⭐️ Training code are released. Now you can create your own OminiControl model by customizing any control tasks (3D, multi-view, pose-guided, try-on, etc.) with the FLUX model. Check the [training folder](./train) for more details.
+
+## Quick Start
+### Setup (Optional)
+1. **Environment setup**
+```bash
+conda create -n omini python=3.10
+conda activate omini
+```
+2. **Requirements installation**
+```bash
+pip install -r requirements.txt
+```
+### Usage example
+1. Subject-driven generation: `examples/subject.ipynb`
+2. In-painting: `examples/inpainting.ipynb`
+3. Canny edge to image, depth to image, colorization, deblurring: `examples/spatial.ipynb`
+
+### Gradio app
+To run the Gradio app for subject-driven generation:
+```bash
+python -m src.gradio.gradio_app
+```
+
+### Guidelines for subject-driven generation
+1. Input images are automatically center-cropped and resized to 512x512 resolution.
+2. When writing prompts, refer to the subject using phrases like `this item`, `the object`, or `it`. e.g.
+ 1. *A close up view of this item. It is placed on a wooden table.*
+ 2. *A young lady is wearing this shirt.*
+3. The model primarily works with objects rather than human subjects currently, due to the absence of human data in training.
+
+## Generated samples
+### Subject-driven generation
+
+
+**Demos** (Left: condition image; Right: generated image)
+
+
+
+
+Text Prompts
+
+- Prompt1: *A close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!.'*
+- Prompt2: *A film style shot. On the moon, this item drives across the moon surface. A flag on it reads 'Omini'. The background is that Earth looms large in the foreground.*
+- Prompt3: *In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.*
+- Prompt4: *"On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple."*
+
+
+More results
+
+* Try on:
+
+* Scene variations:
+
+* Dreambooth dataset:
+
+* Oye-cartoon finetune:
+
+

+

+
+
+
+### Spatially aligned control
+1. **Image Inpainting** (Left: original image; Center: masked image; Right: filled image)
+ - Prompt: *The Mona Lisa is wearing a white VR headset with 'Omini' written on it.*
+
+
+ - Prompt: *A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.*
+
+
+2. **Other spatially aligned tasks** (Canny edge to image, depth to image, colorization, deblurring)
+
+
+ Click to show
+
+
+ Prompt: *A light gray sofa stands against a white wall, featuring a black and white geometric patterned pillow. A white side table sits next to the sofa, topped with a white adjustable desk lamp and some books. Dark hardwood flooring contrasts with the pale walls and furniture.*
+
+
+
+
+
+## Models
+
+**Subject-driven control:**
+| Model | Base model | Description | Resolution |
+| ------------------------------------------------------------------------------------------------ | -------------- | -------------------------------------------------------------------------------------------------------- | ------------ |
+| [`experimental`](https://huggingface.co./Yuanshi/OminiControl/tree/main/experimental) / `subject` | FLUX.1-schnell | The model used in the paper. | (512, 512) |
+| [`omini`](https://huggingface.co./Yuanshi/OminiControl/tree/main/omini) / `subject_512` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset. | (512, 512) |
+| [`omini`](https://huggingface.co./Yuanshi/OminiControl/tree/main/omini) / `subject_1024` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset and accommodates higher resolution. (To be released) | (1024, 1024) |
+| [`oye-cartoon`](https://huggingface.co./saquiboye/oye-cartoon) | FLUX.1-dev | The model has been fine-tuned on [oye-cartoon](https://huggingface.co./datasets/saquiboye/oye-cartoon) dataset by [@saquib764](https://github.com/Saquib764) | (512, 512) |
+
+**Spatial aligned control:**
+| Model | Base model | Description | Resolution |
+| --------------------------------------------------------------------------------------------------------- | ---------- | -------------------------------------------------------------------------- | ------------ |
+| [`experimental`](https://huggingface.co./Yuanshi/OminiControl/tree/main/experimental) / `` | FLUX.1 | Canny edge to image, depth to image, colorization, deblurring, in-painting | (512, 512) |
+| [`experimental`](https://huggingface.co./Yuanshi/OminiControl/tree/main/experimental) / `_1024` | FLUX.1 | Supports higher resolution.(To be released) | (1024, 1024) |
+
+## Community Extensions
+- [ComfyUI-Diffusers-OminiControl](https://github.com/Macoron/ComfyUI-Diffusers-OminiControl) - ComfyUI integration by [@Macoron](https://github.com/Macoron)
+- [ComfyUI_RH_OminiControl](https://github.com/HM-RunningHub/ComfyUI_RH_OminiControl) - ComfyUI integration by [@HM-RunningHub](https://github.com/HM-RunningHub)
+
+## Limitations
+1. The model's subject-driven generation primarily works with objects rather than human subjects due to the absence of human data in training.
+2. The subject-driven generation model may not work well with `FLUX.1-dev`.
+3. The released model currently only supports the resolution of 512x512.
+
+## Training
+Training instructions can be found in this [folder](./train).
+
+
+## To-do
+- [x] Release the training code.
+- [ ] Release the model for higher resolution (1024x1024).
+
+## Citation
+```
+@article{tan2024ominicontrol,
+ title={Ominicontrol: Minimal and universal control for diffusion transformer},
+ author={Tan, Zhenxiong and Liu, Songhua and Yang, Xingyi and Xue, Qiaochu and Wang, Xinchao},
+ journal={arXiv preprint arXiv:2411.15098},
+ volume={3},
+ year={2024}
+}
+```
diff --git a/OminiControl/assets/book.jpg b/OminiControl/assets/book.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..da7069c305f50ecd042ba889ac1624d02206e98f
Binary files /dev/null and b/OminiControl/assets/book.jpg differ
diff --git a/OminiControl/assets/cartoon_boy.png b/OminiControl/assets/cartoon_boy.png
new file mode 100644
index 0000000000000000000000000000000000000000..2ad7e896bf6f8f645752c2bb9545326536d4b73b
--- /dev/null
+++ b/OminiControl/assets/cartoon_boy.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d4a82c0f9ed09b9468bded7d901beffaf29addc30ed5f72ad72451e1b6344b1c
+size 428900
diff --git a/OminiControl/assets/clock.jpg b/OminiControl/assets/clock.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..eeee9c0c7ee5281344c4010dcfb09cd40bf18e91
--- /dev/null
+++ b/OminiControl/assets/clock.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41235973f26152ac92d32bfc166fb5f9f1e352c5e16807920238473316ec462b
+size 288877
diff --git a/OminiControl/assets/coffee.png b/OminiControl/assets/coffee.png
new file mode 100644
index 0000000000000000000000000000000000000000..36ff439ef3a00b597bb587237e79a5812514dc81
Binary files /dev/null and b/OminiControl/assets/coffee.png differ
diff --git a/OminiControl/assets/demo/book_omini.jpg b/OminiControl/assets/demo/book_omini.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9cadef5e92d5315ee66f07322c3608be0743868c
Binary files /dev/null and b/OminiControl/assets/demo/book_omini.jpg differ
diff --git a/OminiControl/assets/demo/clock_omini.jpg b/OminiControl/assets/demo/clock_omini.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..df0341e836eb39ea21c4ac38d143228efc6909f6
Binary files /dev/null and b/OminiControl/assets/demo/clock_omini.jpg differ
diff --git a/OminiControl/assets/demo/demo_this_is_omini_control.jpg b/OminiControl/assets/demo/demo_this_is_omini_control.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5fa77c908af2db85f98b1d238914b1b92ad85b9e
--- /dev/null
+++ b/OminiControl/assets/demo/demo_this_is_omini_control.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:798b7c25be6be118dc0de97c444c840869afca633a0d48f99d940aec040a7518
+size 129406
diff --git a/OminiControl/assets/demo/dreambooth_res.jpg b/OminiControl/assets/demo/dreambooth_res.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4f369c03887576cb63229c2b535fd15f7e189023
--- /dev/null
+++ b/OminiControl/assets/demo/dreambooth_res.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba36bd861989564dc679acf3b5e56f382f1a11b1596e6f611ea0bd7d81b89680
+size 1935753
diff --git a/OminiControl/assets/demo/man_omini.jpg b/OminiControl/assets/demo/man_omini.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cf8f54498d02e0b38914a80617ca420f81231e6c
Binary files /dev/null and b/OminiControl/assets/demo/man_omini.jpg differ
diff --git a/OminiControl/assets/demo/monalisa_omini.jpg b/OminiControl/assets/demo/monalisa_omini.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..de43c96b35ceafea510add2a9815c56f444a91b1
--- /dev/null
+++ b/OminiControl/assets/demo/monalisa_omini.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5ca6c2bf44f19d216b2eb16dcc67d19f11d87220d3ee80f5e5e1ad98a5536dc
+size 132717
diff --git a/OminiControl/assets/demo/oranges_omini.jpg b/OminiControl/assets/demo/oranges_omini.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..30554406b373eb1547d6922e8c18eef4e8efd292
Binary files /dev/null and b/OminiControl/assets/demo/oranges_omini.jpg differ
diff --git a/OminiControl/assets/demo/panda_omini.jpg b/OminiControl/assets/demo/panda_omini.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e69743402e70ee47e124344f5b50fd8a2f8cd41d
Binary files /dev/null and b/OminiControl/assets/demo/panda_omini.jpg differ
diff --git a/OminiControl/assets/demo/penguin_omini.jpg b/OminiControl/assets/demo/penguin_omini.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..15726e226d88e0216f7d8605baf2742b807782a5
Binary files /dev/null and b/OminiControl/assets/demo/penguin_omini.jpg differ
diff --git a/OminiControl/assets/demo/rc_car_omini.jpg b/OminiControl/assets/demo/rc_car_omini.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2cd6395e8aeb241c452b92f2655e2975b2b8a5e6
Binary files /dev/null and b/OminiControl/assets/demo/rc_car_omini.jpg differ
diff --git a/OminiControl/assets/demo/room_corner_canny.jpg b/OminiControl/assets/demo/room_corner_canny.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..295943cec832f5948283377ad68ed486e00bdac5
Binary files /dev/null and b/OminiControl/assets/demo/room_corner_canny.jpg differ
diff --git a/OminiControl/assets/demo/room_corner_coloring.jpg b/OminiControl/assets/demo/room_corner_coloring.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..087c5fb8912378850b3c1d745f591f7809f08c85
Binary files /dev/null and b/OminiControl/assets/demo/room_corner_coloring.jpg differ
diff --git a/OminiControl/assets/demo/room_corner_deblurring.jpg b/OminiControl/assets/demo/room_corner_deblurring.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..29ef78bb776f4a1fcc0d7a77435a32d40ec2fef6
Binary files /dev/null and b/OminiControl/assets/demo/room_corner_deblurring.jpg differ
diff --git a/OminiControl/assets/demo/room_corner_depth.jpg b/OminiControl/assets/demo/room_corner_depth.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d385433fc25795ad8c1a10ebaf518a3cfd1bbab0
Binary files /dev/null and b/OminiControl/assets/demo/room_corner_depth.jpg differ
diff --git a/OminiControl/assets/demo/scene_variation.jpg b/OminiControl/assets/demo/scene_variation.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2b84d8686cba903b5178a5b555a6fdb389aa89b1
--- /dev/null
+++ b/OminiControl/assets/demo/scene_variation.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:39e4e16d2eeb58b3775b6d34c8b3e125d0d19cc36fa90b07c6c8d57624ad4333
+size 958333
diff --git a/OminiControl/assets/demo/shirt_omini.jpg b/OminiControl/assets/demo/shirt_omini.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..221880750404583c0db0c308ae99b82d7bcc5127
Binary files /dev/null and b/OminiControl/assets/demo/shirt_omini.jpg differ
diff --git a/OminiControl/assets/demo/try_on.jpg b/OminiControl/assets/demo/try_on.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f34aec443f2932ebe5d59fc321b34f6b61232bbc
--- /dev/null
+++ b/OminiControl/assets/demo/try_on.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6adce5194329a83f0109b4375e00667c341879e64fb55831c70ea3f3b2f99f7e
+size 774269
diff --git a/OminiControl/assets/monalisa.jpg b/OminiControl/assets/monalisa.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dbfd24ae1941ae2f44179aab6d17e1e79e978ea9
--- /dev/null
+++ b/OminiControl/assets/monalisa.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:188b8b6499e4541f9dfef2a9daf6f1eb920079c9208f587fd97566d6aa4a9719
+size 352759
diff --git a/OminiControl/assets/oranges.jpg b/OminiControl/assets/oranges.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..19e38651ebac7633d4b67efd1bf1f0a111e4bd6a
Binary files /dev/null and b/OminiControl/assets/oranges.jpg differ
diff --git a/OminiControl/assets/penguin.jpg b/OminiControl/assets/penguin.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9b98535482e4bfe6e1febb64d595ef167a404d63
Binary files /dev/null and b/OminiControl/assets/penguin.jpg differ
diff --git a/OminiControl/assets/rc_car.jpg b/OminiControl/assets/rc_car.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0f31728c3f1164180ca58a84c6b90a715e349a64
--- /dev/null
+++ b/OminiControl/assets/rc_car.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae8aed11029fa3b084deb286c07a8cab5056840c9c123816fe2b504e94233e95
+size 254155
diff --git a/OminiControl/assets/room_corner.jpg b/OminiControl/assets/room_corner.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e16dea6be5f92137c5d82e8469a7ae475e25f8f2
--- /dev/null
+++ b/OminiControl/assets/room_corner.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f97bd63df05f5f15ad5dd1a2ccef803e74e12caadd8fe145493fd6d5219045e7
+size 236101
diff --git a/OminiControl/assets/test_in.jpg b/OminiControl/assets/test_in.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4ecce80b5cbc8fc6754279731207f6c7bcc1b3ec
Binary files /dev/null and b/OminiControl/assets/test_in.jpg differ
diff --git a/OminiControl/assets/test_out.jpg b/OminiControl/assets/test_out.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9b98535482e4bfe6e1febb64d595ef167a404d63
Binary files /dev/null and b/OminiControl/assets/test_out.jpg differ
diff --git a/OminiControl/assets/tshirt.jpg b/OminiControl/assets/tshirt.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..76c495a4da1f0daeffe660cbb4dba602e5c83233
--- /dev/null
+++ b/OminiControl/assets/tshirt.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cb1803315765302113a9e7a64dedd4ecba2672028cf093cbc33ef2edd2247c39
+size 301252
diff --git a/OminiControl/assets/vase.jpg b/OminiControl/assets/vase.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7acb40eca66987b29cafb3a8d4c0a9207dcc5bf8
Binary files /dev/null and b/OminiControl/assets/vase.jpg differ
diff --git a/OminiControl/assets/vase_hq.jpg b/OminiControl/assets/vase_hq.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..378d2487450ec04b1981556bf6dadc344774eaad
--- /dev/null
+++ b/OminiControl/assets/vase_hq.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:279905e32116792f118802d23b0d96629d98ccbdac9e704e65eaf2e98c752679
+size 2901712
diff --git a/OminiControl/examples/inpainting.ipynb b/OminiControl/examples/inpainting.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..2a48abbf702bb2bf272e93810a220130dd529c13
--- /dev/null
+++ b/OminiControl/examples/inpainting.ipynb
@@ -0,0 +1,143 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.chdir(\"..\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from diffusers.pipelines import FluxPipeline\n",
+ "from src.flux.condition import Condition\n",
+ "from PIL import Image\n",
+ "\n",
+ "from src.flux.generate import generate, seed_everything"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe = FluxPipeline.from_pretrained(\n",
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
+ ")\n",
+ "pipe = pipe.to(\"cuda\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe.load_lora_weights(\n",
+ " \"Yuanshi/OminiControl\",\n",
+ " weight_name=f\"experimental/fill.safetensors\",\n",
+ " adapter_name=\"fill\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/monalisa.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "masked_image = image.copy()\n",
+ "masked_image.paste((0, 0, 0), (128, 100, 384, 220))\n",
+ "\n",
+ "condition = Condition(\"fill\", masked_image)\n",
+ "\n",
+ "seed_everything()\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=\"The Mona Lisa is wearing a white VR headset with 'Omini' written on it.\",\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/book.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "w, h, min_dim = image.size + (min(image.size),)\n",
+ "image = image.crop(\n",
+ " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
+ ").resize((512, 512))\n",
+ "\n",
+ "\n",
+ "masked_image = image.copy()\n",
+ "masked_image.paste((0, 0, 0), (150, 150, 350, 250))\n",
+ "masked_image.paste((0, 0, 0), (200, 380, 320, 420))\n",
+ "\n",
+ "condition = Condition(\"fill\", masked_image)\n",
+ "\n",
+ "seed_everything()\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=\"A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.\",\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/OminiControl/examples/spatial.ipynb b/OminiControl/examples/spatial.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..e9a918717884500f8176e731e59af2429ae2fdc0
--- /dev/null
+++ b/OminiControl/examples/spatial.ipynb
@@ -0,0 +1,184 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.chdir(\"..\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from diffusers.pipelines import FluxPipeline\n",
+ "from src.flux.condition import Condition\n",
+ "from PIL import Image\n",
+ "\n",
+ "from src.flux.generate import generate, seed_everything"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe = FluxPipeline.from_pretrained(\n",
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
+ ")\n",
+ "pipe = pipe.to(\"cuda\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for condition_type in [\"canny\", \"depth\", \"coloring\", \"deblurring\"]:\n",
+ " pipe.load_lora_weights(\n",
+ " \"Yuanshi/OminiControl\",\n",
+ " weight_name=f\"experimental/{condition_type}.safetensors\",\n",
+ " adapter_name=condition_type,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/coffee.png\").convert(\"RGB\")\n",
+ "\n",
+ "w, h, min_dim = image.size + (min(image.size),)\n",
+ "image = image.crop(\n",
+ " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
+ ").resize((512, 512))\n",
+ "\n",
+ "prompt = \"In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table.\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "condition = Condition(\"canny\", image)\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "condition = Condition(\"depth\", image)\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "condition = Condition(\"deblurring\", image)\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "condition = Condition(\"coloring\", image)\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/OminiControl/examples/subject.ipynb b/OminiControl/examples/subject.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..249d0773ebae95a02d38100ca767bb64115fe5dc
--- /dev/null
+++ b/OminiControl/examples/subject.ipynb
@@ -0,0 +1,214 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.chdir(\"..\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from diffusers.pipelines import FluxPipeline\n",
+ "from src.flux.condition import Condition\n",
+ "from PIL import Image\n",
+ "\n",
+ "from src.flux.generate import generate, seed_everything"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe = FluxPipeline.from_pretrained(\n",
+ " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
+ ")\n",
+ "pipe = pipe.to(\"cuda\")\n",
+ "pipe.load_lora_weights(\n",
+ " \"Yuanshi/OminiControl\",\n",
+ " weight_name=f\"omini/subject_512.safetensors\",\n",
+ " adapter_name=\"subject\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
+ "\n",
+ "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
+ "\n",
+ "\n",
+ "seed_everything(0)\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=512,\n",
+ " width=512,\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
+ "\n",
+ "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
+ "\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=512,\n",
+ " width=512,\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
+ "concat_image.paste(condition.condition, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
+ "\n",
+ "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=512,\n",
+ " width=512,\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
+ "concat_image.paste(condition.condition, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
+ "\n",
+ "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=512,\n",
+ " width=512,\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
+ "concat_image.paste(condition.condition, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
+ "\n",
+ "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=512,\n",
+ " width=512,\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
+ "concat_image.paste(condition.condition, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/OminiControl/examples/subject_1024.ipynb b/OminiControl/examples/subject_1024.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..d0f9de9e7d86dc9aa455b0b162c93ca764f346b3
--- /dev/null
+++ b/OminiControl/examples/subject_1024.ipynb
@@ -0,0 +1,221 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.chdir(\"..\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from diffusers.pipelines import FluxPipeline\n",
+ "from src.flux.condition import Condition\n",
+ "from PIL import Image\n",
+ "\n",
+ "from src.flux.generate import generate, seed_everything"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe = FluxPipeline.from_pretrained(\n",
+ " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
+ ")\n",
+ "pipe = pipe.to(\"cuda\")\n",
+ "pipe.load_lora_weights(\n",
+ " \"Yuanshi/OminiControl\",\n",
+ " weight_name=f\"omini/subject_1024_beta.safetensors\",\n",
+ " adapter_name=\"subject\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(\"subject\", image)\n",
+ "\n",
+ "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
+ "\n",
+ "\n",
+ "seed_everything(0)\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=1024,\n",
+ " width=1024,\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(\"subject\", image)\n",
+ "\n",
+ "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
+ "\n",
+ "\n",
+ "seed_everything(0)\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=1024,\n",
+ " width=1024,\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(\"subject\", image)\n",
+ "\n",
+ "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=1024,\n",
+ " width=1024,\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(\"subject\", image)\n",
+ "\n",
+ "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
+ "\n",
+ "seed_everything(0)\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=1024,\n",
+ " width=1024,\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(\"subject\", image)\n",
+ "\n",
+ "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=1024,\n",
+ " width=1024,\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.21"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/OminiControl/requirements.txt b/OminiControl/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6d0bc83f27098b2cf17e27de67797092ff6bbee1
--- /dev/null
+++ b/OminiControl/requirements.txt
@@ -0,0 +1,9 @@
+transformers
+diffusers
+peft
+opencv-python
+protobuf
+sentencepiece
+gradio
+jupyter
+torchao
\ No newline at end of file
diff --git a/OminiControl/src/flux/block.py b/OminiControl/src/flux/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c9123abca755a4d66ac537260aaf24df2fd04e6
--- /dev/null
+++ b/OminiControl/src/flux/block.py
@@ -0,0 +1,339 @@
+import torch
+from typing import List, Union, Optional, Dict, Any, Callable
+from diffusers.models.attention_processor import Attention, F
+from .lora_controller import enable_lora
+
+
+def attn_forward(
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ condition_latents: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ cond_rotary_emb: Optional[torch.Tensor] = None,
+ model_config: Optional[Dict[str, Any]] = {},
+) -> torch.FloatTensor:
+ batch_size, _, _ = (
+ hidden_states.shape
+ if encoder_hidden_states is None
+ else encoder_hidden_states.shape
+ )
+
+ with enable_lora(
+ (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
+ ):
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ if encoder_hidden_states is not None:
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(
+ encoder_hidden_states_query_proj
+ )
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(
+ encoder_hidden_states_key_proj
+ )
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from diffusers.models.embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ if condition_latents is not None:
+ cond_query = attn.to_q(condition_latents)
+ cond_key = attn.to_k(condition_latents)
+ cond_value = attn.to_v(condition_latents)
+
+ cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
+ 1, 2
+ )
+ cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
+ 1, 2
+ )
+ if attn.norm_q is not None:
+ cond_query = attn.norm_q(cond_query)
+ if attn.norm_k is not None:
+ cond_key = attn.norm_k(cond_key)
+
+ if cond_rotary_emb is not None:
+ cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
+ cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
+
+ if condition_latents is not None:
+ query = torch.cat([query, cond_query], dim=2)
+ key = torch.cat([key, cond_key], dim=2)
+ value = torch.cat([value, cond_value], dim=2)
+
+ if not model_config.get("union_cond_attn", True):
+ # If we don't want to use the union condition attention, we need to mask the attention
+ # between the hidden states and the condition latents
+ attention_mask = torch.ones(
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
+ )
+ condition_n = cond_query.shape[2]
+ attention_mask[-condition_n:, :-condition_n] = False
+ attention_mask[:-condition_n, -condition_n:] = False
+ elif model_config.get("independent_condition", False):
+ attention_mask = torch.ones(
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
+ )
+ condition_n = cond_query.shape[2]
+ attention_mask[-condition_n:, :-condition_n] = False
+ if hasattr(attn, "c_factor"):
+ attention_mask = torch.zeros(
+ query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
+ )
+ condition_n = cond_query.shape[2]
+ bias = torch.log(attn.c_factor[0])
+ attention_mask[-condition_n:, :-condition_n] = bias
+ attention_mask[:-condition_n, -condition_n:] = bias
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ if condition_latents is not None:
+ encoder_hidden_states, hidden_states, condition_latents = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
+ ],
+ hidden_states[:, -condition_latents.shape[1] :],
+ )
+ else:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if condition_latents is not None:
+ condition_latents = attn.to_out[0](condition_latents)
+ condition_latents = attn.to_out[1](condition_latents)
+
+ return (
+ (hidden_states, encoder_hidden_states, condition_latents)
+ if condition_latents is not None
+ else (hidden_states, encoder_hidden_states)
+ )
+ elif condition_latents is not None:
+ # if there are condition_latents, we need to separate the hidden_states and the condition_latents
+ hidden_states, condition_latents = (
+ hidden_states[:, : -condition_latents.shape[1]],
+ hidden_states[:, -condition_latents.shape[1] :],
+ )
+ return hidden_states, condition_latents
+ else:
+ return hidden_states
+
+
+def block_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor,
+ condition_latents: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ cond_temb: torch.FloatTensor,
+ cond_rotary_emb=None,
+ image_rotary_emb=None,
+ model_config: Optional[Dict[str, Any]] = {},
+):
+ use_cond = condition_latents is not None
+ with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, emb=temb
+ )
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
+ self.norm1_context(encoder_hidden_states, emb=temb)
+ )
+
+ if use_cond:
+ (
+ norm_condition_latents,
+ cond_gate_msa,
+ cond_shift_mlp,
+ cond_scale_mlp,
+ cond_gate_mlp,
+ ) = self.norm1(condition_latents, emb=cond_temb)
+
+ # Attention.
+ result = attn_forward(
+ self.attn,
+ model_config=model_config,
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ condition_latents=norm_condition_latents if use_cond else None,
+ image_rotary_emb=image_rotary_emb,
+ cond_rotary_emb=cond_rotary_emb if use_cond else None,
+ )
+ attn_output, context_attn_output = result[:2]
+ cond_attn_output = result[2] if use_cond else None
+
+ # Process attention outputs for the `hidden_states`.
+ # 1. hidden_states
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+ # 2. encoder_hidden_states
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+ # 3. condition_latents
+ if use_cond:
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
+ condition_latents = condition_latents + cond_attn_output
+ if model_config.get("add_cond_attn", False):
+ hidden_states += cond_attn_output
+
+ # LayerNorm + MLP.
+ # 1. hidden_states
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = (
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ )
+ # 2. encoder_hidden_states
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = (
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+ )
+ # 3. condition_latents
+ if use_cond:
+ norm_condition_latents = self.norm2(condition_latents)
+ norm_condition_latents = (
+ norm_condition_latents * (1 + cond_scale_mlp[:, None])
+ + cond_shift_mlp[:, None]
+ )
+
+ # Feed-forward.
+ with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
+ # 1. hidden_states
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ # 2. encoder_hidden_states
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
+ # 3. condition_latents
+ if use_cond:
+ cond_ff_output = self.ff(norm_condition_latents)
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
+
+ # Process feed-forward outputs.
+ hidden_states = hidden_states + ff_output
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
+ if use_cond:
+ condition_latents = condition_latents + cond_ff_output
+
+ # Clip to avoid overflow.
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
+
+
+def single_block_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ image_rotary_emb=None,
+ condition_latents: torch.FloatTensor = None,
+ cond_temb: torch.FloatTensor = None,
+ cond_rotary_emb=None,
+ model_config: Optional[Dict[str, Any]] = {},
+):
+
+ using_cond = condition_latents is not None
+ residual = hidden_states
+ with enable_lora(
+ (
+ self.norm.linear,
+ self.proj_mlp,
+ ),
+ model_config.get("latent_lora", False),
+ ):
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ if using_cond:
+ residual_cond = condition_latents
+ norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
+
+ attn_output = attn_forward(
+ self.attn,
+ model_config=model_config,
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **(
+ {
+ "condition_latents": norm_condition_latents,
+ "cond_rotary_emb": cond_rotary_emb if using_cond else None,
+ }
+ if using_cond
+ else {}
+ ),
+ )
+ if using_cond:
+ attn_output, cond_attn_output = attn_output
+
+ with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if using_cond:
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
+ cond_gate = cond_gate.unsqueeze(1)
+ condition_latents = cond_gate * self.proj_out(condition_latents)
+ condition_latents = residual_cond + condition_latents
+
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return hidden_states if not using_cond else (hidden_states, condition_latents)
diff --git a/OminiControl/src/flux/condition.py b/OminiControl/src/flux/condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2dcd36f817f1285f8e9060f4ea9583742e04e8e
--- /dev/null
+++ b/OminiControl/src/flux/condition.py
@@ -0,0 +1,138 @@
+import torch
+from typing import Optional, Union, List, Tuple
+from diffusers.pipelines import FluxPipeline
+from PIL import Image, ImageFilter
+import numpy as np
+import cv2
+
+from .pipeline_tools import encode_images
+
+condition_dict = {
+ "depth": 0,
+ "canny": 1,
+ "subject": 4,
+ "coloring": 6,
+ "deblurring": 7,
+ "depth_pred": 8,
+ "fill": 9,
+ "sr": 10,
+ "cartoon": 11,
+}
+
+
+class Condition(object):
+ def __init__(
+ self,
+ condition_type: str,
+ raw_img: Union[Image.Image, torch.Tensor] = None,
+ condition: Union[Image.Image, torch.Tensor] = None,
+ mask=None,
+ position_delta=None,
+ position_scale=1.0,
+ ) -> None:
+ self.condition_type = condition_type
+ assert raw_img is not None or condition is not None
+ if raw_img is not None:
+ self.condition = self.get_condition(condition_type, raw_img)
+ else:
+ self.condition = condition
+ self.position_delta = position_delta
+ self.position_scale = position_scale
+ # TODO: Add mask support
+ assert mask is None, "Mask not supported yet"
+
+ def get_condition(
+ self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
+ ) -> Union[Image.Image, torch.Tensor]:
+ """
+ Returns the condition image.
+ """
+ if condition_type == "depth":
+ from transformers import pipeline
+
+ depth_pipe = pipeline(
+ task="depth-estimation",
+ model="LiheYoung/depth-anything-small-hf",
+ device="cuda",
+ )
+ source_image = raw_img.convert("RGB")
+ condition_img = depth_pipe(source_image)["depth"].convert("RGB")
+ return condition_img
+ elif condition_type == "canny":
+ img = np.array(raw_img)
+ edges = cv2.Canny(img, 100, 200)
+ edges = Image.fromarray(edges).convert("RGB")
+ return edges
+ elif condition_type == "subject":
+ return raw_img
+ elif condition_type == "coloring":
+ return raw_img.convert("L").convert("RGB")
+ elif condition_type == "deblurring":
+ condition_image = (
+ raw_img.convert("RGB")
+ .filter(ImageFilter.GaussianBlur(10))
+ .convert("RGB")
+ )
+ return condition_image
+ elif condition_type == "fill":
+ return raw_img.convert("RGB")
+ elif condition_type == "cartoon":
+ return raw_img.convert("RGB")
+ return self.condition
+
+ @property
+ def type_id(self) -> int:
+ """
+ Returns the type id of the condition.
+ """
+ return condition_dict[self.condition_type]
+
+ @classmethod
+ def get_type_id(cls, condition_type: str) -> int:
+ """
+ Returns the type id of the condition.
+ """
+ return condition_dict[condition_type]
+
+ def encode(
+ self, pipe: FluxPipeline, empty: bool = False
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ """
+ Encodes the condition into tokens, ids and type_id.
+ """
+ if self.condition_type in [
+ "depth",
+ "canny",
+ "subject",
+ "coloring",
+ "deblurring",
+ "depth_pred",
+ "fill",
+ "sr",
+ "cartoon",
+ ]:
+ if empty:
+ # make the condition black
+ e_condition = Image.new("RGB", self.condition.size, (0, 0, 0))
+ e_condition = e_condition.convert("RGB")
+ tokens, ids = encode_images(pipe, e_condition)
+ else:
+ tokens, ids = encode_images(pipe, self.condition)
+ tokens, ids = encode_images(pipe, self.condition)
+ else:
+ raise NotImplementedError(
+ f"Condition type {self.condition_type} not implemented"
+ )
+ if self.position_delta is None and self.condition_type == "subject":
+ self.position_delta = [0, -self.condition.size[0] // 16]
+ if self.position_delta is not None:
+ ids[:, 1] += self.position_delta[0]
+ ids[:, 2] += self.position_delta[1]
+ if self.position_scale != 1.0:
+ scale_bias = (self.position_scale - 1.0) / 2
+ ids[:, 1] *= self.position_scale
+ ids[:, 2] *= self.position_scale
+ ids[:, 1] += scale_bias
+ ids[:, 2] += scale_bias
+ type_id = torch.ones_like(ids[:, :1]) * self.type_id
+ return tokens, ids, type_id
diff --git a/OminiControl/src/flux/generate.py b/OminiControl/src/flux/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..a380633d740a5dbad731d0e154f26f16e4b5bc13
--- /dev/null
+++ b/OminiControl/src/flux/generate.py
@@ -0,0 +1,321 @@
+import torch
+import yaml, os
+from diffusers.pipelines import FluxPipeline
+from typing import List, Union, Optional, Dict, Any, Callable
+from .transformer import tranformer_forward
+from .condition import Condition
+
+from diffusers.pipelines.flux.pipeline_flux import (
+ FluxPipelineOutput,
+ calculate_shift,
+ retrieve_timesteps,
+ np,
+)
+
+
+def get_config(config_path: str = None):
+ config_path = config_path or os.environ.get("XFL_CONFIG")
+ if not config_path:
+ return {}
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+ return config
+
+
+def prepare_params(
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ **kwargs: dict,
+):
+ return (
+ prompt,
+ prompt_2,
+ height,
+ width,
+ num_inference_steps,
+ timesteps,
+ guidance_scale,
+ num_images_per_prompt,
+ generator,
+ latents,
+ prompt_embeds,
+ pooled_prompt_embeds,
+ output_type,
+ return_dict,
+ joint_attention_kwargs,
+ callback_on_step_end,
+ callback_on_step_end_tensor_inputs,
+ max_sequence_length,
+ )
+
+
+def seed_everything(seed: int = 42):
+ torch.backends.cudnn.deterministic = True
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+
+@torch.no_grad()
+def generate(
+ pipeline: FluxPipeline,
+ conditions: List[Condition] = None,
+ config_path: str = None,
+ model_config: Optional[Dict[str, Any]] = {},
+ condition_scale: float = 1.0,
+ default_lora: bool = False,
+ image_guidance_scale: float = 1.0,
+ **params: dict,
+):
+ model_config = model_config or get_config(config_path).get("model", {})
+ if condition_scale != 1:
+ for name, module in pipeline.transformer.named_modules():
+ if not name.endswith(".attn"):
+ continue
+ module.c_factor = torch.ones(1, 1) * condition_scale
+
+ self = pipeline
+ (
+ prompt,
+ prompt_2,
+ height,
+ width,
+ num_inference_steps,
+ timesteps,
+ guidance_scale,
+ num_images_per_prompt,
+ generator,
+ latents,
+ prompt_embeds,
+ pooled_prompt_embeds,
+ output_type,
+ return_dict,
+ joint_attention_kwargs,
+ callback_on_step_end,
+ callback_on_step_end_tensor_inputs,
+ max_sequence_length,
+ ) = prepare_params(**params)
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None)
+ if self.joint_attention_kwargs is not None
+ else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 4.1. Prepare conditions
+ condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
+ use_condition = conditions is not None or []
+ if use_condition:
+ assert len(conditions) <= 1, "Only one condition is supported for now."
+ if not default_lora:
+ pipeline.set_adapters(conditions[0].condition_type)
+ for condition in conditions:
+ tokens, ids, type_id = condition.encode(self)
+ condition_latents.append(tokens) # [batch_size, token_n, token_dim]
+ condition_ids.append(ids) # [token_n, id_dim(3)]
+ condition_type_ids.append(type_id) # [token_n, 1]
+ condition_latents = torch.cat(condition_latents, dim=1)
+ condition_ids = torch.cat(condition_ids, dim=0)
+ condition_type_ids = torch.cat(condition_type_ids, dim=0)
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
+ )
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.tensor([guidance_scale], device=device)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+ noise_pred = tranformer_forward(
+ self.transformer,
+ model_config=model_config,
+ # Inputs of the condition (new feature)
+ condition_latents=condition_latents if use_condition else None,
+ condition_ids=condition_ids if use_condition else None,
+ condition_type_ids=condition_type_ids if use_condition else None,
+ # Inputs to the original transformer
+ hidden_states=latents,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if image_guidance_scale != 1.0:
+ uncondition_latents = condition.encode(self, empty=True)[0]
+ unc_pred = tranformer_forward(
+ self.transformer,
+ model_config=model_config,
+ # Inputs of the condition (new feature)
+ condition_latents=uncondition_latents if use_condition else None,
+ condition_ids=condition_ids if use_condition else None,
+ condition_type_ids=condition_type_ids if use_condition else None,
+ # Inputs to the original transformer
+ hidden_states=latents,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timestep / 1000,
+ guidance=torch.ones_like(guidance),
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (
+ latents / self.vae.config.scaling_factor
+ ) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if condition_scale != 1:
+ for name, module in pipeline.transformer.named_modules():
+ if not name.endswith(".attn"):
+ continue
+ del module.c_factor
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/OminiControl/src/flux/lora_controller.py b/OminiControl/src/flux/lora_controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..21b23eae2bdeaab171df4616e174dd6d96351620
--- /dev/null
+++ b/OminiControl/src/flux/lora_controller.py
@@ -0,0 +1,75 @@
+from peft.tuners.tuners_utils import BaseTunerLayer
+from typing import List, Any, Optional, Type
+
+
+class enable_lora:
+ def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
+ self.activated: bool = activated
+ if activated:
+ return
+ self.lora_modules: List[BaseTunerLayer] = [
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
+ ]
+ self.scales = [
+ {
+ active_adapter: lora_module.scaling[active_adapter]
+ for active_adapter in lora_module.active_adapters
+ }
+ for lora_module in self.lora_modules
+ ]
+
+ def __enter__(self) -> None:
+ if self.activated:
+ return
+
+ for lora_module in self.lora_modules:
+ if not isinstance(lora_module, BaseTunerLayer):
+ continue
+ lora_module.scale_layer(0)
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[Any],
+ ) -> None:
+ if self.activated:
+ return
+ for i, lora_module in enumerate(self.lora_modules):
+ if not isinstance(lora_module, BaseTunerLayer):
+ continue
+ for active_adapter in lora_module.active_adapters:
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
+
+
+class set_lora_scale:
+ def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
+ self.lora_modules: List[BaseTunerLayer] = [
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
+ ]
+ self.scales = [
+ {
+ active_adapter: lora_module.scaling[active_adapter]
+ for active_adapter in lora_module.active_adapters
+ }
+ for lora_module in self.lora_modules
+ ]
+ self.scale = scale
+
+ def __enter__(self) -> None:
+ for lora_module in self.lora_modules:
+ if not isinstance(lora_module, BaseTunerLayer):
+ continue
+ lora_module.scale_layer(self.scale)
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[Any],
+ ) -> None:
+ for i, lora_module in enumerate(self.lora_modules):
+ if not isinstance(lora_module, BaseTunerLayer):
+ continue
+ for active_adapter in lora_module.active_adapters:
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
diff --git a/OminiControl/src/flux/pipeline_tools.py b/OminiControl/src/flux/pipeline_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..36174e4fe27dd12dbe629e3b92ca160b3600d889
--- /dev/null
+++ b/OminiControl/src/flux/pipeline_tools.py
@@ -0,0 +1,52 @@
+from diffusers.pipelines import FluxPipeline
+from diffusers.utils import logging
+from diffusers.pipelines.flux.pipeline_flux import logger
+from torch import Tensor
+
+
+def encode_images(pipeline: FluxPipeline, images: Tensor):
+ images = pipeline.image_processor.preprocess(images)
+ images = images.to(pipeline.device).to(pipeline.dtype)
+ images = pipeline.vae.encode(images).latent_dist.sample()
+ images = (
+ images - pipeline.vae.config.shift_factor
+ ) * pipeline.vae.config.scaling_factor
+ images_tokens = pipeline._pack_latents(images, *images.shape)
+ images_ids = pipeline._prepare_latent_image_ids(
+ images.shape[0],
+ images.shape[2],
+ images.shape[3],
+ pipeline.device,
+ pipeline.dtype,
+ )
+ if images_tokens.shape[1] != images_ids.shape[0]:
+ images_ids = pipeline._prepare_latent_image_ids(
+ images.shape[0],
+ images.shape[2] // 2,
+ images.shape[3] // 2,
+ pipeline.device,
+ pipeline.dtype,
+ )
+ return images_tokens, images_ids
+
+
+def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
+ # Turn off warnings (CLIP overflow)
+ logger.setLevel(logging.ERROR)
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = pipeline.encode_prompt(
+ prompt=prompts,
+ prompt_2=None,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ device=pipeline.device,
+ num_images_per_prompt=1,
+ max_sequence_length=max_sequence_length,
+ lora_scale=None,
+ )
+ # Turn on warnings
+ logger.setLevel(logging.WARNING)
+ return prompt_embeds, pooled_prompt_embeds, text_ids
diff --git a/OminiControl/src/flux/transformer.py b/OminiControl/src/flux/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a31e14bf9a72f1b42775419ad251ee5bb65e764
--- /dev/null
+++ b/OminiControl/src/flux/transformer.py
@@ -0,0 +1,252 @@
+import torch
+from diffusers.pipelines import FluxPipeline
+from typing import List, Union, Optional, Dict, Any, Callable
+from .block import block_forward, single_block_forward
+from .lora_controller import enable_lora
+from accelerate.utils import is_torch_version
+from diffusers.models.transformers.transformer_flux import (
+ FluxTransformer2DModel,
+ Transformer2DModelOutput,
+ USE_PEFT_BACKEND,
+ scale_lora_layers,
+ unscale_lora_layers,
+ logger,
+)
+import numpy as np
+
+
+def prepare_params(
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ return_dict: bool = True,
+ **kwargs: dict,
+):
+ return (
+ hidden_states,
+ encoder_hidden_states,
+ pooled_projections,
+ timestep,
+ img_ids,
+ txt_ids,
+ guidance,
+ joint_attention_kwargs,
+ controlnet_block_samples,
+ controlnet_single_block_samples,
+ return_dict,
+ )
+
+
+def tranformer_forward(
+ transformer: FluxTransformer2DModel,
+ condition_latents: torch.Tensor,
+ condition_ids: torch.Tensor,
+ condition_type_ids: torch.Tensor,
+ model_config: Optional[Dict[str, Any]] = {},
+ c_t=0,
+ **params: dict,
+):
+ self = transformer
+ use_condition = condition_latents is not None
+
+ (
+ hidden_states,
+ encoder_hidden_states,
+ pooled_projections,
+ timestep,
+ img_ids,
+ txt_ids,
+ guidance,
+ joint_attention_kwargs,
+ controlnet_block_samples,
+ controlnet_single_block_samples,
+ return_dict,
+ ) = prepare_params(**params)
+
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if (
+ joint_attention_kwargs is not None
+ and joint_attention_kwargs.get("scale", None) is not None
+ ):
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
+ hidden_states = self.x_embedder(hidden_states)
+ condition_latents = self.x_embedder(condition_latents) if use_condition else None
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+ else:
+ guidance = None
+
+ temb = (
+ self.time_text_embed(timestep, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, pooled_projections)
+ )
+
+ cond_temb = (
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(
+ torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
+ )
+ )
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+ if use_condition:
+ # condition_ids[:, :1] = condition_type_ids
+ cond_rotary_emb = self.pos_embed(condition_ids)
+
+ # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+ ckpt_kwargs: Dict[str, Any] = (
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ )
+ encoder_hidden_states, hidden_states, condition_latents = (
+ torch.utils.checkpoint.checkpoint(
+ block_forward,
+ self=block,
+ model_config=model_config,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ condition_latents=condition_latents if use_condition else None,
+ temb=temb,
+ cond_temb=cond_temb if use_condition else None,
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
+ image_rotary_emb=image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ )
+
+ else:
+ encoder_hidden_states, hidden_states, condition_latents = block_forward(
+ block,
+ model_config=model_config,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ condition_latents=condition_latents if use_condition else None,
+ temb=temb,
+ cond_temb=cond_temb if use_condition else None,
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(
+ controlnet_block_samples
+ )
+ interval_control = int(np.ceil(interval_control))
+ hidden_states = (
+ hidden_states
+ + controlnet_block_samples[index_block // interval_control]
+ )
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+ ckpt_kwargs: Dict[str, Any] = (
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ )
+ result = torch.utils.checkpoint.checkpoint(
+ single_block_forward,
+ self=block,
+ model_config=model_config,
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ **(
+ {
+ "condition_latents": condition_latents,
+ "cond_temb": cond_temb,
+ "cond_rotary_emb": cond_rotary_emb,
+ }
+ if use_condition
+ else {}
+ ),
+ **ckpt_kwargs,
+ )
+
+ else:
+ result = single_block_forward(
+ block,
+ model_config=model_config,
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ **(
+ {
+ "condition_latents": condition_latents,
+ "cond_temb": cond_temb,
+ "cond_rotary_emb": cond_rotary_emb,
+ }
+ if use_condition
+ else {}
+ ),
+ )
+ if use_condition:
+ hidden_states, condition_latents = result
+ else:
+ hidden_states = result
+
+ # controlnet residual
+ if controlnet_single_block_samples is not None:
+ interval_control = len(self.single_transformer_blocks) / len(
+ controlnet_single_block_samples
+ )
+ interval_control = int(np.ceil(interval_control))
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ + controlnet_single_block_samples[index_block // interval_control]
+ )
+
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/OminiControl/src/gradio/gradio_app.py b/OminiControl/src/gradio/gradio_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8190be9c0cd1985e664761a40aad5569dc19dc9
--- /dev/null
+++ b/OminiControl/src/gradio/gradio_app.py
@@ -0,0 +1,115 @@
+import gradio as gr
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from diffusers.pipelines import FluxPipeline
+from diffusers import FluxTransformer2DModel
+import numpy as np
+
+from ..flux.condition import Condition
+from ..flux.generate import seed_everything, generate
+
+pipe = None
+use_int8 = False
+
+
+def get_gpu_memory():
+ return torch.cuda.get_device_properties(0).total_memory / 1024**3
+
+
+def init_pipeline():
+ global pipe
+ if use_int8 or get_gpu_memory() < 33:
+ transformer_model = FluxTransformer2DModel.from_pretrained(
+ "sayakpaul/flux.1-schell-int8wo-improved",
+ torch_dtype=torch.bfloat16,
+ use_safetensors=False,
+ )
+ pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ transformer=transformer_model,
+ torch_dtype=torch.bfloat16,
+ )
+ else:
+ pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
+ )
+ pipe = pipe.to("cuda")
+ pipe.load_lora_weights(
+ "Yuanshi/OminiControl",
+ weight_name="omini/subject_512.safetensors",
+ adapter_name="subject",
+ )
+
+
+def process_image_and_text(image, text):
+ # center crop image
+ w, h, min_size = image.size[0], image.size[1], min(image.size)
+ image = image.crop(
+ (
+ (w - min_size) // 2,
+ (h - min_size) // 2,
+ (w + min_size) // 2,
+ (h + min_size) // 2,
+ )
+ )
+ image = image.resize((512, 512))
+
+ condition = Condition("subject", image, position_delta=(0, 32))
+
+ if pipe is None:
+ init_pipeline()
+
+ result_img = generate(
+ pipe,
+ prompt=text.strip(),
+ conditions=[condition],
+ num_inference_steps=8,
+ height=512,
+ width=512,
+ ).images[0]
+
+ return result_img
+
+
+def get_samples():
+ sample_list = [
+ {
+ "image": "assets/oranges.jpg",
+ "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
+ },
+ {
+ "image": "assets/penguin.jpg",
+ "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
+ },
+ {
+ "image": "assets/rc_car.jpg",
+ "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
+ },
+ {
+ "image": "assets/clock.jpg",
+ "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
+ },
+ {
+ "image": "assets/tshirt.jpg",
+ "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.",
+ },
+ ]
+ return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
+
+
+demo = gr.Interface(
+ fn=process_image_and_text,
+ inputs=[
+ gr.Image(type="pil"),
+ gr.Textbox(lines=2),
+ ],
+ outputs=gr.Image(type="pil"),
+ title="OminiControl / Subject driven generation",
+ examples=get_samples(),
+)
+
+if __name__ == "__main__":
+ init_pipeline()
+ demo.launch(
+ debug=True,
+ )
diff --git a/OminiControl/src/train/callbacks.py b/OminiControl/src/train/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ccc7488ce1f1d4cf60707eb8ea5cfe9659c7132
--- /dev/null
+++ b/OminiControl/src/train/callbacks.py
@@ -0,0 +1,253 @@
+import lightning as L
+from PIL import Image, ImageFilter, ImageDraw
+import numpy as np
+from transformers import pipeline
+import cv2
+import torch
+import os
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+from ..flux.condition import Condition
+from ..flux.generate import generate
+
+
+class TrainingCallback(L.Callback):
+ def __init__(self, run_name, training_config: dict = {}):
+ self.run_name, self.training_config = run_name, training_config
+
+ self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
+ self.save_interval = training_config.get("save_interval", 1000)
+ self.sample_interval = training_config.get("sample_interval", 1000)
+ self.save_path = training_config.get("save_path", "./output")
+
+ self.wandb_config = training_config.get("wandb", None)
+ self.use_wandb = (
+ wandb is not None and os.environ.get("WANDB_API_KEY") is not None
+ )
+
+ self.total_steps = 0
+
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
+ gradient_size = 0
+ max_gradient_size = 0
+ count = 0
+ for _, param in pl_module.named_parameters():
+ if param.grad is not None:
+ gradient_size += param.grad.norm(2).item()
+ max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
+ count += 1
+ if count > 0:
+ gradient_size /= count
+
+ self.total_steps += 1
+
+ # Print training progress every n steps
+ if self.use_wandb:
+ report_dict = {
+ "steps": batch_idx,
+ "steps": self.total_steps,
+ "epoch": trainer.current_epoch,
+ "gradient_size": gradient_size,
+ }
+ loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
+ report_dict["loss"] = loss_value
+ report_dict["t"] = pl_module.last_t
+ wandb.log(report_dict)
+
+ if self.total_steps % self.print_every_n_steps == 0:
+ print(
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
+ )
+
+ # Save LoRA weights at specified intervals
+ if self.total_steps % self.save_interval == 0:
+ print(
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights"
+ )
+ pl_module.save_lora(
+ f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}"
+ )
+
+ # Generate and save a sample image at specified intervals
+ if self.total_steps % self.sample_interval == 0:
+ print(
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
+ )
+ self.generate_a_sample(
+ trainer,
+ pl_module,
+ f"{self.save_path}/{self.run_name}/output",
+ f"lora_{self.total_steps}",
+ batch["condition_type"][
+ 0
+ ], # Use the condition type from the current batch
+ )
+
+ @torch.no_grad()
+ def generate_a_sample(
+ self,
+ trainer,
+ pl_module,
+ save_path,
+ file_name,
+ condition_type="super_resolution",
+ ):
+ # TODO: change this two variables to parameters
+ condition_size = trainer.training_config["dataset"]["condition_size"]
+ target_size = trainer.training_config["dataset"]["target_size"]
+ position_scale = trainer.training_config["dataset"].get("position_scale", 1.0)
+
+ generator = torch.Generator(device=pl_module.device)
+ generator.manual_seed(42)
+
+ test_list = []
+
+ if condition_type == "subject":
+ test_list.extend(
+ [
+ (
+ Image.open("assets/test_in.jpg"),
+ [0, -32],
+ "Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene.",
+ ),
+ (
+ Image.open("assets/test_out.jpg"),
+ [0, -32],
+ "In a bright room. It is placed on a table.",
+ ),
+ ]
+ )
+ elif condition_type == "canny":
+ condition_img = Image.open("assets/vase_hq.jpg").resize(
+ (condition_size, condition_size)
+ )
+ condition_img = np.array(condition_img)
+ condition_img = cv2.Canny(condition_img, 100, 200)
+ condition_img = Image.fromarray(condition_img).convert("RGB")
+ test_list.append(
+ (
+ condition_img,
+ [0, 0],
+ "A beautiful vase on a table.",
+ {"position_scale": position_scale} if position_scale != 1.0 else {},
+ )
+ )
+ elif condition_type == "coloring":
+ condition_img = (
+ Image.open("assets/vase_hq.jpg")
+ .resize((condition_size, condition_size))
+ .convert("L")
+ .convert("RGB")
+ )
+ test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
+ elif condition_type == "depth":
+ if not hasattr(self, "deepth_pipe"):
+ self.deepth_pipe = pipeline(
+ task="depth-estimation",
+ model="LiheYoung/depth-anything-small-hf",
+ device="cpu",
+ )
+ condition_img = (
+ Image.open("assets/vase_hq.jpg")
+ .resize((condition_size, condition_size))
+ .convert("RGB")
+ )
+ condition_img = self.deepth_pipe(condition_img)["depth"].convert("RGB")
+ test_list.append(
+ (
+ condition_img,
+ [0, 0],
+ "A beautiful vase on a table.",
+ {"position_scale": position_scale} if position_scale != 1.0 else {},
+ )
+ )
+ elif condition_type == "depth_pred":
+ condition_img = (
+ Image.open("assets/vase_hq.jpg")
+ .resize((condition_size, condition_size))
+ .convert("RGB")
+ )
+ test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
+ elif condition_type == "deblurring":
+ blur_radius = 5
+ image = Image.open("./assets/vase_hq.jpg")
+ condition_img = (
+ image.convert("RGB")
+ .resize((condition_size, condition_size))
+ .filter(ImageFilter.GaussianBlur(blur_radius))
+ .convert("RGB")
+ )
+ test_list.append(
+ (
+ condition_img,
+ [0, 0],
+ "A beautiful vase on a table.",
+ {"position_scale": position_scale} if position_scale != 1.0 else {},
+ )
+ )
+ elif condition_type == "fill":
+ condition_img = (
+ Image.open("./assets/vase_hq.jpg")
+ .resize((condition_size, condition_size))
+ .convert("RGB")
+ )
+ mask = Image.new("L", condition_img.size, 0)
+ draw = ImageDraw.Draw(mask)
+ a = condition_img.size[0] // 4
+ b = a * 3
+ draw.rectangle([a, a, b, b], fill=255)
+ condition_img = Image.composite(
+ condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask
+ )
+ test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
+ elif condition_type == "sr":
+ condition_img = (
+ Image.open("assets/vase_hq.jpg")
+ .resize((condition_size, condition_size))
+ .convert("RGB")
+ )
+ test_list.append((condition_img, [0, -16], "A beautiful vase on a table."))
+ elif condition_type == "cartoon":
+ condition_img = (
+ Image.open("assets/cartoon_boy.png")
+ .resize((condition_size, condition_size))
+ .convert("RGB")
+ )
+ test_list.append(
+ (
+ condition_img,
+ [0, -16],
+ "A cartoon character in a white background. He is looking right, and running.",
+ )
+ )
+ else:
+ raise NotImplementedError
+
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ for i, (condition_img, position_delta, prompt, *others) in enumerate(test_list):
+ condition = Condition(
+ condition_type=condition_type,
+ condition=condition_img.resize(
+ (condition_size, condition_size)
+ ).convert("RGB"),
+ position_delta=position_delta,
+ **(others[0] if others else {}),
+ )
+ res = generate(
+ pl_module.flux_pipe,
+ prompt=prompt,
+ conditions=[condition],
+ height=target_size,
+ width=target_size,
+ generator=generator,
+ model_config=pl_module.model_config,
+ default_lora=True,
+ )
+ res.images[0].save(
+ os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg")
+ )
diff --git a/OminiControl/src/train/data.py b/OminiControl/src/train/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee3c315146d387f8ee34db58b29867885d558bed
--- /dev/null
+++ b/OminiControl/src/train/data.py
@@ -0,0 +1,323 @@
+from PIL import Image, ImageFilter, ImageDraw
+import cv2
+import numpy as np
+from torch.utils.data import Dataset
+import torchvision.transforms as T
+import random
+
+
+class Subject200KDataset(Dataset):
+ def __init__(
+ self,
+ base_dataset,
+ condition_size: int = 512,
+ target_size: int = 512,
+ image_size: int = 512,
+ padding: int = 0,
+ condition_type: str = "subject",
+ drop_text_prob: float = 0.1,
+ drop_image_prob: float = 0.1,
+ return_pil_image: bool = False,
+ ):
+ self.base_dataset = base_dataset
+ self.condition_size = condition_size
+ self.target_size = target_size
+ self.image_size = image_size
+ self.padding = padding
+ self.condition_type = condition_type
+ self.drop_text_prob = drop_text_prob
+ self.drop_image_prob = drop_image_prob
+ self.return_pil_image = return_pil_image
+
+ self.to_tensor = T.ToTensor()
+
+ def __len__(self):
+ return len(self.base_dataset) * 2
+
+ def __getitem__(self, idx):
+ # If target is 0, left image is target, right image is condition
+ target = idx % 2
+ item = self.base_dataset[idx // 2]
+
+ # Crop the image to target and condition
+ image = item["image"]
+ left_img = image.crop(
+ (
+ self.padding,
+ self.padding,
+ self.image_size + self.padding,
+ self.image_size + self.padding,
+ )
+ )
+ right_img = image.crop(
+ (
+ self.image_size + self.padding * 2,
+ self.padding,
+ self.image_size * 2 + self.padding * 2,
+ self.image_size + self.padding,
+ )
+ )
+
+ # Get the target and condition image
+ target_image, condition_img = (
+ (left_img, right_img) if target == 0 else (right_img, left_img)
+ )
+
+ # Resize the image
+ condition_img = condition_img.resize(
+ (self.condition_size, self.condition_size)
+ ).convert("RGB")
+ target_image = target_image.resize(
+ (self.target_size, self.target_size)
+ ).convert("RGB")
+
+ # Get the description
+ description = item["description"][
+ "description_0" if target == 0 else "description_1"
+ ]
+
+ # Randomly drop text or image
+ drop_text = random.random() < self.drop_text_prob
+ drop_image = random.random() < self.drop_image_prob
+ if drop_text:
+ description = ""
+ if drop_image:
+ condition_img = Image.new(
+ "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
+ )
+
+ return {
+ "image": self.to_tensor(target_image),
+ "condition": self.to_tensor(condition_img),
+ "condition_type": self.condition_type,
+ "description": description,
+ # 16 is the downscale factor of the image
+ "position_delta": np.array([0, -self.condition_size // 16]),
+ **({"pil_image": image} if self.return_pil_image else {}),
+ }
+
+
+class ImageConditionDataset(Dataset):
+ def __init__(
+ self,
+ base_dataset,
+ condition_size: int = 512,
+ target_size: int = 512,
+ condition_type: str = "canny",
+ drop_text_prob: float = 0.1,
+ drop_image_prob: float = 0.1,
+ return_pil_image: bool = False,
+ position_scale=1.0,
+ ):
+ self.base_dataset = base_dataset
+ self.condition_size = condition_size
+ self.target_size = target_size
+ self.condition_type = condition_type
+ self.drop_text_prob = drop_text_prob
+ self.drop_image_prob = drop_image_prob
+ self.return_pil_image = return_pil_image
+ self.position_scale = position_scale
+
+ self.to_tensor = T.ToTensor()
+
+ def __len__(self):
+ return len(self.base_dataset)
+
+ @property
+ def depth_pipe(self):
+ if not hasattr(self, "_depth_pipe"):
+ from transformers import pipeline
+
+ self._depth_pipe = pipeline(
+ task="depth-estimation",
+ model="LiheYoung/depth-anything-small-hf",
+ device="cpu",
+ )
+ return self._depth_pipe
+
+ def _get_canny_edge(self, img):
+ resize_ratio = self.condition_size / max(img.size)
+ img = img.resize(
+ (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio))
+ )
+ img_np = np.array(img)
+ img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
+ edges = cv2.Canny(img_gray, 100, 200)
+ return Image.fromarray(edges).convert("RGB")
+
+ def __getitem__(self, idx):
+ image = self.base_dataset[idx]["jpg"]
+ image = image.resize((self.target_size, self.target_size)).convert("RGB")
+ description = self.base_dataset[idx]["json"]["prompt"]
+
+ enable_scale = random.random() < 1
+ if not enable_scale:
+ condition_size = int(self.condition_size * self.position_scale)
+ position_scale = 1.0
+ else:
+ condition_size = self.condition_size
+ position_scale = self.position_scale
+
+ # Get the condition image
+ position_delta = np.array([0, 0])
+ if self.condition_type == "canny":
+ condition_img = self._get_canny_edge(image)
+ elif self.condition_type == "coloring":
+ condition_img = (
+ image.resize((condition_size, condition_size))
+ .convert("L")
+ .convert("RGB")
+ )
+ elif self.condition_type == "deblurring":
+ blur_radius = random.randint(1, 10)
+ condition_img = (
+ image.convert("RGB")
+ .filter(ImageFilter.GaussianBlur(blur_radius))
+ .resize((condition_size, condition_size))
+ .convert("RGB")
+ )
+ elif self.condition_type == "depth":
+ condition_img = self.depth_pipe(image)["depth"].convert("RGB")
+ condition_img = condition_img.resize((condition_size, condition_size))
+ elif self.condition_type == "depth_pred":
+ condition_img = image
+ image = self.depth_pipe(condition_img)["depth"].convert("RGB")
+ description = f"[depth] {description}"
+ elif self.condition_type == "fill":
+ condition_img = image.resize((condition_size, condition_size)).convert(
+ "RGB"
+ )
+ w, h = image.size
+ x1, x2 = sorted([random.randint(0, w), random.randint(0, w)])
+ y1, y2 = sorted([random.randint(0, h), random.randint(0, h)])
+ mask = Image.new("L", image.size, 0)
+ draw = ImageDraw.Draw(mask)
+ draw.rectangle([x1, y1, x2, y2], fill=255)
+ if random.random() > 0.5:
+ mask = Image.eval(mask, lambda a: 255 - a)
+ condition_img = Image.composite(
+ image, Image.new("RGB", image.size, (0, 0, 0)), mask
+ )
+ elif self.condition_type == "sr":
+ condition_img = image.resize((condition_size, condition_size)).convert(
+ "RGB"
+ )
+ position_delta = np.array([0, -condition_size // 16])
+
+ else:
+ raise ValueError(f"Condition type {self.condition_type} not implemented")
+
+ # Randomly drop text or image
+ drop_text = random.random() < self.drop_text_prob
+ drop_image = random.random() < self.drop_image_prob
+ if drop_text:
+ description = ""
+ if drop_image:
+ condition_img = Image.new(
+ "RGB", (condition_size, condition_size), (0, 0, 0)
+ )
+
+ return {
+ "image": self.to_tensor(image),
+ "condition": self.to_tensor(condition_img),
+ "condition_type": self.condition_type,
+ "description": description,
+ "position_delta": position_delta,
+ **({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
+ **({"position_scale": position_scale} if position_scale != 1.0 else {}),
+ }
+
+
+class CartoonDataset(Dataset):
+ def __init__(
+ self,
+ base_dataset,
+ condition_size: int = 1024,
+ target_size: int = 1024,
+ image_size: int = 1024,
+ padding: int = 0,
+ condition_type: str = "cartoon",
+ drop_text_prob: float = 0.1,
+ drop_image_prob: float = 0.1,
+ return_pil_image: bool = False,
+ ):
+ self.base_dataset = base_dataset
+ self.condition_size = condition_size
+ self.target_size = target_size
+ self.image_size = image_size
+ self.padding = padding
+ self.condition_type = condition_type
+ self.drop_text_prob = drop_text_prob
+ self.drop_image_prob = drop_image_prob
+ self.return_pil_image = return_pil_image
+
+ self.to_tensor = T.ToTensor()
+
+ def __len__(self):
+ return len(self.base_dataset)
+
+ def __getitem__(self, idx):
+ data = self.base_dataset[idx]
+ condition_img = data["condition"]
+ target_image = data["target"]
+
+ # Tag
+ tag = data["tags"][0]
+
+ target_description = data["target_description"]
+
+ description = {
+ "lion": "lion like animal",
+ "bear": "bear like animal",
+ "gorilla": "gorilla like animal",
+ "dog": "dog like animal",
+ "elephant": "elephant like animal",
+ "eagle": "eagle like bird",
+ "tiger": "tiger like animal",
+ "owl": "owl like bird",
+ "woman": "woman",
+ "parrot": "parrot like bird",
+ "mouse": "mouse like animal",
+ "man": "man",
+ "pigeon": "pigeon like bird",
+ "girl": "girl",
+ "panda": "panda like animal",
+ "crocodile": "crocodile like animal",
+ "rabbit": "rabbit like animal",
+ "boy": "boy",
+ "monkey": "monkey like animal",
+ "cat": "cat like animal",
+ }
+
+ # Resize the image
+ condition_img = condition_img.resize(
+ (self.condition_size, self.condition_size)
+ ).convert("RGB")
+ target_image = target_image.resize(
+ (self.target_size, self.target_size)
+ ).convert("RGB")
+
+ # Process datum to create description
+ description = data.get(
+ "description",
+ f"Photo of a {description[tag]} cartoon character in a white background. Character is facing {target_description['facing_direction']}. Character pose is {target_description['pose']}.",
+ )
+
+ # Randomly drop text or image
+ drop_text = random.random() < self.drop_text_prob
+ drop_image = random.random() < self.drop_image_prob
+ if drop_text:
+ description = ""
+ if drop_image:
+ condition_img = Image.new(
+ "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
+ )
+
+ return {
+ "image": self.to_tensor(target_image),
+ "condition": self.to_tensor(condition_img),
+ "condition_type": self.condition_type,
+ "description": description,
+ # 16 is the downscale factor of the image
+ "position_delta": np.array([0, -16]),
+ }
diff --git a/OminiControl/src/train/model.py b/OminiControl/src/train/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b143858d2e90cbb11e5033766febafd49ea7697a
--- /dev/null
+++ b/OminiControl/src/train/model.py
@@ -0,0 +1,185 @@
+import lightning as L
+from diffusers.pipelines import FluxPipeline
+import torch
+from peft import LoraConfig, get_peft_model_state_dict
+
+import prodigyopt
+
+from ..flux.transformer import tranformer_forward
+from ..flux.condition import Condition
+from ..flux.pipeline_tools import encode_images, prepare_text_input
+
+
+class OminiModel(L.LightningModule):
+ def __init__(
+ self,
+ flux_pipe_id: str,
+ lora_path: str = None,
+ lora_config: dict = None,
+ device: str = "cuda",
+ dtype: torch.dtype = torch.bfloat16,
+ model_config: dict = {},
+ optimizer_config: dict = None,
+ gradient_checkpointing: bool = False,
+ ):
+ # Initialize the LightningModule
+ super().__init__()
+ self.model_config = model_config
+ self.optimizer_config = optimizer_config
+
+ # Load the Flux pipeline
+ self.flux_pipe: FluxPipeline = (
+ FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
+ )
+ self.transformer = self.flux_pipe.transformer
+ self.transformer.gradient_checkpointing = gradient_checkpointing
+ self.transformer.train()
+
+ # Freeze the Flux pipeline
+ self.flux_pipe.text_encoder.requires_grad_(False).eval()
+ self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
+ self.flux_pipe.vae.requires_grad_(False).eval()
+
+ # Initialize LoRA layers
+ self.lora_layers = self.init_lora(lora_path, lora_config)
+
+ self.to(device).to(dtype)
+
+ def init_lora(self, lora_path: str, lora_config: dict):
+ assert lora_path or lora_config
+ if lora_path:
+ # TODO: Implement this
+ raise NotImplementedError
+ else:
+ self.transformer.add_adapter(LoraConfig(**lora_config))
+ # TODO: Check if this is correct (p.requires_grad)
+ lora_layers = filter(
+ lambda p: p.requires_grad, self.transformer.parameters()
+ )
+ return list(lora_layers)
+
+ def save_lora(self, path: str):
+ FluxPipeline.save_lora_weights(
+ save_directory=path,
+ transformer_lora_layers=get_peft_model_state_dict(self.transformer),
+ safe_serialization=True,
+ )
+
+ def configure_optimizers(self):
+ # Freeze the transformer
+ self.transformer.requires_grad_(False)
+ opt_config = self.optimizer_config
+
+ # Set the trainable parameters
+ self.trainable_params = self.lora_layers
+
+ # Unfreeze trainable parameters
+ for p in self.trainable_params:
+ p.requires_grad_(True)
+
+ # Initialize the optimizer
+ if opt_config["type"] == "AdamW":
+ optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
+ elif opt_config["type"] == "Prodigy":
+ optimizer = prodigyopt.Prodigy(
+ self.trainable_params,
+ **opt_config["params"],
+ )
+ elif opt_config["type"] == "SGD":
+ optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
+ else:
+ raise NotImplementedError
+
+ return optimizer
+
+ def training_step(self, batch, batch_idx):
+ step_loss = self.step(batch)
+ self.log_loss = (
+ step_loss.item()
+ if not hasattr(self, "log_loss")
+ else self.log_loss * 0.95 + step_loss.item() * 0.05
+ )
+ return step_loss
+
+ def step(self, batch):
+ imgs = batch["image"]
+ conditions = batch["condition"]
+ condition_types = batch["condition_type"]
+ prompts = batch["description"]
+ position_delta = batch["position_delta"][0]
+ position_scale = float(batch.get("position_scale", [1.0])[0])
+
+ # Prepare inputs
+ with torch.no_grad():
+ # Prepare image input
+ x_0, img_ids = encode_images(self.flux_pipe, imgs)
+
+ # Prepare text input
+ prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
+ self.flux_pipe, prompts
+ )
+
+ # Prepare t and x_t
+ t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
+ x_1 = torch.randn_like(x_0).to(self.device)
+ t_ = t.unsqueeze(1).unsqueeze(1)
+ x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
+
+ # Prepare conditions
+ condition_latents, condition_ids = encode_images(self.flux_pipe, conditions)
+
+ # Add position delta
+ condition_ids[:, 1] += position_delta[0]
+ condition_ids[:, 2] += position_delta[1]
+
+ if position_scale != 1.0:
+ scale_bias = (position_scale - 1.0) / 2
+ condition_ids[:, 1] *= position_scale
+ condition_ids[:, 2] *= position_scale
+ condition_ids[:, 1] += scale_bias
+ condition_ids[:, 2] += scale_bias
+
+ # Prepare condition type
+ condition_type_ids = torch.tensor(
+ [
+ Condition.get_type_id(condition_type)
+ for condition_type in condition_types
+ ]
+ ).to(self.device)
+ condition_type_ids = (
+ torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0]
+ ).unsqueeze(1)
+
+ # Prepare guidance
+ guidance = (
+ torch.ones_like(t).to(self.device)
+ if self.transformer.config.guidance_embeds
+ else None
+ )
+
+ # Forward pass
+ transformer_out = tranformer_forward(
+ self.transformer,
+ # Model config
+ model_config=self.model_config,
+ # Inputs of the condition (new feature)
+ condition_latents=condition_latents,
+ condition_ids=condition_ids,
+ condition_type_ids=condition_type_ids,
+ # Inputs to the original transformer
+ hidden_states=x_t,
+ timestep=t,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=img_ids,
+ joint_attention_kwargs=None,
+ return_dict=False,
+ )
+ pred = transformer_out[0]
+
+ # Compute loss
+ loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
+ self.last_t = t.mean().item()
+ return loss
diff --git a/OminiControl/src/train/train.py b/OminiControl/src/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dbdc2989c7ca4261f2198de92f1e243c90262fe
--- /dev/null
+++ b/OminiControl/src/train/train.py
@@ -0,0 +1,178 @@
+from torch.utils.data import DataLoader
+import torch
+import lightning as L
+import yaml
+import os
+import time
+
+from datasets import load_dataset
+
+from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset
+from .model import OminiModel
+from .callbacks import TrainingCallback
+
+
+def get_rank():
+ try:
+ rank = int(os.environ.get("LOCAL_RANK"))
+ except:
+ rank = 0
+ return rank
+
+
+def get_config():
+ config_path = os.environ.get("XFL_CONFIG")
+ assert config_path is not None, "Please set the XFL_CONFIG environment variable"
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+ return config
+
+
+def init_wandb(wandb_config, run_name):
+ import wandb
+
+ try:
+ assert os.environ.get("WANDB_API_KEY") is not None
+ wandb.init(
+ project=wandb_config["project"],
+ name=run_name,
+ config={},
+ )
+ except Exception as e:
+ print("Failed to initialize WanDB:", e)
+
+
+def main():
+ # Initialize
+ is_main_process, rank = get_rank() == 0, get_rank()
+ torch.cuda.set_device(rank)
+ config = get_config()
+ training_config = config["train"]
+ run_name = time.strftime("%Y%m%d-%H%M%S")
+
+ # Initialize WanDB
+ wandb_config = training_config.get("wandb", None)
+ if wandb_config is not None and is_main_process:
+ init_wandb(wandb_config, run_name)
+
+ print("Rank:", rank)
+ if is_main_process:
+ print("Config:", config)
+
+ # Initialize dataset and dataloader
+ if training_config["dataset"]["type"] == "subject":
+ dataset = load_dataset("Yuanshi/Subjects200K")
+
+ # Define filter function
+ def filter_func(item):
+ if not item.get("quality_assessment"):
+ return False
+ return all(
+ item["quality_assessment"].get(key, 0) >= 5
+ for key in ["compositeStructure", "objectConsistency", "imageQuality"]
+ )
+
+ # Filter dataset
+ if not os.path.exists("./cache/dataset"):
+ os.makedirs("./cache/dataset")
+ data_valid = dataset["train"].filter(
+ filter_func,
+ num_proc=16,
+ cache_file_name="./cache/dataset/data_valid.arrow",
+ )
+ dataset = Subject200KDataset(
+ data_valid,
+ condition_size=training_config["dataset"]["condition_size"],
+ target_size=training_config["dataset"]["target_size"],
+ image_size=training_config["dataset"]["image_size"],
+ padding=training_config["dataset"]["padding"],
+ condition_type=training_config["condition_type"],
+ drop_text_prob=training_config["dataset"]["drop_text_prob"],
+ drop_image_prob=training_config["dataset"]["drop_image_prob"],
+ )
+ elif training_config["dataset"]["type"] == "img":
+ # Load dataset text-to-image-2M
+ dataset = load_dataset(
+ "webdataset",
+ data_files={"train": training_config["dataset"]["urls"]},
+ split="train",
+ cache_dir="cache/t2i2m",
+ num_proc=32,
+ )
+ dataset = ImageConditionDataset(
+ dataset,
+ condition_size=training_config["dataset"]["condition_size"],
+ target_size=training_config["dataset"]["target_size"],
+ condition_type=training_config["condition_type"],
+ drop_text_prob=training_config["dataset"]["drop_text_prob"],
+ drop_image_prob=training_config["dataset"]["drop_image_prob"],
+ position_scale=training_config["dataset"].get("position_scale", 1.0),
+ )
+ elif training_config["dataset"]["type"] == "cartoon":
+ dataset = load_dataset("saquiboye/oye-cartoon", split="train")
+ dataset = CartoonDataset(
+ dataset,
+ condition_size=training_config["dataset"]["condition_size"],
+ target_size=training_config["dataset"]["target_size"],
+ image_size=training_config["dataset"]["image_size"],
+ padding=training_config["dataset"]["padding"],
+ condition_type=training_config["condition_type"],
+ drop_text_prob=training_config["dataset"]["drop_text_prob"],
+ drop_image_prob=training_config["dataset"]["drop_image_prob"],
+ )
+ else:
+ raise NotImplementedError
+
+ print("Dataset length:", len(dataset))
+ train_loader = DataLoader(
+ dataset,
+ batch_size=training_config["batch_size"],
+ shuffle=True,
+ num_workers=training_config["dataloader_workers"],
+ )
+
+ # Initialize model
+ trainable_model = OminiModel(
+ flux_pipe_id=config["flux_path"],
+ lora_config=training_config["lora_config"],
+ device=f"cuda",
+ dtype=getattr(torch, config["dtype"]),
+ optimizer_config=training_config["optimizer"],
+ model_config=config.get("model", {}),
+ gradient_checkpointing=training_config.get("gradient_checkpointing", False),
+ )
+
+ # Callbacks for logging and saving checkpoints
+ training_callbacks = (
+ [TrainingCallback(run_name, training_config=training_config)]
+ if is_main_process
+ else []
+ )
+
+ # Initialize trainer
+ trainer = L.Trainer(
+ accumulate_grad_batches=training_config["accumulate_grad_batches"],
+ callbacks=training_callbacks,
+ enable_checkpointing=False,
+ enable_progress_bar=False,
+ logger=False,
+ max_steps=training_config.get("max_steps", -1),
+ max_epochs=training_config.get("max_epochs", -1),
+ gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
+ )
+
+ setattr(trainer, "training_config", training_config)
+
+ # Save config
+ save_path = training_config.get("save_path", "./output")
+ if is_main_process:
+ os.makedirs(f"{save_path}/{run_name}")
+ with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
+ yaml.dump(config, f)
+
+ # Start training
+ trainer.fit(trainable_model, train_loader)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/OminiControl/train/README.md b/OminiControl/train/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2ad489a53e8398568754e84a7ebb6bf9ff4780b7
--- /dev/null
+++ b/OminiControl/train/README.md
@@ -0,0 +1,138 @@
+# OminiControl Training 🛠️
+
+## Preparation
+
+### Setup
+1. **Environment**
+ ```bash
+ conda create -n omini python=3.10
+ conda activate omini
+ ```
+2. **Requirements**
+ ```bash
+ pip install -r train/requirements.txt
+ ```
+
+### Dataset
+1. Download dataset [Subject200K](https://huggingface.co./datasets/Yuanshi/Subjects200K). (**subject-driven generation**)
+ ```
+ bash train/script/data_download/data_download1.sh
+ ```
+2. Download dataset [text-to-image-2M](https://huggingface.co./datasets/jackyhate/text-to-image-2M). (**spatial control task**)
+ ```
+ bash train/script/data_download/data_download2.sh
+ ```
+ **Note:** By default, only a few files are downloaded. You can modify `data_download2.sh` to download additional datasets. Remember to update the config file to specify the training data accordingly.
+
+## Training
+
+### Start training training
+**Config file path**: `./train/config`
+
+**Scripts path**: `./train/script`
+
+1. Subject-driven generation
+ ```bash
+ bash train/script/train_subject.sh
+ ```
+2. Spatial control task
+ ```bash
+ bash train/script/train_canny.sh
+ ```
+
+**Note**: Detailed WanDB settings and GPU settings can be found in the script files and the config files.
+
+### Other spatial control tasks
+This repository supports 5 spatial control tasks:
+1. Canny edge to image (`canny`)
+2. Image colorization (`coloring`)
+3. Image deblurring (`deblurring`)
+4. Depth map to image (`depth`)
+5. Image to depth map (`depth_pred`)
+6. Image inpainting (`fill`)
+7. Super resolution (`sr`)
+
+You can modify the `condition_type` parameter in config file `config/canny_512.yaml` to switch between different tasks.
+
+### Customize your own task
+You can customize your own task by constructing a new dataset and modifying the training code.
+
+
+Instructions
+
+1. **Dataset** :
+
+ Construct a new dataset with the following format: (`src/train/data.py`)
+ ```python
+ class MyDataset(Dataset):
+ def __init__(self, ...):
+ ...
+ def __len__(self):
+ ...
+ def __getitem__(self, idx):
+ ...
+ return {
+ "image": image,
+ "condition": condition_img,
+ "condition_type": "your_condition_type",
+ "description": description,
+ "position_delta": position_delta
+ }
+ ```
+ **Note:** For spatial control tasks, set the `position_delta` to be `[0, 0]`. For non-spatial control tasks, set `position_delta` to be `[0, -condition_width // 16]`.
+2. **Condition**:
+
+ Add a new condition type in the `Condition` class. (`src/flux/condition.py`)
+ ```python
+ condition_dict = {
+ ...
+ "your_condition_type": your_condition_id_number, # Add your condition type here
+ }
+ ...
+ if condition_type in [
+ ...
+ "your_condition_type", # Add your condition type here
+ ]:
+ ...
+ ```
+3. **Test**:
+
+ Add a new test function for your task. (`src/train/callbacks.py`)
+ ```python
+ if self.condition_type == "your_condition_type":
+ condition_img = (
+ Image.open("images/vase.jpg")
+ .resize((condition_size, condition_size))
+ .convert("RGB")
+ )
+ ...
+ test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
+ ```
+
+4. **Import relevant dataset in the training script**
+ Update the file in the following section. (`src/train/train.py`)
+ ```python
+ from .data import (
+ ImageConditionDataset,
+ Subject200KDateset,
+ MyDataset
+ )
+ ...
+
+ # Initialize dataset and dataloader
+ if training_config["dataset"]["type"] == "your_condition_type":
+ ...
+ ```
+
+
+
+## Hardware requirement
+**Note**: Memory optimization (like dynamic T5 model loading) is pending implementation.
+
+**Recommanded**
+- Hardware: 2x NVIDIA H100 GPUs
+- Memory: ~80GB GPU memory
+
+**Minimal**
+- Hardware: 1x NVIDIA L20 GPU
+- Memory: ~48GB GPU memory
\ No newline at end of file
diff --git a/OminiControl/train/config/canny_512.yaml b/OminiControl/train/config/canny_512.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8a0a4602870ecba8bf8dd322b6333aa40ee74cc0
--- /dev/null
+++ b/OminiControl/train/config/canny_512.yaml
@@ -0,0 +1,48 @@
+flux_path: "black-forest-labs/FLUX.1-dev"
+dtype: "bfloat16"
+
+model:
+ union_cond_attn: true
+ add_cond_attn: false
+ latent_lora: false
+
+train:
+ batch_size: 1
+ accumulate_grad_batches: 1
+ dataloader_workers: 5
+ save_interval: 1000
+ sample_interval: 100
+ max_steps: -1
+ gradient_checkpointing: true
+ save_path: "runs"
+
+ # Specify the type of condition to use.
+ # Options: ["canny", "coloring", "deblurring", "depth", "depth_pred", "fill"]
+ condition_type: "canny"
+ dataset:
+ type: "img"
+ urls:
+ - "https://huggingface.co./datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000046.tar"
+ - "https://huggingface.co./datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000045.tar"
+ cache_name: "data_512_2M"
+ condition_size: 512
+ target_size: 512
+ drop_text_prob: 0.1
+ drop_image_prob: 0.1
+
+ wandb:
+ project: "OminiControl"
+
+ lora_config:
+ r: 4
+ lora_alpha: 4
+ init_lora_weights: "gaussian"
+ target_modules: "(.*x_embedder|.*(?