diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..39e7ae7fd0fdd2d8e5bc370225bb1f3eb8648ac8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.xz filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d6251bc9e40929ce2c770c116f19c14239062e6c --- /dev/null +++ b/.gitignore @@ -0,0 +1,147 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug +.idea/ +cloud_tools/ + +output +pretrained_models +results +develop +gradio_results +demo \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..988595470b32fc999737a70e15da19f2eb8696da --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "extern/LGM"] + path = extern/LGM + url = https://github.com/3DTopia/LGM.git \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..76e97eaebd7d224468b9e1b85498c2e614007fe8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,13 @@ +S-Lab License 1.0 +Copyright 2024 S-Lab + +Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. \ No newline at end of file diff --git a/README.md b/README.md index e00eb281e048617924277896796fe3d21764d77d..76298c9749bb3f6fdb2444b4df0fe99088ecf826 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ --- title: 3DEnhancer -emoji: 📚 -colorFrom: green -colorTo: pink +emoji: 🔆 +colorFrom: red +colorTo: green sdk: gradio -sdk_version: 5.20.1 +sdk_version: 4.44.1 app_file: app.py pinned: false license: other diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e9846e56d1143f3d3944bcaff25303dab44769b4 --- /dev/null +++ b/app.py @@ -0,0 +1,588 @@ +import warnings +warnings.filterwarnings('ignore') + +import spaces + +import os +import tyro +import imageio +import numpy as np +import tqdm +import cv2 +import torch +import torch.nn.functional as F +from torchvision import transforms as T +import torchvision.transforms.functional as TF +from safetensors.torch import load_file +import kiui +from kiui.op import recenter +from kiui.cam import orbit_camera +import rembg +import gradio as gr +from gradio_imageslider import ImageSlider + +import sys +sys.path.insert(0, "src") +from src.enhancer import Enhancer +from src.utils.camera import get_c2ws + +# import LGM +sys.path.insert(0, "extern/LGM") +from core.options import AllConfigs +from core.models import LGM +from mvdream.pipeline_mvdream import MVDreamPipeline + + +# download checkpoints +from huggingface_hub import hf_hub_download +hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16_fixrot.safetensors", local_dir='pretrained_models/LGM') +hf_hub_download(repo_id="Luo-Yihang/3DEnhancer", filename="model.safetensors", local_dir='pretrained_models/3DEnhancer') + + +### Title and Description ### +#### Description #### +title = r"""

3DEnhancer: Consistent Multi-View Diffusion for 3D Enhancement

""" + +important_link = r""" +
+[arxiv] +  [Project Page] +  [Code] +
+""" + +authors = r""" +
+ Yihang Luo +   Shangchen Zhou +  Yushi Lan +  Xingang Pan +  Chen Change Loy +
+""" + +affiliation = r""" +
+ S-Lab, NTU Singapore +
+""" + +description = r""" +Official Gradio demo for 3DEnhancer: Consistent Multi-View Diffusion for 3D Enhancement.
+🔥 3DEnhancer employs a multi-view diffusion model to enhance multi-view images, thus improving 3D models. Our contributions include a robust data augmentation pipeline, and the view-consistent blocks that integrate multi-view row attention and near-view epipolar aggregation modules to promote view consistency.
+""" + +article = r""" +
If 3DEnhancer is helpful, please help to ⭐ the Github Repo. Thanks! +[![GitHub Stars](https://img.shields.io/github/stars/Luo-Yihang/3DEnhancer)](https://github.com/Luo-Yihang/3DEnhancer) +--- +📝 **License** +
+This project is licensed under S-Lab License 1.0, +Redistribution and use for non-commercial purposes should follow this license. +
+📝 **Citation** +
+If our work is useful for your research, please consider citing: +```bibtex +@article{luo20243denhancer, + title={3DEnhancer: Consistent Multi-View Diffusion for 3D Enhancement}, + author={Yihang Luo and Shangchen Zhou and Yushi Lan and Xingang Pan and Chen Change Loy}, + booktitle={arXiv preprint arXiv:2412.18565} + year={2024}, +} +``` +📧 **Contact** +
+If you have any questions, please feel free to reach me out at luo_yihang@outlook.com. +""" + + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +BASE_SAVE_PATH = 'gradio_results' +GRADIO_VIDEO_PATH = f'{BASE_SAVE_PATH}/gradio_output.mp4' +GRADIO_PLY_PATH = f'{BASE_SAVE_PATH}/gradio_output.ply' +GRADIO_ENHANCED_VIDEO_PATH = f'{BASE_SAVE_PATH}/gradio_enhanced_output.mp4' +GRADIO_ENHANCED_PLY_PATH = f'{BASE_SAVE_PATH}/gradio_enhanced_output.ply' +DEFAULT_NEG_PROMPT = "ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate" +DEFAULT_SEED = 0 +os.makedirs(BASE_SAVE_PATH, exist_ok=True) + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# load dreams +pipe_text = MVDreamPipeline.from_pretrained( + 'ashawkey/mvdream-sd2.1-diffusers', # remote weights + torch_dtype=torch.float16, + trust_remote_code=True +) +pipe_text = pipe_text.to(device) + +pipe_image = MVDreamPipeline.from_pretrained( + "ashawkey/imagedream-ipmv-diffusers", # remote weights + torch_dtype=torch.float16, + trust_remote_code=True +) +pipe_image = pipe_image.to(device) + +# load lgm +lgm_opt = tyro.cli(AllConfigs, args=["big"]) + +tan_half_fov = np.tan(0.5 * np.deg2rad(lgm_opt.fovy)) +proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device) +proj_matrix[0, 0] = 1 / tan_half_fov +proj_matrix[1, 1] = 1 / tan_half_fov +proj_matrix[2, 2] = (lgm_opt.zfar + lgm_opt.znear) / (lgm_opt.zfar - lgm_opt.znear) +proj_matrix[3, 2] = - (lgm_opt.zfar * lgm_opt.znear) / (lgm_opt.zfar - lgm_opt.znear) +proj_matrix[2, 3] = 1 + +lgm_model = LGM(lgm_opt) +lgm_model = lgm_model.half().to(device) +ckpt = load_file("pretrained_models/LGM/model_fp16_fixrot.safetensors", device='cpu') +lgm_model.load_state_dict(ckpt, strict=False) +lgm_model.eval() + +# load 3denhancer +enhancer = Enhancer( + model_path = "pretrained_models/3DEnhancer/model.safetensors", + config_path = "src/configs/config.py", +) + +# load rembg +bg_remover = rembg.new_session() + +@torch.no_grad() +@spaces.GPU +def gen_mv(ref_image, ref_text): + kiui.seed_everything(DEFAULT_SEED) + + # text-conditioned + if ref_image is None: + mv_image_uint8 = pipe_text(ref_text, negative_prompt=DEFAULT_NEG_PROMPT, num_inference_steps=30, guidance_scale=7.5, elevation=0) + mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8) + # bg removal + mv_image = [] + for i in range(4): + image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4] + # to white bg + image = image.astype(np.float32) / 255 + image = recenter(image, image[..., 0] > 0, border_ratio=0.2) + image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:]) + mv_image.append(image) + # image-conditioned (may also input text, but no text usually works too) + else: + ref_image = np.array(ref_image) # uint8 + # bg removal + carved_image = rembg.remove(ref_image, session=bg_remover) # [H, W, 4] + mask = carved_image[..., -1] > 0 + image = recenter(carved_image, mask, border_ratio=0.2) + image = image.astype(np.float32) / 255.0 + image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) + mv_image = pipe_image(ref_text, image, negative_prompt=DEFAULT_NEG_PROMPT, num_inference_steps=30, guidance_scale=5.0, elevation=0) + + # mv_image, a list of 4 np_arrays in shape (256, 256, 3) in range (0.0, 1.0) + mv_image_512 = [] + for i in range(len(mv_image)): + mv_image_512.append(cv2.resize(mv_image[i], (512, 512), interpolation=cv2.INTER_LINEAR)) + + return mv_image_512[0], mv_image_512[1], mv_image_512[2], mv_image_512[3], ref_text, 120 + + +@torch.no_grad() +@spaces.GPU +def gen_3d(image_0, image_1, image_2, image_3, elevation, output_video_path, output_ply_path): + kiui.seed_everything(DEFAULT_SEED) + + mv_image = [image_0, image_1, image_2, image_3] + for i in range(len(mv_image)): + if type(mv_image[i]) is tuple: + mv_image[i] = mv_image[i][1] + mv_image[i] = np.array(mv_image[i]).astype(np.float32) / 255.0 + mv_image[i] = cv2.resize(mv_image[i], (256, 256), interpolation=cv2.INTER_AREA) + + # generate gaussians + input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32 + input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256] + input_image = F.interpolate(input_image, size=(lgm_opt.input_size, lgm_opt.input_size), mode='bilinear', align_corners=False) + input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + rays_embeddings = lgm_model.prepare_default_rays(device, elevation=elevation) + input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W] + + with torch.no_grad(): + with torch.autocast(device_type='cuda', dtype=torch.float16): + # generate gaussians + gaussians = lgm_model.forward_gaussians(input_image) + lgm_model.gs.save_ply(gaussians, output_ply_path) + + # render 360 video + images = [] + elevation = 0 + if lgm_opt.fancy_video: + azimuth = np.arange(0, 720, 4, dtype=np.int32) + for azi in tqdm.tqdm(azimuth): + cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=lgm_opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + scale = min(azi / 360, 1) + + image = lgm_model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image'] + images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) + else: + azimuth = np.arange(0, 360, 2, dtype=np.int32) + for azi in tqdm.tqdm(azimuth): + cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=lgm_opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + image = lgm_model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image'] + images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) + + images = np.concatenate(images, axis=0) + imageio.mimwrite(output_video_path, images, fps=30) + + return output_video_path, output_ply_path + + +@torch.no_grad() +@spaces.GPU +def enhance(image_0, image_1, image_2, image_3, prompt, elevation, noise_level, cfg_scale, steps, seed, color_shift): + kiui.seed_everything(seed) + + mv_image = [image_0, image_1, image_2, image_3] + img_tensor_list = [] + for image in mv_image: + img_tensor_list.append(T.ToTensor()(image)) + + img_tensors = torch.stack(img_tensor_list) + + color_shift = None if color_shift=="disabled" else color_shift + output_img_tensors = enhancer.inference( + mv_imgs=img_tensors, + c2ws=get_c2ws(elevations=[elevation]*4, amuziths=[0,90,180,270]), + prompt=prompt, + noise_level=noise_level, + cfg_scale=cfg_scale, + sample_steps=steps, + color_shift=color_shift, + ) + + mv_image_512 = output_img_tensors.permute(0,2,3,1).cpu().numpy() + + # return to the image slider component + return (image_0, mv_image_512[0]), (image_1, mv_image_512[1]), (image_2, mv_image_512[2]), (image_3, mv_image_512[3]) + + +def check_video(input_video): + if input_video: + return gr.update(interactive=True) + return gr.update(interactive=False) + + +i2mv_examples = [ + ["assets/examples/i2mv/cake.png", "cake"], + ["assets/examples/i2mv/skull.png", "skull"], + ["assets/examples/i2mv/sea_turtle.png", "sea turtle"], + ["assets/examples/i2mv/house2.png", "house"], + ["assets/examples/i2mv/cup.png", "cup"], + ["assets/examples/i2mv/mannequin.png", "mannequin"], + ["assets/examples/i2mv/boy.jpg", "boy"], + ["assets/examples/i2mv/dragontoy.jpg", "dragon toy"], + ["assets/examples/i2mv/gso_rabbit.jpg", "rabbit car"], + ["assets/examples/i2mv/Mario_New_Super_Mario_Bros_U_Deluxe.png", "standing Mario"], +] + +t2mv_examples = [ + "teddy bear", + "hamburger", + "oldman's head sculpture", + "headphone", + "mech suit", + "wooden barrel", + "scary zombie" +] + +mv_examples = [ + [ + "assets/examples/mv_lq_prerendered/vase.mp4", + "assets/examples/mv_lq/vase/00.png", + "assets/examples/mv_lq/vase/01.png", + "assets/examples/mv_lq/vase/02.png", + "assets/examples/mv_lq/vase/03.png", + "vase", + 0 + ], + [ + "assets/examples/mv_lq_prerendered/tower.mp4", + "assets/examples/mv_lq/tower/00.png", + "assets/examples/mv_lq/tower/01.png", + "assets/examples/mv_lq/tower/02.png", + "assets/examples/mv_lq/tower/03.png", + "brick tower", + 0 + ], + [ + "assets/examples/mv_lq_prerendered/truck.mp4", + "assets/examples/mv_lq/truck/00.png", + "assets/examples/mv_lq/truck/01.png", + "assets/examples/mv_lq/truck/02.png", + "assets/examples/mv_lq/truck/03.png", + "truck", + 0 + ], + [ + "assets/examples/mv_lq_prerendered/gascan.mp4", + "assets/examples/mv_lq/gascan/00.png", + "assets/examples/mv_lq/gascan/01.png", + "assets/examples/mv_lq/gascan/02.png", + "assets/examples/mv_lq/gascan/03.png", + "gas can", + 0 + ], + [ + "assets/examples/mv_lq_prerendered/fish.mp4", + "assets/examples/mv_lq/fish/00.png", + "assets/examples/mv_lq/fish/01.png", + "assets/examples/mv_lq/fish/02.png", + "assets/examples/mv_lq/fish/03.png", + "sea fish with eyes", + 0 + ], + [ + "assets/examples/mv_lq_prerendered/tshirt.mp4", + "assets/examples/mv_lq/tshirt/00.png", + "assets/examples/mv_lq/tshirt/01.png", + "assets/examples/mv_lq/tshirt/02.png", + "assets/examples/mv_lq/tshirt/03.png", + "t-shirt", + 0 + ], + [ + "assets/examples/mv_lq_prerendered/turtle.mp4", + "assets/examples/mv_lq/turtle/00.png", + "assets/examples/mv_lq/turtle/01.png", + "assets/examples/mv_lq/turtle/02.png", + "assets/examples/mv_lq/turtle/03.png", + "sea turtle", + 200 + ], + [ + "assets/examples/mv_lq_prerendered/cake.mp4", + "assets/examples/mv_lq/cake/00.png", + "assets/examples/mv_lq/cake/01.png", + "assets/examples/mv_lq/cake/02.png", + "assets/examples/mv_lq/cake/03.png", + "cake", + 120 + ], + [ + "assets/examples/mv_lq_prerendered/lamp.mp4", + "assets/examples/mv_lq/lamp/00.png", + "assets/examples/mv_lq/lamp/01.png", + "assets/examples/mv_lq/lamp/02.png", + "assets/examples/mv_lq/lamp/03.png", + "lamp", + 0 + ], + [ + "assets/examples/mv_lq_prerendered/oldman.mp4", + "assets/examples/mv_lq/oldman/00.png", + "assets/examples/mv_lq/oldman/00.png", + "assets/examples/mv_lq/oldman/00.png", + "assets/examples/mv_lq/oldman/00.png", + "old man sculpture", + 120 + ], + [ + "assets/examples/mv_lq_prerendered/mario.mp4", + "assets/examples/mv_lq/mario/00.png", + "assets/examples/mv_lq/mario/01.png", + "assets/examples/mv_lq/mario/02.png", + "assets/examples/mv_lq/mario/03.png", + "standing mario", + 120 + ], + [ + "assets/examples/mv_lq_prerendered/house.mp4", + "assets/examples/mv_lq/house/00.png", + "assets/examples/mv_lq/house/01.png", + "assets/examples/mv_lq/house/02.png", + "assets/examples/mv_lq/house/03.png", + "house", + 120 + ], +] + + +# gradio UI +demo = gr.Blocks().queue() +with demo: + gr.Markdown(title) + gr.Markdown(authors) + gr.Markdown(affiliation) + gr.Markdown(important_link) + gr.Markdown(description) + + original_video_path = gr.State(GRADIO_VIDEO_PATH) + original_ply_path = gr.State(GRADIO_PLY_PATH) + enhanced_video_path = gr.State(GRADIO_ENHANCED_VIDEO_PATH) + enhanced_ply_path = gr.State(GRADIO_ENHANCED_PLY_PATH) + + with gr.Column(variant='panel'): + with gr.Accordion("Generate Multi Views (LGM)", open=False): + gr.Markdown("*Don't have multi-view images on hand? Generate them here using a single image, text, or a combination of both.*") + with gr.Row(): + with gr.Column(): + ref_image = gr.Image(label="Reference Image", type='pil', height=400, interactive=True) + ref_text = gr.Textbox(label="Prompt", value="", interactive=True) + with gr.Column(): + gr.Examples( + examples=i2mv_examples, + inputs=[ref_image, ref_text], + examples_per_page=3, + label='Image-to-Multiviews Examples', + ) + + gr.Examples( + examples=t2mv_examples, + inputs=[ref_text], + outputs=[ref_image, ref_text], + cache_examples=False, + run_on_click=True, + fn=lambda x: (None, x), + label='Text-to-Multiviews Examples', + ) + + with gr.Row(): + gr.Column() # Empty column for spacing + button_gen_mv = gr.Button("Generate Multi Views", scale=1) + gr.Column() # Empty column for spacing + + with gr.Column(): + gr.Markdown("Let's enhance!") + with gr.Row(): + with gr.Column(scale=2): + with gr.Tab("Multi Views"): + gr.Markdown("*Upload your multi-view images and enhance them with 3DEnhancer. You can also generate 3D model using LGM.*") + with gr.Row(): + input_image_0 = gr.Image(label="[Input] view-0", type='pil', height=320) + input_image_1 = gr.Image(label="[Input] view-1", type='pil', height=320) + input_image_2 = gr.Image(label="[Input] view-2", type='pil', height=320) + input_image_3 = gr.Image(label="[Input] view-3", type='pil', height=320) + gr.Markdown("---") + gr.Markdown("Enhanced Output") + with gr.Row(): + enhanced_image_0 = ImageSlider(label="[Enhanced] view-0", type='pil', height=350, interactive=False) + enhanced_image_1 = ImageSlider(label="[Enhanced] view-1", type='pil', height=350, interactive=False) + enhanced_image_2 = ImageSlider(label="[Enhanced] view-2", type='pil', height=350, interactive=False) + enhanced_image_3 = ImageSlider(label="[Enhanced] view-3", type='pil', height=350, interactive=False) + with gr.Tab("Generated 3D"): + gr.Markdown("Coarse Input") + with gr.Column(): + with gr.Row(): + gr.Column() # Empty column for spacing + with gr.Column(): + input_3d_video = gr.Video(label="[Input] Rendered Video", height=300, scale=1, interactive=False) + with gr.Row(): + button_gen_3d = gr.Button("Render 3D") + button_download_3d = gr.DownloadButton("Download Ply", interactive=False) + # button_download_3d = gr.File(label="Download Ply", interactive=False, height=50) + gr.Column() # Empty column for spacing + gr.Markdown("---") + gr.Markdown("Enhanced Output") + with gr.Row(): + gr.Column() # Empty column for spacing + with gr.Column(): + enhanced_3d_video = gr.Video(label="[Enhanced] Rendered Video", height=300, scale=1, interactive=False) + with gr.Row(): + enhanced_button_gen_3d = gr.Button("Render 3D") + enhanced_button_download_3d = gr.DownloadButton("Download Ply", interactive=False) + gr.Column() # Empty column for spacing + + with gr.Column(): + with gr.Row(): + enhancer_text = gr.Textbox(label="Prompt", value="", scale=1) + enhancer_noise_level = gr.Slider(label="enhancer noise level", minimum=0, maximum=300, step=1, value=0, interactive=True) + with gr.Accordion("Addvanced Setting", open=False): + with gr.Column(): + with gr.Row(): + with gr.Column(): + elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0) + cfg_scale = gr.Slider(label="cfg scale", minimum=0, maximum=10, step=0.1, value=4.5) + with gr.Column(): + seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0) + steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=20) + with gr.Row(): + color_shift = gr.Radio(label="color shift", value="disabled", choices=["disabled", "adain", "wavelet"]) + with gr.Row(): + gr.Column() # Empty column for spacing + button_enhance = gr.Button("Enhance", scale=1, variant="primary") + gr.Column() # Empty column for spacing + + gr.Examples( + examples=mv_examples, + inputs=[input_3d_video, input_image_0, input_image_1, input_image_2, input_image_3, enhancer_text, enhancer_noise_level], + examples_per_page=3, + label='Multiviews Examples', + ) + + gr.Markdown("*Don't have multi-view images on hand but want to generate your own multi-viwes? Generate them in the `Generate Multi Views (LGM)` secction above.*") + + gr.Markdown(article) + + button_gen_mv.click( + gen_mv, + inputs=[ref_image, ref_text], + outputs=[input_image_0, input_image_1, input_image_2, input_image_3, enhancer_text, enhancer_noise_level] + ) + + button_gen_3d.click( + gen_3d, + inputs=[input_image_0, input_image_1, input_image_2, input_image_3, elevation, original_video_path, original_ply_path], + outputs=[input_3d_video, button_download_3d] + ).success( + lambda: gr.Button(interactive=True), + outputs=[button_download_3d], + ) + + enhanced_button_gen_3d.click( + gen_3d, + inputs=[enhanced_image_0, enhanced_image_1, enhanced_image_2, enhanced_image_3, elevation, original_video_path, original_ply_path], + outputs=[enhanced_3d_video, enhanced_button_download_3d] + ).success( + lambda: gr.Button(interactive=True), + outputs=[enhanced_button_download_3d], + ) + + button_enhance.click( + enhance, + inputs=[input_image_0, input_image_1, input_image_2, input_image_3, enhancer_text, elevation, enhancer_noise_level, cfg_scale, steps, seed, color_shift], + outputs=[enhanced_image_0, enhanced_image_1, enhanced_image_2, enhanced_image_3] + ).success( + gen_3d, + inputs=[input_image_0, input_image_1, input_image_2, input_image_3, elevation, original_video_path, original_ply_path], + outputs=[input_3d_video, button_download_3d] + ).success( + lambda: gr.Button(interactive=True), + outputs=[button_download_3d], + ).success( + gen_3d, + inputs=[enhanced_image_0, enhanced_image_1, enhanced_image_2, enhanced_image_3, elevation, enhanced_video_path, enhanced_ply_path], + outputs=[enhanced_3d_video, enhanced_button_download_3d] + ).success( + lambda: gr.Button(interactive=True), + outputs=[enhanced_button_download_3d], + ) + +demo.launch() \ No newline at end of file diff --git a/assets/examples/i2mv/Mario_New_Super_Mario_Bros_U_Deluxe.png b/assets/examples/i2mv/Mario_New_Super_Mario_Bros_U_Deluxe.png new file mode 100644 index 0000000000000000000000000000000000000000..9c967ce54eb0e6d66368f13b7688e910183ad7b4 Binary files /dev/null and b/assets/examples/i2mv/Mario_New_Super_Mario_Bros_U_Deluxe.png differ diff --git a/assets/examples/i2mv/boy.jpg b/assets/examples/i2mv/boy.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f44fdbe15a13cb47ecd01517876fca51e486ac1d Binary files /dev/null and b/assets/examples/i2mv/boy.jpg differ diff --git a/assets/examples/i2mv/cake.png b/assets/examples/i2mv/cake.png new file mode 100644 index 0000000000000000000000000000000000000000..342b59c1b98dba762fa26bde554ba7f7be566fe3 Binary files /dev/null and b/assets/examples/i2mv/cake.png differ diff --git a/assets/examples/i2mv/cup.png b/assets/examples/i2mv/cup.png new file mode 100644 index 0000000000000000000000000000000000000000..2cb485ed2e50dedc37c72c7ff2cc59cabdd6c29c Binary files /dev/null and b/assets/examples/i2mv/cup.png differ diff --git a/assets/examples/i2mv/dragontoy.jpg b/assets/examples/i2mv/dragontoy.jpg new file mode 100644 index 0000000000000000000000000000000000000000..39ce417c23734e91a667eeb852fca63e1f6249b3 Binary files /dev/null and b/assets/examples/i2mv/dragontoy.jpg differ diff --git a/assets/examples/i2mv/gso_rabbit.jpg b/assets/examples/i2mv/gso_rabbit.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f1287e368cd8e847262df107cba2ab0fb864e0aa Binary files /dev/null and b/assets/examples/i2mv/gso_rabbit.jpg differ diff --git a/assets/examples/i2mv/house2.png b/assets/examples/i2mv/house2.png new file mode 100644 index 0000000000000000000000000000000000000000..cb1d1884af6e5aecdcf5f3df9c04c0bc7c3167e9 Binary files /dev/null and b/assets/examples/i2mv/house2.png differ diff --git a/assets/examples/i2mv/mannequin.png b/assets/examples/i2mv/mannequin.png new file mode 100644 index 0000000000000000000000000000000000000000..b166e130290ac7724bbd5363cc8b0e2cacf00769 Binary files /dev/null and b/assets/examples/i2mv/mannequin.png differ diff --git a/assets/examples/i2mv/sea_turtle.png b/assets/examples/i2mv/sea_turtle.png new file mode 100644 index 0000000000000000000000000000000000000000..8b8bf5fccfedb4cff60348334de8276487d22b62 Binary files /dev/null and b/assets/examples/i2mv/sea_turtle.png differ diff --git a/assets/examples/i2mv/skull.png b/assets/examples/i2mv/skull.png new file mode 100644 index 0000000000000000000000000000000000000000..82ab2ffdece45dc6fe21a5af9bf57bf9ec784d61 Binary files /dev/null and b/assets/examples/i2mv/skull.png differ diff --git a/assets/examples/mv_lq/cake/00.png b/assets/examples/mv_lq/cake/00.png new file mode 100644 index 0000000000000000000000000000000000000000..50c64ddf7de3c21ddefb017f30d017dad1277f77 Binary files /dev/null and b/assets/examples/mv_lq/cake/00.png differ diff --git a/assets/examples/mv_lq/cake/01.png b/assets/examples/mv_lq/cake/01.png new file mode 100644 index 0000000000000000000000000000000000000000..291e373df3cac7726e385b1709557388d44d04b3 Binary files /dev/null and b/assets/examples/mv_lq/cake/01.png differ diff --git a/assets/examples/mv_lq/cake/02.png b/assets/examples/mv_lq/cake/02.png new file mode 100644 index 0000000000000000000000000000000000000000..ddcfaac336bb0c3e740d606a4dc97f34f6628f1d Binary files /dev/null and b/assets/examples/mv_lq/cake/02.png differ diff --git a/assets/examples/mv_lq/cake/03.png b/assets/examples/mv_lq/cake/03.png new file mode 100644 index 0000000000000000000000000000000000000000..0368c8a3a9009eb8f9d06adb07f16e03ca8a6252 Binary files /dev/null and b/assets/examples/mv_lq/cake/03.png differ diff --git a/assets/examples/mv_lq/fish/00.png b/assets/examples/mv_lq/fish/00.png new file mode 100644 index 0000000000000000000000000000000000000000..f1050105b40b4e293083536fc8f8b1c26acffb7f Binary files /dev/null and b/assets/examples/mv_lq/fish/00.png differ diff --git a/assets/examples/mv_lq/fish/01.png b/assets/examples/mv_lq/fish/01.png new file mode 100644 index 0000000000000000000000000000000000000000..7c2523a4ffd59801e4a076abd175ca1847c86960 Binary files /dev/null and b/assets/examples/mv_lq/fish/01.png differ diff --git a/assets/examples/mv_lq/fish/02.png b/assets/examples/mv_lq/fish/02.png new file mode 100644 index 0000000000000000000000000000000000000000..ec91396a75e818131fd555a41bd6cc6aba6d2f37 Binary files /dev/null and b/assets/examples/mv_lq/fish/02.png differ diff --git a/assets/examples/mv_lq/fish/03.png b/assets/examples/mv_lq/fish/03.png new file mode 100644 index 0000000000000000000000000000000000000000..a56154e857cdcf3fb52f207c5ff1cbd9343b297c Binary files /dev/null and b/assets/examples/mv_lq/fish/03.png differ diff --git a/assets/examples/mv_lq/gascan/00.png b/assets/examples/mv_lq/gascan/00.png new file mode 100644 index 0000000000000000000000000000000000000000..0759a3dd0cf7010c842146f9f682ca3c167a46b3 Binary files /dev/null and b/assets/examples/mv_lq/gascan/00.png differ diff --git a/assets/examples/mv_lq/gascan/01.png b/assets/examples/mv_lq/gascan/01.png new file mode 100644 index 0000000000000000000000000000000000000000..e7d80b68db1179052198ba5c4b75077a85b66673 Binary files /dev/null and b/assets/examples/mv_lq/gascan/01.png differ diff --git a/assets/examples/mv_lq/gascan/02.png b/assets/examples/mv_lq/gascan/02.png new file mode 100644 index 0000000000000000000000000000000000000000..0026e8f3a3181f911c96d0aac6e0561f648e879b Binary files /dev/null and b/assets/examples/mv_lq/gascan/02.png differ diff --git a/assets/examples/mv_lq/gascan/03.png b/assets/examples/mv_lq/gascan/03.png new file mode 100644 index 0000000000000000000000000000000000000000..cb3a2c5dc4520e0c3f855afd4187e00147a1a9bb Binary files /dev/null and b/assets/examples/mv_lq/gascan/03.png differ diff --git a/assets/examples/mv_lq/house/00.png b/assets/examples/mv_lq/house/00.png new file mode 100644 index 0000000000000000000000000000000000000000..5e539b50d3d15562835f466d72bb473460df69ca Binary files /dev/null and b/assets/examples/mv_lq/house/00.png differ diff --git a/assets/examples/mv_lq/house/01.png b/assets/examples/mv_lq/house/01.png new file mode 100644 index 0000000000000000000000000000000000000000..70cf9d6715695b867f057131bc6b17757c06f0c6 Binary files /dev/null and b/assets/examples/mv_lq/house/01.png differ diff --git a/assets/examples/mv_lq/house/02.png b/assets/examples/mv_lq/house/02.png new file mode 100644 index 0000000000000000000000000000000000000000..d12b18cb6b93ff407675105d66733cb7fcbbbd15 Binary files /dev/null and b/assets/examples/mv_lq/house/02.png differ diff --git a/assets/examples/mv_lq/house/03.png b/assets/examples/mv_lq/house/03.png new file mode 100644 index 0000000000000000000000000000000000000000..c3a2bcf2c7ce91e72541dd82c8e8c9265f55180c Binary files /dev/null and b/assets/examples/mv_lq/house/03.png differ diff --git a/assets/examples/mv_lq/lamp/00.png b/assets/examples/mv_lq/lamp/00.png new file mode 100644 index 0000000000000000000000000000000000000000..f6d8d5d4c88ed2920a31fc94cca6498a6629bbfa Binary files /dev/null and b/assets/examples/mv_lq/lamp/00.png differ diff --git a/assets/examples/mv_lq/lamp/01.png b/assets/examples/mv_lq/lamp/01.png new file mode 100644 index 0000000000000000000000000000000000000000..a08b84173292438120806a0d6cd9ce3f04a5d03e Binary files /dev/null and b/assets/examples/mv_lq/lamp/01.png differ diff --git a/assets/examples/mv_lq/lamp/02.png b/assets/examples/mv_lq/lamp/02.png new file mode 100644 index 0000000000000000000000000000000000000000..5509f467fdca846ababea49155475ebaa0a6d1ea Binary files /dev/null and b/assets/examples/mv_lq/lamp/02.png differ diff --git a/assets/examples/mv_lq/lamp/03.png b/assets/examples/mv_lq/lamp/03.png new file mode 100644 index 0000000000000000000000000000000000000000..4487feebb24a838d06b4bdbca2160b3b58f8fe9e Binary files /dev/null and b/assets/examples/mv_lq/lamp/03.png differ diff --git a/assets/examples/mv_lq/mario/00.png b/assets/examples/mv_lq/mario/00.png new file mode 100644 index 0000000000000000000000000000000000000000..67988afc0adfabec80c2cc6bc35fea9ed183a8f8 Binary files /dev/null and b/assets/examples/mv_lq/mario/00.png differ diff --git a/assets/examples/mv_lq/mario/01.png b/assets/examples/mv_lq/mario/01.png new file mode 100644 index 0000000000000000000000000000000000000000..77ba23f07c4b2d7b7e02d4151e8e55e4c1e215af Binary files /dev/null and b/assets/examples/mv_lq/mario/01.png differ diff --git a/assets/examples/mv_lq/mario/02.png b/assets/examples/mv_lq/mario/02.png new file mode 100644 index 0000000000000000000000000000000000000000..5607502a9c549c13190e7c2ec4974d0f223924d3 Binary files /dev/null and b/assets/examples/mv_lq/mario/02.png differ diff --git a/assets/examples/mv_lq/mario/03.png b/assets/examples/mv_lq/mario/03.png new file mode 100644 index 0000000000000000000000000000000000000000..499e5d597e2a7bdc551a306ad0c988532fcff836 Binary files /dev/null and b/assets/examples/mv_lq/mario/03.png differ diff --git a/assets/examples/mv_lq/oldman/00.png b/assets/examples/mv_lq/oldman/00.png new file mode 100644 index 0000000000000000000000000000000000000000..dcd0be5cfe5a71d64edbb6beecf47e72a7f1cc46 Binary files /dev/null and b/assets/examples/mv_lq/oldman/00.png differ diff --git a/assets/examples/mv_lq/oldman/01.png b/assets/examples/mv_lq/oldman/01.png new file mode 100644 index 0000000000000000000000000000000000000000..43faae16580f473bddfb4ba31159dc8c53866020 Binary files /dev/null and b/assets/examples/mv_lq/oldman/01.png differ diff --git a/assets/examples/mv_lq/oldman/02.png b/assets/examples/mv_lq/oldman/02.png new file mode 100644 index 0000000000000000000000000000000000000000..bee66d7a12b0eaa7001fcfc6e8211929b41d55e1 Binary files /dev/null and b/assets/examples/mv_lq/oldman/02.png differ diff --git a/assets/examples/mv_lq/oldman/03.png b/assets/examples/mv_lq/oldman/03.png new file mode 100644 index 0000000000000000000000000000000000000000..3a74089d109c69c463f844ae0f5e78d799eef50c Binary files /dev/null and b/assets/examples/mv_lq/oldman/03.png differ diff --git a/assets/examples/mv_lq/tower/00.png b/assets/examples/mv_lq/tower/00.png new file mode 100644 index 0000000000000000000000000000000000000000..fe321333083af076f86e4c85c9637aa922543fa8 Binary files /dev/null and b/assets/examples/mv_lq/tower/00.png differ diff --git a/assets/examples/mv_lq/tower/01.png b/assets/examples/mv_lq/tower/01.png new file mode 100644 index 0000000000000000000000000000000000000000..18fa83d389d8c74b439f0b022b3d8117afedae93 Binary files /dev/null and b/assets/examples/mv_lq/tower/01.png differ diff --git a/assets/examples/mv_lq/tower/02.png b/assets/examples/mv_lq/tower/02.png new file mode 100644 index 0000000000000000000000000000000000000000..e7a9e8f9e6d41919734440a07d68942954b5fe91 Binary files /dev/null and b/assets/examples/mv_lq/tower/02.png differ diff --git a/assets/examples/mv_lq/tower/03.png b/assets/examples/mv_lq/tower/03.png new file mode 100644 index 0000000000000000000000000000000000000000..e8b6d90052deed848ec8352d36ba85802d04bd53 Binary files /dev/null and b/assets/examples/mv_lq/tower/03.png differ diff --git a/assets/examples/mv_lq/truck/00.png b/assets/examples/mv_lq/truck/00.png new file mode 100644 index 0000000000000000000000000000000000000000..2bf4539db7b8f29b6cabde5dcca08fdd02ffe7a2 Binary files /dev/null and b/assets/examples/mv_lq/truck/00.png differ diff --git a/assets/examples/mv_lq/truck/01.png b/assets/examples/mv_lq/truck/01.png new file mode 100644 index 0000000000000000000000000000000000000000..a2195d2968f4f5088b124d090545eadf96643eac Binary files /dev/null and b/assets/examples/mv_lq/truck/01.png differ diff --git a/assets/examples/mv_lq/truck/02.png b/assets/examples/mv_lq/truck/02.png new file mode 100644 index 0000000000000000000000000000000000000000..6bff8481611f7d46f8691bbbea19107cfa269afe Binary files /dev/null and b/assets/examples/mv_lq/truck/02.png differ diff --git a/assets/examples/mv_lq/truck/03.png b/assets/examples/mv_lq/truck/03.png new file mode 100644 index 0000000000000000000000000000000000000000..d988dc104ce3c18288da2a3b6affd1e389bcf108 Binary files /dev/null and b/assets/examples/mv_lq/truck/03.png differ diff --git a/assets/examples/mv_lq/tshirt/00.png b/assets/examples/mv_lq/tshirt/00.png new file mode 100644 index 0000000000000000000000000000000000000000..6e048e6902164742fd0b312bc4cc7ffbebbf570d Binary files /dev/null and b/assets/examples/mv_lq/tshirt/00.png differ diff --git a/assets/examples/mv_lq/tshirt/01.png b/assets/examples/mv_lq/tshirt/01.png new file mode 100644 index 0000000000000000000000000000000000000000..bde6e23c0d5815f26134c505028bd62879a6daaa Binary files /dev/null and b/assets/examples/mv_lq/tshirt/01.png differ diff --git a/assets/examples/mv_lq/tshirt/02.png b/assets/examples/mv_lq/tshirt/02.png new file mode 100644 index 0000000000000000000000000000000000000000..f05ab64f52e6337cff82a3ec26031e816f6533dc Binary files /dev/null and b/assets/examples/mv_lq/tshirt/02.png differ diff --git a/assets/examples/mv_lq/tshirt/03.png b/assets/examples/mv_lq/tshirt/03.png new file mode 100644 index 0000000000000000000000000000000000000000..2831ea5253a604c398fa7dcf4606fa37bf9a7e34 Binary files /dev/null and b/assets/examples/mv_lq/tshirt/03.png differ diff --git a/assets/examples/mv_lq/turtle/00.png b/assets/examples/mv_lq/turtle/00.png new file mode 100644 index 0000000000000000000000000000000000000000..a19958e69ae0b7abe662f35c970e53c319e9267e Binary files /dev/null and b/assets/examples/mv_lq/turtle/00.png differ diff --git a/assets/examples/mv_lq/turtle/01.png b/assets/examples/mv_lq/turtle/01.png new file mode 100644 index 0000000000000000000000000000000000000000..918cb36da1c65a812cecdbae9c1eb43bced7a903 Binary files /dev/null and b/assets/examples/mv_lq/turtle/01.png differ diff --git a/assets/examples/mv_lq/turtle/02.png b/assets/examples/mv_lq/turtle/02.png new file mode 100644 index 0000000000000000000000000000000000000000..5602b4fe385351784e417536415e3280cbd118a5 Binary files /dev/null and b/assets/examples/mv_lq/turtle/02.png differ diff --git a/assets/examples/mv_lq/turtle/03.png b/assets/examples/mv_lq/turtle/03.png new file mode 100644 index 0000000000000000000000000000000000000000..9367bf0c7a674cb6b5cf7c02d337fce22b4146db Binary files /dev/null and b/assets/examples/mv_lq/turtle/03.png differ diff --git a/assets/examples/mv_lq/vase/00.png b/assets/examples/mv_lq/vase/00.png new file mode 100644 index 0000000000000000000000000000000000000000..097cced5f3255809b3d3bd8a1fd74827cc3fdc5a Binary files /dev/null and b/assets/examples/mv_lq/vase/00.png differ diff --git a/assets/examples/mv_lq/vase/01.png b/assets/examples/mv_lq/vase/01.png new file mode 100644 index 0000000000000000000000000000000000000000..d3355823a1a21d17af81ab1865fe9ba894805ebd Binary files /dev/null and b/assets/examples/mv_lq/vase/01.png differ diff --git a/assets/examples/mv_lq/vase/02.png b/assets/examples/mv_lq/vase/02.png new file mode 100644 index 0000000000000000000000000000000000000000..aaffce22910383ba23c467017db0133776959885 Binary files /dev/null and b/assets/examples/mv_lq/vase/02.png differ diff --git a/assets/examples/mv_lq/vase/03.png b/assets/examples/mv_lq/vase/03.png new file mode 100644 index 0000000000000000000000000000000000000000..335f814e2265bc2db2a73ce2e17cf639a73b269e Binary files /dev/null and b/assets/examples/mv_lq/vase/03.png differ diff --git a/assets/examples/mv_lq_prerendered/cake.mp4 b/assets/examples/mv_lq_prerendered/cake.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..52eda65c01ccb96fb7c24d91ab22616a350fa392 Binary files /dev/null and b/assets/examples/mv_lq_prerendered/cake.mp4 differ diff --git a/assets/examples/mv_lq_prerendered/fish.mp4 b/assets/examples/mv_lq_prerendered/fish.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..061539898c03ff9941f8e68e19cd06bb4ab2625e Binary files /dev/null and b/assets/examples/mv_lq_prerendered/fish.mp4 differ diff --git a/assets/examples/mv_lq_prerendered/gascan.mp4 b/assets/examples/mv_lq_prerendered/gascan.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f3cf609288cce27ac9c84e7cb1d61b79e0a289c8 Binary files /dev/null and b/assets/examples/mv_lq_prerendered/gascan.mp4 differ diff --git a/assets/examples/mv_lq_prerendered/house.mp4 b/assets/examples/mv_lq_prerendered/house.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..279f9cb9419c3ad82b027c901b1da1a5656d80fe Binary files /dev/null and b/assets/examples/mv_lq_prerendered/house.mp4 differ diff --git a/assets/examples/mv_lq_prerendered/lamp.mp4 b/assets/examples/mv_lq_prerendered/lamp.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e1653295cfb288866b8159089497eefab7957051 Binary files /dev/null and b/assets/examples/mv_lq_prerendered/lamp.mp4 differ diff --git a/assets/examples/mv_lq_prerendered/mario.mp4 b/assets/examples/mv_lq_prerendered/mario.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..fd2b339e93d0a086dc89b230cc0ea1e90d0a67d6 Binary files /dev/null and b/assets/examples/mv_lq_prerendered/mario.mp4 differ diff --git a/assets/examples/mv_lq_prerendered/oldman.mp4 b/assets/examples/mv_lq_prerendered/oldman.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..92652e83e54ed24b994c37d78a1b0a34a3779edb Binary files /dev/null and b/assets/examples/mv_lq_prerendered/oldman.mp4 differ diff --git a/assets/examples/mv_lq_prerendered/tower.mp4 b/assets/examples/mv_lq_prerendered/tower.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5388f7746256b44af2dc8c2960d1e1716f01e912 Binary files /dev/null and b/assets/examples/mv_lq_prerendered/tower.mp4 differ diff --git a/assets/examples/mv_lq_prerendered/truck.mp4 b/assets/examples/mv_lq_prerendered/truck.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e3f62ccfdfe2643744b7a364e85a2a7726efefb8 Binary files /dev/null and b/assets/examples/mv_lq_prerendered/truck.mp4 differ diff --git a/assets/examples/mv_lq_prerendered/tshirt.mp4 b/assets/examples/mv_lq_prerendered/tshirt.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..512bb95180b61e5fbc6a1362c330de55f2e6b41b Binary files /dev/null and b/assets/examples/mv_lq_prerendered/tshirt.mp4 differ diff --git a/assets/examples/mv_lq_prerendered/turtle.mp4 b/assets/examples/mv_lq_prerendered/turtle.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2d5e450ef1f2e4ffdd38da8315f518ad88db5c52 Binary files /dev/null and b/assets/examples/mv_lq_prerendered/turtle.mp4 differ diff --git a/assets/examples/mv_lq_prerendered/vase.mp4 b/assets/examples/mv_lq_prerendered/vase.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0dc17005156ad568ba9afa867533a7cd2099c82c Binary files /dev/null and b/assets/examples/mv_lq_prerendered/vase.mp4 differ diff --git a/assets/method_overview.png b/assets/method_overview.png new file mode 100644 index 0000000000000000000000000000000000000000..a00127715709e23f10d7c5617d05c27b414c6cd6 Binary files /dev/null and b/assets/method_overview.png differ diff --git a/extern/LGM/LICENSE b/extern/LGM/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b6c1f57e9e1b16bfe27f5dc9e110a17006e527ec --- /dev/null +++ b/extern/LGM/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 3D Topia + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/extern/LGM/acc_configs/gpu1.yaml b/extern/LGM/acc_configs/gpu1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bbbabf2f3e134d3d5ac703495dbbd0a1d0a9876b --- /dev/null +++ b/extern/LGM/acc_configs/gpu1.yaml @@ -0,0 +1,15 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: 'NO' +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/extern/LGM/acc_configs/gpu4.yaml b/extern/LGM/acc_configs/gpu4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f4c53873d842bb32fb028d180c9ff243a5948f8 --- /dev/null +++ b/extern/LGM/acc_configs/gpu4.yaml @@ -0,0 +1,15 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/extern/LGM/acc_configs/gpu6.yaml b/extern/LGM/acc_configs/gpu6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32ef41c51c8cd36df6e43908edc7d321c36f26ad --- /dev/null +++ b/extern/LGM/acc_configs/gpu6.yaml @@ -0,0 +1,15 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 6 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/extern/LGM/acc_configs/gpu8.yaml b/extern/LGM/acc_configs/gpu8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..46b5c3522722bd605e9f82ddcf1fd1e5b8eaf9a0 --- /dev/null +++ b/extern/LGM/acc_configs/gpu8.yaml @@ -0,0 +1,15 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/extern/LGM/app.py b/extern/LGM/app.py new file mode 100644 index 0000000000000000000000000000000000000000..c12a8516f2cd55a5956b73213675d29b4eb3fadc --- /dev/null +++ b/extern/LGM/app.py @@ -0,0 +1,249 @@ +import os +import tyro +import imageio +import numpy as np +import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from safetensors.torch import load_file +import rembg +import gradio as gr + +import kiui +from kiui.op import recenter +from kiui.cam import orbit_camera + +from core.options import AllConfigs, Options +from core.models import LGM +from mvdream.pipeline_mvdream import MVDreamPipeline + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +GRADIO_VIDEO_PATH = 'gradio_output.mp4' +GRADIO_PLY_PATH = 'gradio_output.ply' + +opt = tyro.cli(AllConfigs) + +# model +model = LGM(opt) + +# resume pretrained checkpoint +if opt.resume is not None: + if opt.resume.endswith('safetensors'): + ckpt = load_file(opt.resume, device='cpu') + else: + ckpt = torch.load(opt.resume, map_location='cpu') + model.load_state_dict(ckpt, strict=False) + print(f'[INFO] Loaded checkpoint from {opt.resume}') +else: + print(f'[WARN] model randomly initialized, are you sure?') + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model = model.half().to(device) +model.eval() + +tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) +proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device) +proj_matrix[0, 0] = 1 / tan_half_fov +proj_matrix[1, 1] = 1 / tan_half_fov +proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) +proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) +proj_matrix[2, 3] = 1 + +# load dreams +pipe_text = MVDreamPipeline.from_pretrained( + 'ashawkey/mvdream-sd2.1-diffusers', # remote weights + torch_dtype=torch.float16, + trust_remote_code=True, + # local_files_only=True, +) +pipe_text = pipe_text.to(device) + +pipe_image = MVDreamPipeline.from_pretrained( + "ashawkey/imagedream-ipmv-diffusers", # remote weights + torch_dtype=torch.float16, + trust_remote_code=True, + # local_files_only=True, +) +pipe_image = pipe_image.to(device) + +# load rembg +bg_remover = rembg.new_session() + +# process function +def process(input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42): + + # seed + kiui.seed_everything(input_seed) + + os.makedirs(opt.workspace, exist_ok=True) + output_video_path = os.path.join(opt.workspace, GRADIO_VIDEO_PATH) + output_ply_path = os.path.join(opt.workspace, GRADIO_PLY_PATH) + + # text-conditioned + if input_image is None: + mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation) + mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8) + # bg removal + mv_image = [] + for i in range(4): + image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4] + # to white bg + image = image.astype(np.float32) / 255 + image = recenter(image, image[..., 0] > 0, border_ratio=0.2) + image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:]) + mv_image.append(image) + # image-conditioned (may also input text, but no text usually works too) + else: + input_image = np.array(input_image) # uint8 + # bg removal + carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4] + mask = carved_image[..., -1] > 0 + image = recenter(carved_image, mask, border_ratio=0.2) + image = image.astype(np.float32) / 255.0 + image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) + mv_image = pipe_image(prompt, image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation) + + mv_image_grid = np.concatenate([ + np.concatenate([mv_image[1], mv_image[2]], axis=1), + np.concatenate([mv_image[3], mv_image[0]], axis=1), + ], axis=0) + + # generate gaussians + input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32 + input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256] + input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False) + input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + rays_embeddings = model.prepare_default_rays(device, elevation=input_elevation) + input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W] + + with torch.no_grad(): + with torch.autocast(device_type='cuda', dtype=torch.float16): + # generate gaussians + gaussians = model.forward_gaussians(input_image) + + # save gaussians + model.gs.save_ply(gaussians, output_ply_path) + + # render 360 video + images = [] + elevation = 0 + if opt.fancy_video: + azimuth = np.arange(0, 720, 4, dtype=np.int32) + for azi in tqdm.tqdm(azimuth): + + cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + scale = min(azi / 360, 1) + + image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image'] + images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) + else: + azimuth = np.arange(0, 360, 2, dtype=np.int32) + for azi in tqdm.tqdm(azimuth): + + cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image'] + images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) + + images = np.concatenate(images, axis=0) + imageio.mimwrite(output_video_path, images, fps=30) + + return mv_image_grid, output_video_path, output_ply_path + +# gradio UI + +_TITLE = '''LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation''' + +_DESCRIPTION = ''' +
+ + +
+ +* Input can be only text, only image, or both image and text. +* If you find the output unsatisfying, try using different seeds! +''' + +block = gr.Blocks(title=_TITLE).queue() +with block: + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown('# ' + _TITLE) + gr.Markdown(_DESCRIPTION) + + with gr.Row(variant='panel'): + with gr.Column(scale=1): + # input image + input_image = gr.Image(label="image", type='pil') + # input prompt + input_text = gr.Textbox(label="prompt") + # negative prompt + input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate') + # elevation + input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0) + # inference steps + input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30) + # random seed + input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0) + # gen button + button_gen = gr.Button("Generate") + + + with gr.Column(scale=1): + with gr.Tab("Video"): + # final video results + output_video = gr.Video(label="video") + # ply file + output_file = gr.File(label="ply") + with gr.Tab("Multi-view Image"): + # multi-view results + output_image = gr.Image(interactive=False, show_label=False) + + button_gen.click(process, inputs=[input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed], outputs=[output_image, output_video, output_file]) + + gr.Examples( + examples=[ + "data_test/anya_rgba.png", + "data_test/bird_rgba.png", + "data_test/catstatue_rgba.png", + ], + inputs=[input_image], + outputs=[output_image, output_video, output_file], + fn=lambda x: process(input_image=x, prompt=''), + cache_examples=False, + label='Image-to-3D Examples' + ) + + gr.Examples( + examples=[ + "a motorbike", + "a hamburger", + "a furry red fox head", + ], + inputs=[input_text], + outputs=[output_image, output_video, output_file], + fn=lambda x: process(input_image=None, prompt=x), + cache_examples=False, + label='Text-to-3D Examples' + ) + +block.launch(server_name="0.0.0.0", share=False) \ No newline at end of file diff --git a/extern/LGM/convert.py b/extern/LGM/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..e898e3413e6a9d95d258005dc2c5c1bfadd94268 --- /dev/null +++ b/extern/LGM/convert.py @@ -0,0 +1,462 @@ + +import os +import tyro +import tqdm +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from core.options import AllConfigs, Options +from core.gs import GaussianRenderer + +import mcubes +import nerfacc +import nvdiffrast.torch as dr + +import kiui +from kiui.mesh import Mesh +from kiui.mesh_utils import clean_mesh, decimate_mesh +from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency +from kiui.op import uv_padding, safe_normalize, inverse_sigmoid +from kiui.cam import orbit_camera, get_perspective +from kiui.nn import MLP, trunc_exp +from kiui.gridencoder import GridEncoder + +def get_rays(pose, h, w, fovy, opengl=True): + + x, y = torch.meshgrid( + torch.arange(w, device=pose.device), + torch.arange(h, device=pose.device), + indexing="xy", + ) + x = x.flatten() + y = y.flatten() + + cx = w * 0.5 + cy = h * 0.5 + focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) + + camera_dirs = F.pad( + torch.stack( + [ + (x - cx + 0.5) / focal, + (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), + ], + dim=-1, + ), + (0, 1), + value=(-1.0 if opengl else 1.0), + ) # [hw, 3] + + rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] + rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] + + rays_d = safe_normalize(rays_d) + + return rays_o, rays_d + +# Triple renderer of gaussians, gaussian, and diso mesh. +# gaussian --> nerf --> mesh +class Converter(nn.Module): + def __init__(self, opt: Options): + super().__init__() + + self.opt = opt + self.device = torch.device("cuda") + + # gs renderer + self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) + self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device) + self.proj_matrix[0, 0] = 1 / self.tan_half_fov + self.proj_matrix[1, 1] = 1 / self.tan_half_fov + self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[2, 3] = 1 + + self.gs_renderer = GaussianRenderer(opt) + + self.gaussians = self.gs_renderer.load_ply(opt.test_path).to(self.device) + + # nerf renderer + if not self.opt.force_cuda_rast: + self.glctx = dr.RasterizeGLContext() + else: + self.glctx = dr.RasterizeCudaContext() + + self.step = 0 + self.render_step_size = 5e-3 + self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device) + self.estimator = nerfacc.OccGridEstimator(roi_aabb=self.aabb, resolution=64, levels=1) + + self.encoder_density = GridEncoder(num_levels=12) # VMEncoder(output_dim=16, mode='sum') + self.encoder = GridEncoder(num_levels=12) + self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False) + self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False) + + # mesh renderer + self.proj = torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device) + self.v = self.f = None + self.vt = self.ft = None + self.deform = None + self.albedo = None + + + @torch.no_grad() + def render_gs(self, pose): + + cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device) + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + out = self.gs_renderer.render(self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0)) + image = out['image'].squeeze(1).squeeze(0) # [C, H, W] + alpha = out['alpha'].squeeze(2).squeeze(1).squeeze(0) # [H, W] + + return image, alpha + + def get_density(self, xs): + # xs: [..., 3] + prefix = xs.shape[:-1] + xs = xs.view(-1, 3) + feats = self.encoder_density(xs) + density = trunc_exp(self.mlp_density(feats)) + density = density.view(*prefix, 1) + return density + + def render_nerf(self, pose): + + pose = torch.from_numpy(pose.astype(np.float32)).to(self.device) + + # get rays + resolution = self.opt.output_size + rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy) + + # update occ grid + if self.training: + def occ_eval_fn(xs): + sigmas = self.get_density(xs) + return self.render_step_size * sigmas + + self.estimator.update_every_n_steps(self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8) + self.step += 1 + + # render + def sigma_fn(t_starts, t_ends, ray_indices): + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 + sigmas = self.get_density(xs) + return sigmas.squeeze(-1) + + with torch.no_grad(): + ray_indices, t_starts, t_ends = self.estimator.sampling( + rays_o, + rays_d, + sigma_fn=sigma_fn, + near_plane=0.01, + far_plane=100, + render_step_size=self.render_step_size, + stratified=self.training, + cone_angle=0, + ) + + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 + sigmas = self.get_density(xs).squeeze(-1) + rgbs = torch.sigmoid(self.mlp(self.encoder(xs))) + + n_rays=rays_o.shape[0] + weights, trans, alphas = nerfacc.render_weight_from_density(t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays) + color = nerfacc.accumulate_along_rays(weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays) + alpha = nerfacc.accumulate_along_rays(weights, values=None, ray_indices=ray_indices, n_rays=n_rays) + + color = color + 1 * (1.0 - alpha) + + color = color.view(resolution, resolution, 3).clamp(0, 1).permute(2, 0, 1).contiguous() + alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous() + + return color, alpha + + def fit_nerf(self, iters=512, resolution=128): + + self.opt.output_size = resolution + + optimizer = torch.optim.Adam([ + {'params': self.encoder_density.parameters(), 'lr': 1e-2}, + {'params': self.encoder.parameters(), 'lr': 1e-2}, + {'params': self.mlp_density.parameters(), 'lr': 1e-3}, + {'params': self.mlp.parameters(), 'lr': 1e-3}, + ]) + + print(f"[INFO] fitting nerf...") + pbar = tqdm.trange(iters) + for i in pbar: + + ver = np.random.randint(-45, 45) + hor = np.random.randint(-180, 180) + rad = np.random.uniform(1.5, 3.0) + + pose = orbit_camera(ver, hor, rad) + + image_gt, alpha_gt = self.render_gs(pose) + image_pred, alpha_pred = self.render_nerf(pose) + + # if i % 200 == 0: + # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred) + + loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt) + loss = loss_mse #+ 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss() + + loss.backward() + self.encoder_density.grad_total_variation(1e-8) + + optimizer.step() + optimizer.zero_grad() + + pbar.set_description(f"MSE = {loss_mse.item():.6f}") + + print(f"[INFO] finished fitting nerf!") + + def render_mesh(self, pose): + + h = w = self.opt.output_size + + v = self.v + self.deform + f = self.f + + pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) + + # get v_clip and render rgb + v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0) + v_clip = v_cam @ self.proj.T + + rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) + + alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] + alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) # [H, W] important to enable gradients! + + if self.albedo is None: + xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3] + xyzs = xyzs.view(-1, 3) + mask = (alpha > 0).view(-1) + image = torch.zeros_like(xyzs, dtype=torch.float32) + if mask.any(): + masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask].detach(), bound=1))) + image[mask] = masked_albedo.float() + else: + texc, texc_db = dr.interpolate(self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs='all') + image = torch.sigmoid(dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db)) # [1, H, W, 3] + + image = image.view(1, h, w, 3) + # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1) + image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W] + image = alpha * image + (1 - alpha) + + return image, alpha + + def fit_mesh(self, iters=2048, resolution=512, decimate_target=5e4): + + self.opt.output_size = resolution + + # init mesh from nerf + grid_size = 256 + sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32) + + S = 128 + density_thresh = 10 + + X = torch.linspace(-1, 1, grid_size).split(S) + Y = torch.linspace(-1, 1, grid_size).split(S) + Z = torch.linspace(-1, 1, grid_size).split(S) + + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij') + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3] + val = self.get_density(pts.to(self.device)) + sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z] + + print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})') + + vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh) + vertices = vertices / (grid_size - 1.0) * 2 - 1 + + # clean + vertices = vertices.astype(np.float32) + triangles = triangles.astype(np.int32) + vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01) + if triangles.shape[0] > decimate_target: + vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False) + + self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) + self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) + self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device) + + # fit mesh from gs + lr_factor = 1 + optimizer = torch.optim.Adam([ + {'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor}, + {'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor}, + {'params': self.deform, 'lr': 1e-4}, + ]) + + print(f"[INFO] fitting mesh...") + pbar = tqdm.trange(iters) + for i in pbar: + + ver = np.random.randint(-10, 10) + hor = np.random.randint(-180, 180) + rad = self.opt.cam_radius # np.random.uniform(1, 2) + + pose = orbit_camera(ver, hor, rad) + + image_gt, alpha_gt = self.render_gs(pose) + image_pred, alpha_pred = self.render_mesh(pose) + + loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt) + # loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f) + loss_normal = normal_consistency(self.v + self.deform, self.f) + loss_offsets = (self.deform ** 2).sum(-1).mean() + loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets + + loss.backward() + + optimizer.step() + optimizer.zero_grad() + + # remesh periodically + if i > 0 and i % 512 == 0: + vertices = (self.v + self.deform).detach().cpu().numpy() + triangles = self.f.detach().cpu().numpy() + vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01) + if triangles.shape[0] > decimate_target: + vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False) + self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) + self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) + self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device) + lr_factor *= 0.5 + optimizer = torch.optim.Adam([ + {'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor}, + {'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor}, + {'params': self.deform, 'lr': 1e-4}, + ]) + + pbar.set_description(f"MSE = {loss_mse.item():.6f}") + + # last clean + vertices = (self.v + self.deform).detach().cpu().numpy() + triangles = self.f.detach().cpu().numpy() + vertices, triangles = clean_mesh(vertices, triangles, remesh=False) + self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) + self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) + self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device)) + + print(f"[INFO] finished fitting mesh!") + + # uv mesh refine + def fit_mesh_uv(self, iters=512, resolution=512, texture_resolution=1024, padding=2): + + self.opt.output_size = resolution + + # unwrap uv + print(f"[INFO] uv unwrapping...") + mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device) + mesh.auto_normal() + mesh.auto_uv() + + self.vt = mesh.vt + self.ft = mesh.ft + + # render uv maps + h = w = texture_resolution + uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1] + uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] + + rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4] + xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3] + mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1] + + # masked query + xyzs = xyzs.view(-1, 3) + mask = (mask > 0).view(-1) + + albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32) + + if mask.any(): + print(f"[INFO] querying texture...") + + xyzs = xyzs[mask] # [M, 3] + + # batched inference to avoid OOM + batch = [] + head = 0 + while head < xyzs.shape[0]: + tail = min(head + 640000, xyzs.shape[0]) + batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float()) + head += 640000 + + albedo[mask] = torch.cat(batch, dim=0) + + albedo = albedo.view(h, w, -1) + mask = mask.view(h, w) + albedo = uv_padding(albedo, mask, padding) + + # optimize texture + self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device) + + optimizer = torch.optim.Adam([ + {'params': self.albedo, 'lr': 1e-3}, + ]) + + print(f"[INFO] fitting mesh texture...") + pbar = tqdm.trange(iters) + for i in pbar: + + # shrink to front view as we care more about it... + ver = np.random.randint(-5, 5) + hor = np.random.randint(-15, 15) + rad = self.opt.cam_radius # np.random.uniform(1, 2) + + pose = orbit_camera(ver, hor, rad) + + image_gt, alpha_gt = self.render_gs(pose) + image_pred, alpha_pred = self.render_mesh(pose) + + loss_mse = F.mse_loss(image_pred, image_gt) + loss = loss_mse + + loss.backward() + + optimizer.step() + optimizer.zero_grad() + + pbar.set_description(f"MSE = {loss_mse.item():.6f}") + + print(f"[INFO] finished fitting mesh texture!") + + + @torch.no_grad() + def export_mesh(self, path): + + mesh = Mesh(v=self.v, f=self.f, vt=self.vt, ft=self.ft, albedo=torch.sigmoid(self.albedo), device=self.device) + mesh.auto_normal() + mesh.write(path) + + +opt = tyro.cli(AllConfigs) + +# load a saved ply and convert to mesh +assert opt.test_path.endswith('.ply'), '--test_path must be a .ply file saved by infer.py' + +converter = Converter(opt).cuda() +converter.fit_nerf() +converter.fit_mesh() +converter.fit_mesh_uv() +converter.export_mesh(opt.test_path.replace('.ply', '.glb')) diff --git a/extern/LGM/core/__init__.py b/extern/LGM/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/extern/LGM/core/attention.py b/extern/LGM/core/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1382f65805a8650b3d3369c803dd4df0bc9dc8 --- /dev/null +++ b/extern/LGM/core/attention.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import os +import warnings + +from torch import Tensor +from torch import nn + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + dim: int, + dim_q: int, + dim_k: int, + dim_v: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias) + self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias) + self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # q: [B, N, Cq] + # k: [B, M, Ck] + # v: [B, M, Cv] + # return: [B, N, C] + + B, N, _ = q.shape + M = k.shape[1] + + q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh] + k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] + v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] + + attn = q @ k.transpose(-2, -1) # [B, nh, N, M] + + attn = attn.softmax(dim=-1) # [B, nh, N, M] + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C] + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffCrossAttention(CrossAttention): + def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, _ = q.shape + M = k.shape[1] + + q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh] + k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] + v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/extern/LGM/core/gs.py b/extern/LGM/core/gs.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a56c664578f1e682bc9ceb8389ecab185b62cf --- /dev/null +++ b/extern/LGM/core/gs.py @@ -0,0 +1,190 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diff_gaussian_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, +) + +from core.options import Options + +import kiui + +class GaussianRenderer: + def __init__(self, opt: Options): + + self.opt = opt + self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda") + + # intrinsics + self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) + self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) + self.proj_matrix[0, 0] = 1 / self.tan_half_fov + self.proj_matrix[1, 1] = 1 / self.tan_half_fov + self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[2, 3] = 1 + + def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1): + # gaussians: [B, N, 14] + # cam_view, cam_view_proj: [B, V, 4, 4] + # cam_pos: [B, V, 3] + + device = gaussians.device + B, V = cam_view.shape[:2] + + # loop of loop... + images = [] + alphas = [] + for b in range(B): + + # pos, opacity, scale, rotation, shs + means3D = gaussians[b, :, 0:3].contiguous().float() + opacity = gaussians[b, :, 3:4].contiguous().float() + scales = gaussians[b, :, 4:7].contiguous().float() + rotations = gaussians[b, :, 7:11].contiguous().float() + rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3] + + for v in range(V): + + # render novel views + view_matrix = cam_view[b, v].float() + view_proj_matrix = cam_view_proj[b, v].float() + campos = cam_pos[b, v].float() + + raster_settings = GaussianRasterizationSettings( + image_height=self.opt.output_size, + image_width=self.opt.output_size, + tanfovx=self.tan_half_fov, + tanfovy=self.tan_half_fov, + bg=self.bg_color if bg_color is None else bg_color, + scale_modifier=scale_modifier, + viewmatrix=view_matrix, + projmatrix=view_proj_matrix, + sh_degree=0, + campos=campos, + prefiltered=False, + debug=False, + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( + means3D=means3D, + means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device), + shs=None, + colors_precomp=rgbs, + opacities=opacity, + scales=scales, + rotations=rotations, + cov3D_precomp=None, + ) + + rendered_image = rendered_image.clamp(0, 1) + + images.append(rendered_image) + alphas.append(rendered_alpha) + + images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size) + alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size) + + return { + "image": images, # [B, V, 3, H, W] + "alpha": alphas, # [B, V, 1, H, W] + } + + + def save_ply(self, gaussians, path, compatible=True): + # gaussians: [B, N, 14] + # compatible: save pre-activated gaussians as in the original paper + + assert gaussians.shape[0] == 1, 'only support batch size 1' + + from plyfile import PlyData, PlyElement + + means3D = gaussians[0, :, 0:3].contiguous().float() + opacity = gaussians[0, :, 3:4].contiguous().float() + scales = gaussians[0, :, 4:7].contiguous().float() + rotations = gaussians[0, :, 7:11].contiguous().float() + shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3] + + # prune by opacity + mask = opacity.squeeze(-1) >= 0.005 + means3D = means3D[mask] + opacity = opacity[mask] + scales = scales[mask] + rotations = rotations[mask] + shs = shs[mask] + + # invert activation to make it compatible with the original ply format + if compatible: + opacity = kiui.op.inverse_sigmoid(opacity) + scales = torch.log(scales + 1e-8) + shs = (shs - 0.5) / 0.28209479177387814 + + xyzs = means3D.detach().cpu().numpy() + f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = opacity.detach().cpu().numpy() + scales = scales.detach().cpu().numpy() + rotations = rotations.detach().cpu().numpy() + + l = ['x', 'y', 'z'] + # All channels except the 3 DC + for i in range(f_dc.shape[1]): + l.append('f_dc_{}'.format(i)) + l.append('opacity') + for i in range(scales.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(rotations.shape[1]): + l.append('rot_{}'.format(i)) + + dtype_full = [(attribute, 'f4') for attribute in l] + + elements = np.empty(xyzs.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + + PlyData([el]).write(path) + + def load_ply(self, path, compatible=True): + + from plyfile import PlyData, PlyElement + + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + print("Number of points at loading : ", xyz.shape[0]) + + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + shs = np.zeros((xyz.shape[0], 3)) + shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"]) + shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"]) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")] + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1) + gaussians = torch.from_numpy(gaussians).float() # cpu + + if compatible: + gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4]) + gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7]) + gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5 + + return gaussians \ No newline at end of file diff --git a/extern/LGM/core/models.py b/extern/LGM/core/models.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8bf322c77a1f1920830bb19b5acc1652239cb8 --- /dev/null +++ b/extern/LGM/core/models.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +import kiui +from kiui.lpips import LPIPS + +from core.unet import UNet +from core.options import Options +from core.gs import GaussianRenderer + + +class LGM(nn.Module): + def __init__( + self, + opt: Options, + ): + super().__init__() + + self.opt = opt + + # unet + self.unet = UNet( + 9, 14, + down_channels=self.opt.down_channels, + down_attention=self.opt.down_attention, + mid_attention=self.opt.mid_attention, + up_channels=self.opt.up_channels, + up_attention=self.opt.up_attention, + ) + + # last conv + self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again + + # Gaussian Renderer + self.gs = GaussianRenderer(opt) + + # activations... + self.pos_act = lambda x: x.clamp(-1, 1) + self.scale_act = lambda x: 0.1 * F.softplus(x) + self.opacity_act = lambda x: torch.sigmoid(x) + self.rot_act = lambda x: F.normalize(x, dim=-1) + self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again + + # LPIPS loss + if self.opt.lambda_lpips > 0: + self.lpips_loss = LPIPS(net='vgg') + self.lpips_loss.requires_grad_(False) + + + def state_dict(self, **kwargs): + # remove lpips_loss + state_dict = super().state_dict(**kwargs) + for k in list(state_dict.keys()): + if 'lpips_loss' in k: + del state_dict[k] + return state_dict + + + def prepare_default_rays(self, device, elevation=0): + + from kiui.cam import orbit_camera + from core.utils import get_rays + + cam_poses = np.stack([ + orbit_camera(elevation, 0, radius=self.opt.cam_radius), + orbit_camera(elevation, 90, radius=self.opt.cam_radius), + orbit_camera(elevation, 180, radius=self.opt.cam_radius), + orbit_camera(elevation, 270, radius=self.opt.cam_radius), + ], axis=0) # [4, 4, 4] + cam_poses = torch.from_numpy(cam_poses) + + rays_embeddings = [] + for i in range(cam_poses.shape[0]): + rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3] + rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] + rays_embeddings.append(rays_plucker) + + ## visualize rays for plotting figure + # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True) + + rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w] + + return rays_embeddings + + + def forward_gaussians(self, images): + # images: [B, 4, 9, H, W] + # return: Gaussians: [B, dim_t] + + B, V, C, H, W = images.shape + images = images.view(B*V, C, H, W) + + x = self.unet(images) # [B*4, 14, h, w] + x = self.conv(x) # [B*4, 14, h, w] + + x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size) + + ## visualize multi-view gaussian features for plotting figure + # tmp_alpha = self.opacity_act(x[0, :, 3:4]) + # tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha) + # tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5 + # kiui.vis.plot_image(tmp_img_rgb, save=True) + # kiui.vis.plot_image(tmp_img_pos, save=True) + + x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14) + + pos = self.pos_act(x[..., 0:3]) # [B, N, 3] + opacity = self.opacity_act(x[..., 3:4]) + scale = self.scale_act(x[..., 4:7]) + rotation = self.rot_act(x[..., 7:11]) + rgbs = self.rgb_act(x[..., 11:]) + + gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14] + + return gaussians + + + def forward(self, data, step_ratio=1): + # data: output of the dataloader + # return: loss + + results = {} + loss = 0 + + images = data['input'] # [B, 4, 9, h, W], input features + + # use the first view to predict gaussians + gaussians = self.forward_gaussians(images) # [B, N, 14] + + results['gaussians'] = gaussians + + # always use white bg + bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device) + + # use the other views for rendering and supervision + results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color) + pred_images = results['image'] # [B, V, C, output_size, output_size] + pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size] + + results['images_pred'] = pred_images + results['alphas_pred'] = pred_alphas + + gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views + gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks + + gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks) + + loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks) + loss = loss + loss_mse + + if self.opt.lambda_lpips > 0: + loss_lpips = self.lpips_loss( + # gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, + # pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, + # downsampled to at most 256 to reduce memory cost + F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), + F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), + ).mean() + results['loss_lpips'] = loss_lpips + loss = loss + self.opt.lambda_lpips * loss_lpips + + results['loss'] = loss + + # metric + with torch.no_grad(): + psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2)) + results['psnr'] = psnr + + return results diff --git a/extern/LGM/core/options.py b/extern/LGM/core/options.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc31944f89ff14a4387204f1828edd785bc3498 --- /dev/null +++ b/extern/LGM/core/options.py @@ -0,0 +1,120 @@ +import tyro +from dataclasses import dataclass +from typing import Tuple, Literal, Dict, Optional + + +@dataclass +class Options: + ### model + # Unet image input size + input_size: int = 256 + # Unet definition + down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024) + down_attention: Tuple[bool, ...] = (False, False, False, True, True, True) + mid_attention: bool = True + up_channels: Tuple[int, ...] = (1024, 1024, 512, 256) + up_attention: Tuple[bool, ...] = (True, True, True, False) + # Unet output size, dependent on the input_size and U-Net structure! + splat_size: int = 64 + # gaussian render size + output_size: int = 256 + + ### dataset + # data mode (only support s3 now) + data_mode: Literal['s3'] = 's3' + # fovy of the dataset + fovy: float = 49.1 + # camera near plane + znear: float = 0.5 + # camera far plane + zfar: float = 2.5 + # number of all views (input + output) + num_views: int = 12 + # number of views + num_input_views: int = 4 + # camera radius + cam_radius: float = 1.5 # to better use [-1, 1]^3 space + # num workers + num_workers: int = 8 + + ### training + # workspace + workspace: str = './workspace' + # resume + resume: Optional[str] = None + # batch size (per-GPU) + batch_size: int = 8 + # gradient accumulation + gradient_accumulation_steps: int = 1 + # training epochs + num_epochs: int = 30 + # lpips loss weight + lambda_lpips: float = 1.0 + # gradient clip + gradient_clip: float = 1.0 + # mixed precision + mixed_precision: str = 'bf16' + # learning rate + lr: float = 4e-4 + # augmentation prob for grid distortion + prob_grid_distortion: float = 0.5 + # augmentation prob for camera jitter + prob_cam_jitter: float = 0.5 + + ### testing + # test image path + test_path: Optional[str] = None + + ### misc + # nvdiffrast backend setting + force_cuda_rast: bool = False + # render fancy video with gaussian scaling effect + fancy_video: bool = False + + +# all the default settings +config_defaults: Dict[str, Options] = {} +config_doc: Dict[str, str] = {} + +config_doc['lrm'] = 'the default settings for LGM' +config_defaults['lrm'] = Options() + +config_doc['small'] = 'small model with lower resolution Gaussians' +config_defaults['small'] = Options( + input_size=256, + splat_size=64, + output_size=256, + batch_size=8, + gradient_accumulation_steps=1, + mixed_precision='bf16', +) + +config_doc['big'] = 'big model with higher resolution Gaussians' +config_defaults['big'] = Options( + input_size=256, + up_channels=(1024, 1024, 512, 256, 128), # one more decoder + up_attention=(True, True, True, False, False), + splat_size=128, + output_size=512, # render & supervise Gaussians at a higher resolution. + batch_size=8, + num_views=8, + gradient_accumulation_steps=1, + mixed_precision='bf16', +) + +config_doc['tiny'] = 'tiny model for ablation' +config_defaults['tiny'] = Options( + input_size=256, + down_channels=(32, 64, 128, 256, 512), + down_attention=(False, False, False, False, True), + up_channels=(512, 256, 128), + up_attention=(True, False, False, False), + splat_size=64, + output_size=256, + batch_size=16, + num_views=8, + gradient_accumulation_steps=1, + mixed_precision='bf16', +) + +AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc) diff --git a/extern/LGM/core/provider_objaverse.py b/extern/LGM/core/provider_objaverse.py new file mode 100644 index 0000000000000000000000000000000000000000..a90b773c75ccd5f21552f08cb8bff3630ef20782 --- /dev/null +++ b/extern/LGM/core/provider_objaverse.py @@ -0,0 +1,172 @@ +import os +import cv2 +import random +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset + +import kiui +from core.options import Options +from core.utils import get_rays, grid_distortion, orbit_camera_jitter + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + + +class ObjaverseDataset(Dataset): + + def _warn(self): + raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)') + + def __init__(self, opt: Options, training=True): + + self.opt = opt + self.training = training + + # TODO: remove this barrier + self._warn() + + # TODO: load the list of objects for training + self.items = [] + with open('TODO: file containing the list', 'r') as f: + for line in f.readlines(): + self.items.append(line.strip()) + + # naive split + if self.training: + self.items = self.items[:-self.opt.batch_size] + else: + self.items = self.items[-self.opt.batch_size:] + + # default camera intrinsics + self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) + self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) + self.proj_matrix[0, 0] = 1 / self.tan_half_fov + self.proj_matrix[1, 1] = 1 / self.tan_half_fov + self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear) + self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear) + self.proj_matrix[2, 3] = 1 + + + def __len__(self): + return len(self.items) + + def __getitem__(self, idx): + + uid = self.items[idx] + results = {} + + # load num_views images + images = [] + masks = [] + cam_poses = [] + + vid_cnt = 0 + + # TODO: choose views, based on your rendering settings + if self.training: + # input views are in (36, 72), other views are randomly selected + vids = np.random.permutation(np.arange(36, 73))[:self.opt.num_input_views].tolist() + np.random.permutation(100).tolist() + else: + # fixed views + vids = np.arange(36, 73, 4).tolist() + np.arange(100).tolist() + + for vid in vids: + + image_path = os.path.join(uid, 'rgb', f'{vid:03d}.png') + camera_path = os.path.join(uid, 'pose', f'{vid:03d}.txt') + + try: + # TODO: load data (modify self.client here) + image = np.frombuffer(self.client.get(image_path), np.uint8) + image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1] + c2w = [float(t) for t in self.client.get(camera_path).decode().strip().split(' ')] + c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4) + except Exception as e: + # print(f'[WARN] dataset {uid} {vid}: {e}') + continue + + # TODO: you may have a different camera system + # blender world + opencv cam --> opengl world & cam + c2w[1] *= -1 + c2w[[1, 2]] = c2w[[2, 1]] + c2w[:3, 1:3] *= -1 # invert up and forward direction + + # scale up radius to fully use the [-1, 1]^3 space! + c2w[:3, 3] *= self.opt.cam_radius / 1.5 # 1.5 is the default scale + + image = image.permute(2, 0, 1) # [4, 512, 512] + mask = image[3:4] # [1, 512, 512] + image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg + image = image[[2,1,0]].contiguous() # bgr to rgb + + images.append(image) + masks.append(mask.squeeze(0)) + cam_poses.append(c2w) + + vid_cnt += 1 + if vid_cnt == self.opt.num_views: + break + + if vid_cnt < self.opt.num_views: + print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!') + n = self.opt.num_views - vid_cnt + images = images + [images[-1]] * n + masks = masks + [masks[-1]] * n + cam_poses = cam_poses + [cam_poses[-1]] * n + + images = torch.stack(images, dim=0) # [V, C, H, W] + masks = torch.stack(masks, dim=0) # [V, H, W] + cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4] + + # normalized camera feats as in paper (transform the first pose to a fixed position) + transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0]) + cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4] + + images_input = F.interpolate(images[:self.opt.num_input_views].clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W] + cam_poses_input = cam_poses[:self.opt.num_input_views].clone() + + # data augmentation + if self.training: + # apply random grid distortion to simulate 3D inconsistency + if random.random() < self.opt.prob_grid_distortion: + images_input[1:] = grid_distortion(images_input[1:]) + # apply camera jittering (only to input!) + if random.random() < self.opt.prob_cam_jitter: + cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:]) + + images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + # resize render ground-truth images, range still in [0, 1] + results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size] + results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size] + + # build rays for input views + rays_embeddings = [] + for i in range(self.opt.num_input_views): + rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3] + rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] + rays_embeddings.append(rays_plucker) + + + rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w] + final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W] + results['input'] = final_input + + # opengl to colmap camera for gaussian renderer + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + results['cam_view'] = cam_view + results['cam_view_proj'] = cam_view_proj + results['cam_pos'] = cam_pos + + return results \ No newline at end of file diff --git a/extern/LGM/core/unet.py b/extern/LGM/core/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..4134e809d0bad8263874b77a217a7fef06309355 --- /dev/null +++ b/extern/LGM/core/unet.py @@ -0,0 +1,319 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +from typing import Tuple, Literal +from functools import partial + +from core.attention import MemEffAttention + +class MVAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + groups: int = 32, + eps: float = 1e-5, + residual: bool = True, + skip_scale: float = 1, + num_frames: int = 4, # WARN: hardcoded! + ): + super().__init__() + + self.residual = residual + self.skip_scale = skip_scale + self.num_frames = num_frames + + self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True) + self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop) + + def forward(self, x): + # x: [B*V, C, H, W] + BV, C, H, W = x.shape + B = BV // self.num_frames # assert BV % self.num_frames == 0 + + res = x + x = self.norm(x) + + x = x.reshape(B, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).reshape(B, -1, C) + x = self.attn(x) + x = x.reshape(B, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W) + + if self.residual: + x = (x + res) * self.skip_scale + return x + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resample: Literal['default', 'up', 'down'] = 'default', + groups: int = 32, + eps: float = 1e-5, + skip_scale: float = 1, # multiplied to output + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.skip_scale = skip_scale + + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.act = F.silu + + self.resample = None + if resample == 'up': + self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + elif resample == 'down': + self.resample = nn.AvgPool2d(kernel_size=2, stride=2) + + self.shortcut = nn.Identity() + if self.in_channels != self.out_channels: + self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True) + + + def forward(self, x): + res = x + + x = self.norm1(x) + x = self.act(x) + + if self.resample: + res = self.resample(res) + x = self.resample(x) + + x = self.conv1(x) + x = self.norm2(x) + x = self.act(x) + x = self.conv2(x) + + x = (x + self.shortcut(res)) * self.skip_scale + + return x + +class DownBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + downsample: bool = True, + attention: bool = True, + attention_heads: int = 16, + skip_scale: float = 1, + ): + super().__init__() + + nets = [] + attns = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale)) + if attention: + attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) + else: + attns.append(None) + self.nets = nn.ModuleList(nets) + self.attns = nn.ModuleList(attns) + + self.downsample = None + if downsample: + self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) + + def forward(self, x): + xs = [] + + for attn, net in zip(self.attns, self.nets): + x = net(x) + if attn: + x = attn(x) + xs.append(x) + + if self.downsample: + x = self.downsample(x) + xs.append(x) + + return x, xs + + +class MidBlock(nn.Module): + def __init__( + self, + in_channels: int, + num_layers: int = 1, + attention: bool = True, + attention_heads: int = 16, + skip_scale: float = 1, + ): + super().__init__() + + nets = [] + attns = [] + # first layer + nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) + # more layers + for i in range(num_layers): + nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) + if attention: + attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale)) + else: + attns.append(None) + self.nets = nn.ModuleList(nets) + self.attns = nn.ModuleList(attns) + + def forward(self, x): + x = self.nets[0](x) + for attn, net in zip(self.attns, self.nets[1:]): + if attn: + x = attn(x) + x = net(x) + return x + + +class UpBlock(nn.Module): + def __init__( + self, + in_channels: int, + prev_out_channels: int, + out_channels: int, + num_layers: int = 1, + upsample: bool = True, + attention: bool = True, + attention_heads: int = 16, + skip_scale: float = 1, + ): + super().__init__() + + nets = [] + attns = [] + for i in range(num_layers): + cin = in_channels if i == 0 else out_channels + cskip = prev_out_channels if (i == num_layers - 1) else out_channels + + nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale)) + if attention: + attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) + else: + attns.append(None) + self.nets = nn.ModuleList(nets) + self.attns = nn.ModuleList(attns) + + self.upsample = None + if upsample: + self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x, xs): + + for attn, net in zip(self.attns, self.nets): + res_x = xs[-1] + xs = xs[:-1] + x = torch.cat([x, res_x], dim=1) + x = net(x) + if attn: + x = attn(x) + + if self.upsample: + x = F.interpolate(x, scale_factor=2.0, mode='nearest') + x = self.upsample(x) + + return x + + +# it could be asymmetric! +class UNet(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024), + down_attention: Tuple[bool, ...] = (False, False, False, True, True), + mid_attention: bool = True, + up_channels: Tuple[int, ...] = (1024, 512, 256), + up_attention: Tuple[bool, ...] = (True, True, False), + layers_per_block: int = 2, + skip_scale: float = np.sqrt(0.5), + ): + super().__init__() + + # first + self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1) + + # down + down_blocks = [] + cout = down_channels[0] + for i in range(len(down_channels)): + cin = cout + cout = down_channels[i] + + down_blocks.append(DownBlock( + cin, cout, + num_layers=layers_per_block, + downsample=(i != len(down_channels) - 1), # not final layer + attention=down_attention[i], + skip_scale=skip_scale, + )) + self.down_blocks = nn.ModuleList(down_blocks) + + # mid + self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale) + + # up + up_blocks = [] + cout = up_channels[0] + for i in range(len(up_channels)): + cin = cout + cout = up_channels[i] + cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric + + up_blocks.append(UpBlock( + cin, cskip, cout, + num_layers=layers_per_block + 1, # one more layer for up + upsample=(i != len(up_channels) - 1), # not final layer + attention=up_attention[i], + skip_scale=skip_scale, + )) + self.up_blocks = nn.ModuleList(up_blocks) + + # last + self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5) + self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) + + + def forward(self, x): + # x: [B, Cin, H, W] + + # first + x = self.conv_in(x) + + # down + xss = [x] + for block in self.down_blocks: + x, xs = block(x) + xss.extend(xs) + + # mid + x = self.mid_block(x) + + # up + for block in self.up_blocks: + xs = xss[-len(block.nets):] + xss = xss[:-len(block.nets)] + x = block(x, xs) + + # last + x = self.norm_out(x) + x = F.silu(x) + x = self.conv_out(x) # [B, Cout, H', W'] + + return x diff --git a/extern/LGM/core/utils.py b/extern/LGM/core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..17ab4f5116f6c40422efb65a9bd139ef07f9e41c --- /dev/null +++ b/extern/LGM/core/utils.py @@ -0,0 +1,109 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import roma +from kiui.op import safe_normalize + +def get_rays(pose, h, w, fovy, opengl=True): + + x, y = torch.meshgrid( + torch.arange(w, device=pose.device), + torch.arange(h, device=pose.device), + indexing="xy", + ) + x = x.flatten() + y = y.flatten() + + cx = w * 0.5 + cy = h * 0.5 + + focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) + + camera_dirs = F.pad( + torch.stack( + [ + (x - cx + 0.5) / focal, + (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), + ], + dim=-1, + ), + (0, 1), + value=(-1.0 if opengl else 1.0), + ) # [hw, 3] + + rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] + rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] + + rays_o = rays_o.view(h, w, 3) + rays_d = safe_normalize(rays_d).view(h, w, 3) + + return rays_o, rays_d + +def orbit_camera_jitter(poses, strength=0.1): + # poses: [B, 4, 4], assume orbit camera in opengl format + # random orbital rotate + + B = poses.shape[0] + rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1) + rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1) + + rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y) + R = rot @ poses[:, :3, :3] + T = rot @ poses[:, :3, 3:] + + new_poses = poses.clone() + new_poses[:, :3, :3] = R + new_poses[:, :3, 3:] = T + + return new_poses + +def grid_distortion(images, strength=0.5): + # images: [B, C, H, W] + # num_steps: int, grid resolution for distortion + # strength: float in [0, 1], strength of distortion + + B, C, H, W = images.shape + + num_steps = np.random.randint(8, 17) + grid_steps = torch.linspace(-1, 1, num_steps) + + # have to loop batch... + grids = [] + for b in range(B): + # construct displacement + x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive + x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb + x_steps = (x_steps * W).long() # [num_steps] + x_steps[0] = 0 + x_steps[-1] = W + xs = [] + for i in range(num_steps - 1): + xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i])) + xs = torch.cat(xs, dim=0) # [W] + + y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive + y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb + y_steps = (y_steps * H).long() # [num_steps] + y_steps[0] = 0 + y_steps[-1] = H + ys = [] + for i in range(num_steps - 1): + ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i])) + ys = torch.cat(ys, dim=0) # [H] + + # construct grid + grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W] + grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2] + + grids.append(grid) + + grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2] + + # grid sample + images = F.grid_sample(images, grids, align_corners=False) + + return images + diff --git a/extern/LGM/data_test/anya_rgba.png b/extern/LGM/data_test/anya_rgba.png new file mode 100644 index 0000000000000000000000000000000000000000..089499e16e410207c890b45bc865627352df967d Binary files /dev/null and b/extern/LGM/data_test/anya_rgba.png differ diff --git a/extern/LGM/data_test/bird_rgba.png b/extern/LGM/data_test/bird_rgba.png new file mode 100644 index 0000000000000000000000000000000000000000..a1940a7320ed524b047475a46d446fc7704044e4 Binary files /dev/null and b/extern/LGM/data_test/bird_rgba.png differ diff --git a/extern/LGM/data_test/catstatue_rgba.png b/extern/LGM/data_test/catstatue_rgba.png new file mode 100644 index 0000000000000000000000000000000000000000..3b44eb51645f4ecf9e288c53a46000c5a795af69 Binary files /dev/null and b/extern/LGM/data_test/catstatue_rgba.png differ diff --git a/extern/LGM/gui.py b/extern/LGM/gui.py new file mode 100644 index 0000000000000000000000000000000000000000..c5188c7b06800de5dc3675bc7503e10c56b06608 --- /dev/null +++ b/extern/LGM/gui.py @@ -0,0 +1,294 @@ + +import os +import tyro +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from core.options import AllConfigs, Options +from core.gs import GaussianRenderer + +import dearpygui.dearpygui as dpg + +import kiui +from kiui.cam import OrbitCamera + + +class GUI: + def __init__(self, opt: Options): + self.opt = opt + self.W = opt.output_size + self.H = opt.output_size + self.cam = OrbitCamera(self.W, self.H, r=opt.cam_radius, fovy=opt.fovy) + + self.device = torch.device("cuda") + + self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) + self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device) + self.proj_matrix[0, 0] = 1 / self.tan_half_fov + self.proj_matrix[1, 1] = 1 / self.tan_half_fov + self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[2, 3] = 1 + + self.mode = "image" + + self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) + self.need_update = True # update buffer_image + + # renderer + self.renderer = GaussianRenderer(opt) + self.gaussain_scale_factor = 1 + + self.gaussians = self.renderer.load_ply(opt.test_path).to(self.device) + + dpg.create_context() + self.register_dpg() + self.test_step() + + def __del__(self): + dpg.destroy_context() + + @torch.no_grad() + def test_step(self): + # ignore if no need to update + if not self.need_update: + return + + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + starter.record() + + # should update image + if self.need_update: + # render image + + cam_poses = torch.from_numpy(self.cam.pose).unsqueeze(0).to(self.device) + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + buffer_image = self.renderer.render(self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=self.gaussain_scale_factor)[self.mode] + buffer_image = buffer_image.squeeze(1) # [B, C, H, W] + + if self.mode in ['alpha']: + buffer_image = buffer_image.repeat(1, 3, 1, 1) + + buffer_image = F.interpolate( + buffer_image, + size=(self.H, self.W), + mode="bilinear", + align_corners=False, + ).squeeze(0) + + self.buffer_image = ( + buffer_image.permute(1, 2, 0) + .contiguous() + .clamp(0, 1) + .contiguous() + .detach() + .cpu() + .numpy() + ) + + self.need_update = False + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)") + dpg.set_value( + "_texture", self.buffer_image + ) # buffer must be contiguous, else seg fault! + + def register_dpg(self): + ### register texture + + with dpg.texture_registry(show=False): + dpg.add_raw_texture( + self.W, + self.H, + self.buffer_image, + format=dpg.mvFormat_Float_rgb, + tag="_texture", + ) + + ### register window + + # the rendered image, as the primary window + with dpg.window( + tag="_primary_window", + width=self.W, + height=self.H, + pos=[0, 0], + no_move=True, + no_title_bar=True, + no_scrollbar=True, + ): + # add the texture + dpg.add_image("_texture") + + # dpg.set_primary_window("_primary_window", True) + + # control window + with dpg.window( + label="Control", + tag="_control_window", + width=600, + height=self.H, + pos=[self.W, 0], + no_move=True, + no_title_bar=True, + ): + # button theme + with dpg.theme() as theme_button: + with dpg.theme_component(dpg.mvButton): + dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) + dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) + + # timer stuff + with dpg.group(horizontal=True): + dpg.add_text("Infer time: ") + dpg.add_text("no data", tag="_log_infer_time") + + # rendering options + with dpg.collapsing_header(label="Rendering", default_open=True): + # mode combo + def callback_change_mode(sender, app_data): + self.mode = app_data + self.need_update = True + + dpg.add_combo( + ("image", "alpha"), + label="mode", + default_value=self.mode, + callback=callback_change_mode, + ) + + # fov slider + def callback_set_fovy(sender, app_data): + self.cam.fovy = np.deg2rad(app_data) + self.need_update = True + + dpg.add_slider_int( + label="FoV (vertical)", + min_value=1, + max_value=120, + format="%d deg", + default_value=np.rad2deg(self.cam.fovy), + callback=callback_set_fovy, + ) + + def callback_set_gaussain_scale(sender, app_data): + self.gaussain_scale_factor = app_data + self.need_update = True + + dpg.add_slider_float( + label="gaussain scale", + min_value=0, + max_value=1, + format="%.2f", + default_value=self.gaussain_scale_factor, + callback=callback_set_gaussain_scale, + ) + + ### register camera handler + + def callback_camera_drag_rotate(sender, app_data): + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.orbit(dx, dy) + self.need_update = True + + def callback_camera_wheel_scale(sender, app_data): + if not dpg.is_item_focused("_primary_window"): + return + + delta = app_data + + self.cam.scale(delta) + self.need_update = True + + def callback_camera_drag_pan(sender, app_data): + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.pan(dx, dy) + self.need_update = True + + with dpg.handler_registry(): + # for camera moving + dpg.add_mouse_drag_handler( + button=dpg.mvMouseButton_Left, + callback=callback_camera_drag_rotate, + ) + dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) + dpg.add_mouse_drag_handler( + button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan + ) + + dpg.create_viewport( + title="Gaussian3D", + width=self.W + 600, + height=self.H + (45 if os.name == "nt" else 0), + resizable=False, + ) + + ### global theme + with dpg.theme() as theme_no_padding: + with dpg.theme_component(dpg.mvAll): + # set all padding to 0 to avoid scroll bar + dpg.add_theme_style( + dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core + ) + dpg.add_theme_style( + dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core + ) + dpg.add_theme_style( + dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core + ) + + dpg.bind_item_theme("_primary_window", theme_no_padding) + + dpg.setup_dearpygui() + + ### register a larger font + # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf + if os.path.exists("LXGWWenKai-Regular.ttf"): + with dpg.font_registry(): + with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font: + dpg.bind_font(default_font) + + # dpg.show_metrics() + + dpg.show_viewport() + + def render(self): + while dpg.is_dearpygui_running(): + # update texture every frame + self.test_step() + dpg.render_dearpygui_frame() + + +opt = tyro.cli(AllConfigs) + +# load a saved ply and visualize +assert opt.test_path.endswith('.ply'), '--test_path must be a .ply file saved by infer.py' + +gui = GUI(opt) +gui.render() \ No newline at end of file diff --git a/extern/LGM/infer.py b/extern/LGM/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..205e4c5b1160c16867af51d91897df90c16c054d --- /dev/null +++ b/extern/LGM/infer.py @@ -0,0 +1,157 @@ + +import os +import tyro +import glob +import imageio +import numpy as np +import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from safetensors.torch import load_file +import rembg + +import kiui +from kiui.op import recenter +from kiui.cam import orbit_camera + +from core.options import AllConfigs, Options +from core.models import LGM +from mvdream.pipeline_mvdream import MVDreamPipeline + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +opt = tyro.cli(AllConfigs) + +# model +model = LGM(opt) + +# resume pretrained checkpoint +if opt.resume is not None: + if opt.resume.endswith('safetensors'): + ckpt = load_file(opt.resume, device='cpu') + else: + ckpt = torch.load(opt.resume, map_location='cpu') + model.load_state_dict(ckpt, strict=False) + print(f'[INFO] Loaded checkpoint from {opt.resume}') +else: + print(f'[WARN] model randomly initialized, are you sure?') + +# device +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model = model.half().to(device) +model.eval() + +rays_embeddings = model.prepare_default_rays(device) + +tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) +proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device) +proj_matrix[0, 0] = 1 / tan_half_fov +proj_matrix[1, 1] = 1 / tan_half_fov +proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) +proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) +proj_matrix[2, 3] = 1 + +# load image dream +pipe = MVDreamPipeline.from_pretrained( + "ashawkey/imagedream-ipmv-diffusers", # remote weights + torch_dtype=torch.float16, + trust_remote_code=True, + # local_files_only=True, +) +pipe = pipe.to(device) + +# load rembg +bg_remover = rembg.new_session() + +# process function +def process(opt: Options, path): + name = os.path.splitext(os.path.basename(path))[0] + print(f'[INFO] Processing {path} --> {name}') + os.makedirs(opt.workspace, exist_ok=True) + + input_image = kiui.read_image(path, mode='uint8') + + # bg removal + carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4] + mask = carved_image[..., -1] > 0 + + # recenter + image = recenter(carved_image, mask, border_ratio=0.2) + + # generate mv + image = image.astype(np.float32) / 255.0 + + # rgba to rgb white bg + if image.shape[-1] == 4: + image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) + + mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=30, elevation=0) + mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32 + + # generate gaussians + input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256] + input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False) + input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W] + + with torch.no_grad(): + with torch.autocast(device_type='cuda', dtype=torch.float16): + # generate gaussians + gaussians = model.forward_gaussians(input_image) + + # save gaussians + model.gs.save_ply(gaussians, os.path.join(opt.workspace, name + '.ply')) + + # render 360 video + images = [] + elevation = 0 + + if opt.fancy_video: + + azimuth = np.arange(0, 720, 4, dtype=np.int32) + for azi in tqdm.tqdm(azimuth): + + cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + scale = min(azi / 360, 1) + + image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image'] + images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) + else: + azimuth = np.arange(0, 360, 2, dtype=np.int32) + for azi in tqdm.tqdm(azimuth): + + cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image'] + images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) + + images = np.concatenate(images, axis=0) + imageio.mimwrite(os.path.join(opt.workspace, name + '.mp4'), images, fps=30) + + +assert opt.test_path is not None +if os.path.isdir(opt.test_path): + file_paths = glob.glob(os.path.join(opt.test_path, "*")) +else: + file_paths = [opt.test_path] +for path in file_paths: + process(opt, path) diff --git a/extern/LGM/main.py b/extern/LGM/main.py new file mode 100644 index 0000000000000000000000000000000000000000..8d658e78efcf480da907aa7ead7735e8910fcaa8 --- /dev/null +++ b/extern/LGM/main.py @@ -0,0 +1,185 @@ +import tyro +import time +import random + +import torch +from core.options import AllConfigs +from core.models import LGM +from accelerate import Accelerator, DistributedDataParallelKwargs +from safetensors.torch import load_file + +import kiui + +def main(): + opt = tyro.cli(AllConfigs) + + # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + + accelerator = Accelerator( + mixed_precision=opt.mixed_precision, + gradient_accumulation_steps=opt.gradient_accumulation_steps, + # kwargs_handlers=[ddp_kwargs], + ) + + # model + model = LGM(opt) + + # resume + if opt.resume is not None: + if opt.resume.endswith('safetensors'): + ckpt = load_file(opt.resume, device='cpu') + else: + ckpt = torch.load(opt.resume, map_location='cpu') + + # tolerant load (only load matching shapes) + # model.load_state_dict(ckpt, strict=False) + state_dict = model.state_dict() + for k, v in ckpt.items(): + if k in state_dict: + if state_dict[k].shape == v.shape: + state_dict[k].copy_(v) + else: + accelerator.print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.') + else: + accelerator.print(f'[WARN] unexpected param {k}: {v.shape}') + + # data + if opt.data_mode == 's3': + from core.provider_objaverse import ObjaverseDataset as Dataset + else: + raise NotImplementedError + + train_dataset = Dataset(opt, training=True) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=opt.batch_size, + shuffle=True, + num_workers=opt.num_workers, + pin_memory=True, + drop_last=True, + ) + + test_dataset = Dataset(opt, training=False) + test_dataloader = torch.utils.data.DataLoader( + test_dataset, + batch_size=opt.batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + drop_last=False, + ) + + # optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95)) + + # scheduler (per-iteration) + # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3000, eta_min=1e-6) + total_steps = opt.num_epochs * len(train_dataloader) + pct_start = 3000 / total_steps + scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=total_steps, pct_start=pct_start) + + # accelerate + model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare( + model, optimizer, train_dataloader, test_dataloader, scheduler + ) + + # loop + for epoch in range(opt.num_epochs): + # train + model.train() + total_loss = 0 + total_psnr = 0 + for i, data in enumerate(train_dataloader): + with accelerator.accumulate(model): + + optimizer.zero_grad() + + step_ratio = (epoch + i / len(train_dataloader)) / opt.num_epochs + + out = model(data, step_ratio) + loss = out['loss'] + psnr = out['psnr'] + accelerator.backward(loss) + + # gradient clipping + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip) + + optimizer.step() + scheduler.step() + + total_loss += loss.detach() + total_psnr += psnr.detach() + + if accelerator.is_main_process: + # logging + if i % 100 == 0: + mem_free, mem_total = torch.cuda.mem_get_info() + print(f"[INFO] {i}/{len(train_dataloader)} mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G lr: {scheduler.get_last_lr()[0]:.7f} step_ratio: {step_ratio:.4f} loss: {loss.item():.6f}") + + # save log images + if i % 500 == 0: + gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] + gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] + kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images) + + # gt_alphas = data['masks_output'].detach().cpu().numpy() # [B, V, 1, output_size, output_size] + # gt_alphas = gt_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, gt_alphas.shape[1] * gt_alphas.shape[3], 1) + # kiui.write_image(f'{opt.workspace}/train_gt_alphas_{epoch}_{i}.jpg', gt_alphas) + + pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] + pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) + kiui.write_image(f'{opt.workspace}/train_pred_images_{epoch}_{i}.jpg', pred_images) + + # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size] + # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1) + # kiui.write_image(f'{opt.workspace}/train_pred_alphas_{epoch}_{i}.jpg', pred_alphas) + + total_loss = accelerator.gather_for_metrics(total_loss).mean() + total_psnr = accelerator.gather_for_metrics(total_psnr).mean() + if accelerator.is_main_process: + total_loss /= len(train_dataloader) + total_psnr /= len(train_dataloader) + accelerator.print(f"[train] epoch: {epoch} loss: {total_loss.item():.6f} psnr: {total_psnr.item():.4f}") + + # checkpoint + # if epoch % 10 == 0 or epoch == opt.num_epochs - 1: + accelerator.wait_for_everyone() + accelerator.save_model(model, opt.workspace) + + # eval + with torch.no_grad(): + model.eval() + total_psnr = 0 + for i, data in enumerate(test_dataloader): + + out = model(data) + + psnr = out['psnr'] + total_psnr += psnr.detach() + + # save some images + if accelerator.is_main_process: + gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] + gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] + kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images) + + pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] + pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) + kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images) + + # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size] + # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1) + # kiui.write_image(f'{opt.workspace}/eval_pred_alphas_{epoch}_{i}.jpg', pred_alphas) + + torch.cuda.empty_cache() + + total_psnr = accelerator.gather_for_metrics(total_psnr).mean() + if accelerator.is_main_process: + total_psnr /= len(test_dataloader) + accelerator.print(f"[eval] epoch: {epoch} psnr: {psnr:.4f}") + + + +if __name__ == "__main__": + main() diff --git a/extern/LGM/mvdream/mv_unet.py b/extern/LGM/mvdream/mv_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..7d9ad4def5910394eb64b36f9f76c98e8eaf80ae --- /dev/null +++ b/extern/LGM/mvdream/mv_unet.py @@ -0,0 +1,1005 @@ +import math +import numpy as np +from inspect import isfunction +from typing import Optional, Any, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from diffusers.configuration_utils import ConfigMixin +from diffusers.models.modeling_utils import ModelMixin + +# require xformers! +import xformers +import xformers.ops + +from kiui.cam import orbit_camera + +def get_camera( + num_frames, elevation=0, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False, +): + angle_gap = azimuth_span / num_frames + cameras = [] + for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap): + + pose = orbit_camera(elevation, azimuth, radius=1) # [4, 4] + + # opengl to blender + if blender_coord: + pose[2] *= -1 + pose[[1, 2]] = pose[[2, 1]] + + cameras.append(pose.flatten()) + + if extra_view: + cameras.append(np.zeros_like(cameras[0])) + + return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16] + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + # import pdb; pdb.set_trace() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def default(val, d): + if val is not None: + return val + return d() if isfunction(d) else d + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + ip_dim=0, + ip_weight=1, + ): + super().__init__() + + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.ip_dim = ip_dim + self.ip_weight = ip_weight + + if self.ip_dim > 0: + self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None): + q = self.to_q(x) + context = default(context, x) + + if self.ip_dim > 0: + # context: [B, 77 + 16(ip), 1024] + token_len = context.shape[1] + context_ip = context[:, -self.ip_dim :, :] + k_ip = self.to_k_ip(context_ip) + v_ip = self.to_v_ip(context_ip) + context = context[:, : (token_len - self.ip_dim), :] + + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + if self.ip_dim > 0: + k_ip, v_ip = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (k_ip, v_ip), + ) + # actually compute the attention, what we cannot get enough of + out_ip = xformers.ops.memory_efficient_attention( + q, k_ip, v_ip, attn_bias=None, op=self.attention_op + ) + out = out + self.ip_weight * out_ip + + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock3D(nn.Module): + + def __init__( + self, + dim, + n_heads, + d_head, + context_dim, + dropout=0.0, + gated_ff=True, + ip_dim=0, + ip_weight=1, + ): + super().__init__() + + self.attn1 = MemoryEfficientCrossAttention( + query_dim=dim, + context_dim=None, # self-attention + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = MemoryEfficientCrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + # ip only applies to cross-attention + ip_dim=ip_dim, + ip_weight=ip_weight, + ) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + + def forward(self, x, context=None, num_frames=1): + x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous() + x = self.attn1(self.norm1(x), context=None) + x + x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous() + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer3D(nn.Module): + + def __init__( + self, + in_channels, + n_heads, + d_head, + context_dim, # cross attention input dim + depth=1, + dropout=0.0, + ip_dim=0, + ip_weight=1, + ): + super().__init__() + + if not isinstance(context_dim, list): + context_dim = [context_dim] + + self.in_channels = in_channels + + inner_dim = n_heads * d_head + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock3D( + inner_dim, + n_heads, + d_head, + context_dim=context_dim[d], + dropout=dropout, + ip_dim=ip_dim, + ip_weight=ip_weight, + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + + + def forward(self, x, context=None, num_frames=1): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i], num_frames=num_frames) + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + + return x + x_in + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head ** -0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q, k, v = map( + lambda t: t.reshape(b, t.shape[1], self.heads, -1) + .transpose(1, 2) + .reshape(b, self.heads, t.shape[1], -1) + .contiguous(), + (q, k, v), + ) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) + self.proj_in = nn.Linear(embedding_dim, dim) + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * ff_mult, bias=False), + nn.GELU(), + nn.Linear(dim * ff_mult, dim, bias=False), + ) + ] + ) + ) + + def forward(self, x): + latents = self.latents.repeat(x.size(0), 1, 1) + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class CondSequential(nn.Sequential): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, num_frames=1): + for layer in self: + if isinstance(layer, ResBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer3D): + x = layer(x, context, num_frames=num_frames) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class MultiViewUNetModel(ModelMixin, ConfigMixin): + """ + The full multi-view UNet model with attention, timestep embedding and camera embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + :param camera_dim: dimensionality of camera input. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + transformer_depth=1, + context_dim=None, + n_embed=None, + num_attention_blocks=None, + adm_in_channels=None, + camera_dim=None, + ip_dim=0, # imagedream uses ip_dim > 0 + ip_weight=1.0, + **kwargs, + ): + super().__init__() + assert context_dim is not None + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + self.ip_dim = ip_dim + self.ip_weight = ip_weight + + if self.ip_dim > 0: + self.image_embed = Resampler( + dim=context_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=ip_dim, # num token + embedding_dim=1280, + output_dim=context_dim, + ff_mult=4, + ) + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(model_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + if camera_dim is not None: + time_embed_dim = model_channels * 4 + self.camera_embed = nn.Sequential( + nn.Linear(camera_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) + elif self.num_classes == "continuous": + # print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + nn.Linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + CondSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers: List[Any] = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + if num_attention_blocks is None or nr < num_attention_blocks[level]: + layers.append( + SpatialTransformer3D( + ch, + num_heads, + dim_head, + context_dim=context_dim, + depth=transformer_depth, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight, + ) + ) + self.input_blocks.append(CondSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + CondSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + self.middle_block = CondSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ), + SpatialTransformer3D( + ch, + num_heads, + dim_head, + context_dim=context_dim, + depth=transformer_depth, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + if num_attention_blocks is None or i < num_attention_blocks[level]: + layers.append( + SpatialTransformer3D( + ch, + num_heads, + dim_head, + context_dim=context_dim, + depth=transformer_depth, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight, + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(CondSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + nn.GroupNorm(32, ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + nn.GroupNorm(32, ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def forward( + self, + x, + timesteps=None, + context=None, + y=None, + camera=None, + num_frames=1, + ip=None, + ip_img=None, + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :param num_frames: a integer indicating number of frames for tensor reshaping. + :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). + """ + assert ( + x.shape[0] % num_frames == 0 + ), "input batch size must be dividable by num_frames!" + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y is not None + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # Add camera embeddings + if camera is not None: + emb = emb + self.camera_embed(camera) + + # imagedream variant + if self.ip_dim > 0: + x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9] + ip_emb = self.image_embed(ip) + context = torch.cat((context, ip_emb), 1) + + h = x + for module in self.input_blocks: + h = module(h, emb, context, num_frames=num_frames) + hs.append(h) + h = self.middle_block(h, emb, context, num_frames=num_frames) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, num_frames=num_frames) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) \ No newline at end of file diff --git a/extern/LGM/mvdream/pipeline_mvdream.py b/extern/LGM/mvdream/pipeline_mvdream.py new file mode 100644 index 0000000000000000000000000000000000000000..3b7b3a72558aa03f77a1626e26d4f87fd830dd90 --- /dev/null +++ b/extern/LGM/mvdream/pipeline_mvdream.py @@ -0,0 +1,559 @@ +import torch +import torch.nn.functional as F +import inspect +import numpy as np +from typing import Callable, List, Optional, Union +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor +from diffusers import AutoencoderKL, DiffusionPipeline +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, +) +from diffusers.configuration_utils import FrozenDict +from diffusers.schedulers import DDIMScheduler +from diffusers.utils.torch_utils import randn_tensor + +from mvdream.mv_unet import MultiViewUNetModel, get_camera + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MVDreamPipeline(DiffusionPipeline): + + _optional_components = ["feature_extractor", "image_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + unet: MultiViewUNetModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + scheduler: DDIMScheduler, + # imagedream variant + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + requires_safety_checker: bool = False, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate( + "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + unet=unet, + scheduler=scheduler, + tokenizer=tokenizer, + text_encoder=text_encoder, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError( + "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher" + ) + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError( + "`enable_model_offload` requires `accelerate v0.17.0` or higher." + ) + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook( + cpu_offloaded_model, device, prev_module_hook=hook + ) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance: bool, + negative_prompt=None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError( + f"`prompt` should be either a string or a list of strings, but got {type(prompt)}." + ) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder.dtype, device=device + ) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if image.dtype == np.float32: + image = (image * 255).astype(np.uint8) + + image = self.feature_extractor(image, return_tensors="pt").pixel_values + image = image.to(device=device, dtype=dtype) + + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + return torch.zeros_like(image_embeds), image_embeds + + def encode_image_latents(self, image, device, num_images_per_prompt): + + dtype = next(self.image_encoder.parameters()).dtype + + image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W] + image = 2 * image - 1 + image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) + image = image.to(dtype=dtype) + + posterior = self.vae.encode(image).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + + return torch.zeros_like(latents), latents + + @torch.no_grad() + def __call__( + self, + prompt: str = "", + image: Optional[np.ndarray] = None, + height: int = 256, + width: int = 256, + elevation: float = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.0, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "numpy", # pil, numpy, latents + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + num_frames: int = 4, + device=torch.device("cuda:0"), + ): + self.unet = self.unet.to(device=device) + self.vae = self.vae.to(device=device) + self.text_encoder = self.text_encoder.to(device=device) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # imagedream variant + if image is not None: + assert isinstance(image, np.ndarray) and image.dtype == np.float32 + self.image_encoder = self.image_encoder.to(device=device) + image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt) + image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt) + + _prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + ) # type: ignore + prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2) + + # Prepare latent variables + actual_num_frames = num_frames if image is None else num_frames + 1 + latents: torch.Tensor = self.prepare_latents( + actual_num_frames * num_images_per_prompt, + 4, + height, + width, + prompt_embeds_pos.dtype, + device, + generator, + None, + ) + + if image is not None: + camera = get_camera(num_frames, elevation=elevation, extra_view=True).to(dtype=latents.dtype, device=device) + else: + camera = get_camera(num_frames, elevation=elevation, extra_view=False).to(dtype=latents.dtype, device=device) + camera = camera.repeat_interleave(num_images_per_prompt, dim=0) + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + multiplier = 2 if do_classifier_free_guidance else 1 + latent_model_input = torch.cat([latents] * multiplier) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + unet_inputs = { + 'x': latent_model_input, + 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device), + 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames), + 'num_frames': actual_num_frames, + 'camera': torch.cat([camera] * multiplier), + } + + if image is not None: + unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames) + unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat + + # predict the noise residual + noise_pred = self.unet.forward(**unet_inputs) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents: torch.Tensor = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # type: ignore + + # Post-processing + if output_type == "latent": + image = latents + elif output_type == "pil": + image = self.decode_latents(latents) + image = self.numpy_to_pil(image) + else: # numpy + image = self.decode_latents(latents) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return image \ No newline at end of file diff --git a/extern/LGM/readme.md b/extern/LGM/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..e29cb2bfa8c9461130365d9ac6e68a2c631c6887 --- /dev/null +++ b/extern/LGM/readme.md @@ -0,0 +1,108 @@ + +## Large Multi-View Gaussian Model + +This is the official implementation of *LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation*. + +### [Project Page](https://me.kiui.moe/lgm/) | [Arxiv](https://arxiv.org/abs/2402.05054) | [Weights](https://huggingface.co./ashawkey/LGM) | + +https://github.com/3DTopia/LGM/assets/25863658/cf64e489-29f3-4935-adba-e393a24c26e8 + +### News +[2024.4.3] Thanks to [@yxymessi](https://github.com/yxymessi) and [@florinshen](https://github.com/florinshen), we have fixed a **severe bug in rotation normalization** [here](https://github.com/3DTopia/LGM/commit/9a0797cdbacf8e6216d0108cb00cbe43b9cb3d81). We have finetuned the model with correct normalization for 30 more epochs and uploaded new checkpoints. + +### Replicate Demo: +* gaussians: [demo](https://replicate.com/camenduru/lgm) | [code](https://github.com/camenduru/LGM-replicate) +* mesh: [demo](https://replicate.com/camenduru/lgm-ply-to-glb) | [code](https://github.com/camenduru/LGM-ply-to-glb-replicate) + +Thanks to [@camenduru](https://github.com/camenduru)! + +### Install + +```bash +# xformers is required! please refer to https://github.com/facebookresearch/xformers for details. +# for example, we use torch 2.1.0 + cuda 11.8 +pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118 +pip install -U xformers --index-url https://download.pytorch.org/whl/cu118 + +# a modified gaussian splatting (+ depth, alpha rendering) +git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization +pip install ./diff-gaussian-rasterization + +# for mesh extraction +pip install git+https://github.com/NVlabs/nvdiffrast + +# other dependencies +pip install -r requirements.txt +``` + +### Pretrained Weights + +Our pretrained weight can be downloaded from [huggingface](https://huggingface.co./ashawkey/LGM). + +For example, to download the fp16 model for inference: +```bash +mkdir pretrained && cd pretrained +wget https://huggingface.co./ashawkey/LGM/resolve/main/model_fp16_fixrot.safetensors +cd .. +``` + +For [MVDream](https://github.com/bytedance/MVDream) and [ImageDream](https://github.com/bytedance/ImageDream), we use a [diffusers implementation](https://github.com/ashawkey/mvdream_diffusers). +Their weights will be downloaded automatically. + +### Inference + +Inference takes about 10GB GPU memory (loading all imagedream, mvdream, and our LGM). + +```bash +### gradio app for both text/image to 3D +python app.py big --resume pretrained/model_fp16.safetensors + +### test +# --workspace: folder to save output (*.ply and *.mp4) +# --test_path: path to a folder containing images, or a single image +python infer.py big --resume pretrained/model_fp16.safetensors --workspace workspace_test --test_path data_test + +### local gui to visualize saved ply +python gui.py big --output_size 800 --test_path workspace_test/saved.ply + +### mesh conversion +python convert.py big --test_path workspace_test/saved.ply +``` + +For more options, please check [options](./core/options.py). + +### Training + +**NOTE**: +Since the dataset used in our training is based on AWS, it cannot be directly used for training in a new environment. +We provide the necessary training code framework, please check and modify the [dataset](./core/provider_objaverse.py) implementation! + +We also provide the **~80K subset of [Objaverse](https://objaverse.allenai.org/objaverse-1.0)** used to train LGM in [objaverse_filter](https://github.com/ashawkey/objaverse_filter). + +```bash +# debug training +accelerate launch --config_file acc_configs/gpu1.yaml main.py big --workspace workspace_debug + +# training (use slurm for multi-nodes training) +accelerate launch --config_file acc_configs/gpu8.yaml main.py big --workspace workspace +``` + +### Acknowledgement + +This work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing! + +- [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) and [diff-gaussian-rasterization](https://github.com/graphdeco-inria/diff-gaussian-rasterization) +- [nvdiffrast](https://github.com/NVlabs/nvdiffrast) +- [dearpygui](https://github.com/hoffstadt/DearPyGui) +- [tyro](https://github.com/brentyi/tyro) + +### Citation + +``` +@article{tang2024lgm, + title={LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation}, + author={Tang, Jiaxiang and Chen, Zhaoxi and Chen, Xiaokang and Wang, Tengfei and Zeng, Gang and Liu, Ziwei}, + journal={arXiv preprint arXiv:2402.05054}, + year={2024} +} +``` diff --git a/extern/LGM/requirements.txt b/extern/LGM/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ade4e49c6b22e2c0ecdca2e0b8ac507a94cd634d --- /dev/null +++ b/extern/LGM/requirements.txt @@ -0,0 +1,28 @@ +torch +numpy +tyro +diffusers +dearpygui +einops +accelerate +gradio +imageio +imageio-ffmpeg +lpips +matplotlib +packaging +Pillow +pygltflib +rembg[gpu,cli] +rich +safetensors +scikit-image +scikit-learn +scipy +tqdm +transformers +trimesh +kiui >= 0.2.3 +xatlas +roma +plyfile diff --git a/extern/LGM/scripts/convert_all.py b/extern/LGM/scripts/convert_all.py new file mode 100644 index 0000000000000000000000000000000000000000..163ae27a414a4925d8f8b7b2b55a8c3571f04792 --- /dev/null +++ b/extern/LGM/scripts/convert_all.py @@ -0,0 +1,15 @@ +import os +import glob +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('dir', default='workspace', type=str) +parser.add_argument('--gpu', default=0, type=int, help='ID of GPU to use') +args = parser.parse_args() + +files = glob.glob(f'{args.dir}/*.ply') + +for file in files: + name = file.replace('.ply', '') + os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} python convert.py big --test_path {file}') + # os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} kire {name}.glb --save_video {name}_mesh.mp4 --wogui') \ No newline at end of file diff --git a/extern/LGM/scripts/examples.sh b/extern/LGM/scripts/examples.sh new file mode 100644 index 0000000000000000000000000000000000000000..d68897eda043eb73cc87199fe560919bf7ead42f --- /dev/null +++ b/extern/LGM/scripts/examples.sh @@ -0,0 +1,17 @@ +# debug training +accelerate launch --config_file acc_configs/gpu1.yaml main.py big --workspace workspace_debug + +# training (should use slurm) +accelerate launch --config_file acc_configs/gpu8.yaml main.py big --workspace workspace + +# test +python infer.py big --workspace workspace_test --resume workspace/model.safetensors --test_path data_test + +# gradio app +python app.py big --resume workspace/model.safetensors + +# local gui +python gui.py big --output_size 800 --test_path workspace_test/anya_rgba.ply + +# mesh conversion +python convert.py big --test_path workspace_test/anya_rgba.ply \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a6f4d75f1a266bea6ebe8fe4a4ccb7ff434e2030 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,49 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 +torch==2.1.2 +xformers + +mmcv==1.7.0 +git+https://github.com/huggingface/diffusers +timm==0.6.12 +accelerate +tensorboard +tensorboardX +transformers +sentencepiece~=0.1.99 +protobuf==3.20.2 +yapf==0.40.1 +peft==0.6.2 +ftfy +beautifulsoup4 +gradio==4.20.1 +opencv-python +bs4 +einops +optimum +came-pytorch +wandb +scipy +roma +kornia +numpy +tyro +dearpygui +imageio +imageio-ffmpeg +lpips +matplotlib +packaging +Pillow +pygltflib +rembg[gpu,cli] +rich +safetensors +scikit-image +scikit-learn +tqdm +trimesh +kiui >= 0.2.3 +xatlas +plyfile +gradio_imageslider +https://huggingface.co./spaces/ashawkey/LGM/resolve/main/wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl?download=true \ No newline at end of file diff --git a/src/configs/config.py b/src/configs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..086f4e462db40c3bfadc71d3a250253f76420d62 --- /dev/null +++ b/src/configs/config.py @@ -0,0 +1,29 @@ +image_size = 512 + +# model setting +model = 'PixArtMS_XL_2' +use_crossview_module = True + +mixed_precision = 'bf16' # ['fp16', 'fp32', 'bf16'] +fp32_attention = True +pipeline_load_from = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" +aspect_ratio_type = 'ASPECT_RATIO_512' +pe_interpolation = 1.0 + +# pixart-sigma +scale_factor = 0.13025 +model_max_length = 300 +kv_compress = False +kv_compress_config = { + 'sampling': 'conv', # ['conv', 'uniform', 'ave'] + 'scale_factor': 2, + 'kv_compress_layer': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27], +} +qk_norm = False +micro_condition = False + +# controlnet +copy_blocks_num = 13 + +# diffusion sampling +train_sampling_steps = 1000 \ No newline at end of file diff --git a/src/diffusion/__init__.py b/src/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54508d56cbff959d571efbc6c1a8041e0ace22a1 --- /dev/null +++ b/src/diffusion/__init__.py @@ -0,0 +1,8 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from .iddpm import IDDPM +from .dpm_solver import DPMS +from .sa_sampler import SASolverSampler diff --git a/src/diffusion/dpm_solver.py b/src/diffusion/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..232449c276973b75c7c9c93f8904bf017a42ac39 --- /dev/null +++ b/src/diffusion/dpm_solver.py @@ -0,0 +1,36 @@ +import torch +from .model import gaussian_diffusion as gd +from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP + + +def DPMS( + model, + condition, + uncondition, + cfg_scale, + model_type='noise', # or "x_start" or "v" or "score" + noise_schedule="linear", + guidance_type='classifier-free', + model_kwargs={}, + diffusion_steps=1000 +): + betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) + + ## 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas) + + ## 2. Convert your discrete-time `model` to the continuous-time + ## noise prediction model. Here is an example for a diffusion model + ## `model` with the noise prediction type ("noise") . + model_fn = model_wrapper( + model, + noise_schedule, + model_type=model_type, + model_kwargs=model_kwargs, + guidance_type=guidance_type, + condition=condition, + unconditional_condition=uncondition, + guidance_scale=cfg_scale, + ) + ## 3. Define dpm-solver and sample by multistep DPM-Solver. + return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") \ No newline at end of file diff --git a/src/diffusion/iddpm.py b/src/diffusion/iddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..c9459f4c807d2318a699392d51a86bc10bbe318f --- /dev/null +++ b/src/diffusion/iddpm.py @@ -0,0 +1,53 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +from diffusion.model.respace import SpacedDiffusion, space_timesteps +from .model import gaussian_diffusion as gd + + +def IDDPM( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + pred_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, + snr=False, + return_startx=False, +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + (( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ) + if pred_sigma + else None + ), + loss_type=loss_type, + snr=snr, + return_startx=return_startx, + # rescale_timesteps=rescale_timesteps, + ) \ No newline at end of file diff --git a/src/diffusion/lcm_scheduler.py b/src/diffusion/lcm_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..9d69dcedcc33ec494b3b152b8922c2cb3f976bc9 --- /dev/null +++ b/src/diffusion/lcm_scheduler.py @@ -0,0 +1,459 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers import ConfigMixin, SchedulerMixin +from diffusers.configuration_utils import register_to_config +from diffusers.utils import BaseOutput + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class LCMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + denoised: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class LCMScheduler(SchedulerMixin, ConfigMixin): + """ + `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co./papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + # _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # LCM Timesteps Setting: # Linear Spacing + c = self.config.num_train_timesteps // lcm_origin_steps + lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule + + self.timesteps = torch.from_numpy(timesteps.copy()).to(device) + + def get_scalings_for_boundary_condition_discrete(self, t): + self.sigma_data = 0.5 # Default: 0.5 + + # By dividing 0.1: This is almost a delta function at t=0. + c_skip = self.sigma_data ** 2 / ((t / 0.1) ** 2 + self.sigma_data ** 2) + c_out = ((t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data ** 2) ** 0.5) + return c_skip, c_out + + def step( + self, + model_output: torch.FloatTensor, + timeindex: int, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[LCMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.FloatTensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # 1. get previous step value + prev_timeindex = timeindex + 1 + if prev_timeindex < len(self.timesteps): + prev_timestep = self.timesteps[prev_timeindex] + else: + prev_timestep = timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 3. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + + # 4. Different Parameterization: + parameterization = self.config.prediction_type + + if parameterization == "epsilon": # noise-prediction + pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + + elif parameterization == "sample": # x-prediction + pred_x0 = model_output + + elif parameterization == "v_prediction": # v-prediction + pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + + # 4. Denoise model output using boundary conditions + denoised = c_out * pred_x0 + c_skip * sample + + # 5. Sample z ~ N(0, I), For MultiStep Inference + # Noise is not used for one-step sampling. + if len(self.timesteps) > 1: + noise = torch.randn(model_output.shape).to(model_output.device) + prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise + else: + prev_sample = denoised + + if not return_dict: + return (prev_sample, denoised) + + return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + diff --git a/src/diffusion/model/__init__.py b/src/diffusion/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a0d2755ad3a52d9d304763f96bf8b8c13dbbe76 --- /dev/null +++ b/src/diffusion/model/__init__.py @@ -0,0 +1 @@ +from .nets import * diff --git a/src/diffusion/model/builder.py b/src/diffusion/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..22821d03ef3410325885ad89a289c8ba1c032bac --- /dev/null +++ b/src/diffusion/model/builder.py @@ -0,0 +1,14 @@ +from mmcv import Registry + +from diffusion.model.utils import set_grad_checkpoint + +MODELS = Registry('models') + + +def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs): + if isinstance(cfg, str): + cfg = dict(type=cfg) + model = MODELS.build(cfg, default_args=kwargs) + if use_grad_checkpoint: + set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step) + return model diff --git a/src/diffusion/model/diffusion_utils.py b/src/diffusion/model/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cedd4fa2433f32c34df1157839b423ecb444e403 --- /dev/null +++ b/src/diffusion/model/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/src/diffusion/model/dpm_solver.py b/src/diffusion/model/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..bd738f0100939c756f72509e2dc79bcd2d08993b --- /dev/null +++ b/src/diffusion/model/dpm_solver.py @@ -0,0 +1,1337 @@ +import torch +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise + schedule are the default settings in Yang Song's ScoreSDE: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.T = 1. + self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype) + self.total_N = self.log_alpha_array.shape[1] + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + else: + self.T = 1. + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + + def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): + """ + For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. + We clip the log-SNR near t=T within -5.1 to ensure the stability. + Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. + """ + log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas)) + lambs = log_alphas - log_sigmas + idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) + if idx > 0: + log_alphas = log_alphas[:-idx] + return log_alphas + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim()) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -expand_dims(sigma_t, x.dim()) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3, ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3, ] * (K - 1) + [1] + else: + orders = [3, ] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2, ] * K + else: + K = steps // 2 + 1 + orders = [2, ] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1, ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum(torch.tensor([0, ] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + sigma_t / sigma_s * x + - alpha_t * phi_1 * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, + solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( + s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + torch.exp(log_alpha_s1 - log_alpha_s) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, + return_intermediate=False, solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( + s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + (torch.exp(log_alpha_s1 - log_alpha_s)) * x + - (sigma_s1 * phi_11) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == 'dpmsolver': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + - 0.5 * (alpha_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == 'dpmsolver': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( + t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1. / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, + r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, + solver_type='dpmsolver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + solver_type=solver_type, + **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + return_intermediate=True, + solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, + solver_type=solver_type, + **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type, + method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, + solver_type=solver_type, + atol=atol, rtol=rtol, return_intermediate=return_intermediate) + + def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, disable_progress_ui=False + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ['multistep', 'singlestep', + 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ['multistep', 'singlestep', + 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == 'adaptive': + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, + solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, + solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1), disable=disable_progress_ui): + t = timesteps[step] + # We only use lower order for steps < 10 + # if lower_order_final and steps < 10: + if lower_order_final: # recommended by Shuchen Xue + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, + solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, + order=order, + skip_type=skip_type, + t_T=t_T, t_0=t_0, + device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order, ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, + device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/src/diffusion/model/edm_sample.py b/src/diffusion/model/edm_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..f930bd3be0e533f3829d83813479763229941881 --- /dev/null +++ b/src/diffusion/model/edm_sample.py @@ -0,0 +1,171 @@ +import random +import numpy as np +from tqdm import tqdm + +from diffusion.model.utils import * + + +# ---------------------------------------------------------------------------- +# Proposed EDM sampler (Algorithm 2). + +def edm_sampler( + net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like, + num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, + S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs +): + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( + sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + x_next = latents.to(torch.float64) * t_steps[0] + for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + t_hat = net.round_sigma(t_cur + gamma * t_cur) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. + denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + + return x_next + + +# ---------------------------------------------------------------------------- +# Generalized ablation sampler, representing the superset of all sampling +# methods discussed in the paper. + +def ablation_sampler( + net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like, + num_steps=18, sigma_min=None, sigma_max=None, rho=7, + solver='heun', discretization='edm', schedule='linear', scaling='none', + epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1, + S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, +): + assert solver in ['euler', 'heun'] + assert discretization in ['vp', 've', 'iddpm', 'edm'] + assert schedule in ['vp', 've', 'linear'] + assert scaling in ['vp', 'none'] + + # Helper functions for VP & VE noise level schedules. + vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 + vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) + vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * ( + sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d + ve_sigma = lambda t: t.sqrt() + ve_sigma_deriv = lambda t: 0.5 / t.sqrt() + ve_sigma_inv = lambda sigma: sigma ** 2 + + # Select default noise level range based on the specified time step discretization. + if sigma_min is None: + vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s) + sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization] + if sigma_max is None: + vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1) + sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization] + + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Compute corresponding betas for VP. + vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1) + vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d + + # Define time steps in terms of noise level. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + if discretization == 'vp': + orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) + sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) + elif discretization == 've': + orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1))) + sigma_steps = ve_sigma(orig_t_steps) + elif discretization == 'iddpm': + u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) + alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 + for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 + u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() + u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] + sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] + else: + assert discretization == 'edm' + sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( + sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + + # Define noise level schedule. + if schedule == 'vp': + sigma = vp_sigma(vp_beta_d, vp_beta_min) + sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) + sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) + elif schedule == 've': + sigma = ve_sigma + sigma_deriv = ve_sigma_deriv + sigma_inv = ve_sigma_inv + else: + assert schedule == 'linear' + sigma = lambda t: t + sigma_deriv = lambda t: 1 + sigma_inv = lambda sigma: sigma + + # Define scaling schedule. + if scaling == 'vp': + s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() + s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) + else: + assert scaling == 'none' + s = lambda t: 1 + s_deriv = lambda t: 0 + + # Compute final time steps based on the corresponding noise levels. + t_steps = sigma_inv(net.round_sigma(sigma_steps)) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + t_next = t_steps[0] + x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 + t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) + x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s( + t_hat) * S_noise * randn_like(x_cur) + + # Euler step. + h = t_next - t_hat + denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to( + torch.float64) + d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s( + t_hat) / sigma(t_hat) * denoised + x_prime = x_hat + alpha * h * d_cur + t_prime = t_hat + alpha * h + + # Apply 2nd order correction. + if solver == 'euler' or i == num_steps - 1: + x_next = x_hat + h * d_cur + else: + assert solver == 'heun' + denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to( + torch.float64) + d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv( + t_prime) * s(t_prime) / sigma(t_prime) * denoised + x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) + + return x_next diff --git a/src/diffusion/model/gaussian_diffusion.py b/src/diffusion/model/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1c0ccdb5a133a44fb968dddefe64f4145fd37e --- /dev/null +++ b/src/diffusion/model/gaussian_diffusion.py @@ -0,0 +1,1048 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import enum +import math + +import numpy as np +import torch as th +import torch.nn.functional as F + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + snr=False, + return_startx=False, + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.snr = snr + self.return_startx = return_startx + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + + if False: + target_resolution = 128 # 1024:128; 512:64; 256:32; + reference_resolution = 64 # Reference resolution (e.g., 64x64) + scaling_factor = (target_resolution / reference_resolution) ** 2 + print('scaling_factor', scaling_factor) + + # Adjust alphas and betas according to the scaling factor + alpha_cumprod_snr_shift = self.alphas_cumprod / (scaling_factor * (1 - self.alphas_cumprod) + self.alphas_cumprod) + alpha_cuspord_rmove1 = np.concatenate([np.ones([1]), alpha_cumprod_snr_shift[:999]]) + alpha_snr_shift = alpha_cumprod_snr_shift / alpha_cuspord_rmove1 + + betas_snr_shift = 1 - alpha_snr_shift + + # Update the class attributes with adjusted values + snr_ref = (self.alphas_cumprod / (1 - self.alphas_cumprod)) + snr_cur = (alpha_cumprod_snr_shift / (1 - alpha_cumprod_snr_shift)) + + self.betas = betas_snr_shift + self.alphas_cumprod = np.cumprod(alpha_snr_shift, axis=0) + + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + elif self.model_var_type in [ModelVarType.FIXED_LARGE, ModelVarType.FIXED_SMALL]: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + else: + model_variance = th.zeros_like(model_output) + model_log_variance = th.zeros_like(model_output) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False, loss_weight_mask=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + t = timestep + if model_kwargs is None: + model_kwargs = {} + if skip_noise: + x_t = x_start + else: + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + if isinstance(model_output, dict) and model_output.get('x', None) is not None: + output = model_output['x'] + else: + output = model_output + + if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON: + B, C = x_t.shape[:2] + assert output.shape == (B, C * 2, *x_t.shape[2:]) + output = th.split(output, C, dim=1)[0] + return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert output.shape == (B, C * 2, *x_t.shape[2:]) + output, model_var_values = th.split(output, C, dim=1) + # Learn the variance using the variational bound, but don't let it affect our mean prediction. + frozen_out = th.cat([output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out, **kwargs: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert output.shape == target.shape == x_start.shape + if self.snr: + if self.model_mean_type == ModelMeanType.START_X: + pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output) + pred_startx = output + elif self.model_mean_type == ModelMeanType.EPSILON: + pred_noise = output + pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output) + # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2) + # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2) + + t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32] + # best + target = th.where(t > 249, noise, x_start) + output = th.where(t > 249, pred_noise, pred_startx) + loss = (target - output) ** 2 + + # Changed by Yihang: for Face SR weighted loss: + if loss_weight_mask is not None: + weights = th.ones_like(loss) + weights[loss_weight_mask.expand_as(loss) == 1] = 5 + loss *= weights + + if model_kwargs.get('mask_ratio', False) and model_kwargs['mask_ratio'] > 0: + assert 'mask' in model_output + loss = F.avg_pool2d(loss.mean(dim=1), model.model.module.patch_size).flatten(1) + mask = model_output['mask'] + unmask = 1 - mask + terms['mse'] = mean_flat(loss * unmask) * unmask.shape[1]/unmask.sum(1) + if model_kwargs['mask_loss_coef'] > 0: + terms['mae'] = model_kwargs['mask_loss_coef'] * mean_flat(loss * mask) * mask.shape[1]/mask.sum(1) + else: + terms["mse"] = mean_flat(loss) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + if "mae" in terms: + terms["loss"] = terms["loss"] + terms["mae"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def training_losses_diffusers(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + t = timestep + if model_kwargs is None: + model_kwargs = {} + if skip_noise: + x_t = x_start + else: + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + output = model(x_t, timestep=t, **model_kwargs, return_dict=False)[0] + + if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON: + B, C = x_t.shape[:2] + assert output.shape == (B, C * 2, *x_t.shape[2:]) + output = th.split(output, C, dim=1)[0] + return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert output.shape == (B, C * 2, *x_t.shape[2:]) + output, model_var_values = th.split(output, C, dim=1) + # Learn the variance using the variational bound, but don't let it affect our mean prediction. + frozen_out = th.cat([output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out, **kwargs: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert output.shape == target.shape == x_start.shape + if self.snr: + if self.model_mean_type == ModelMeanType.START_X: + pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output) + pred_startx = output + elif self.model_mean_type == ModelMeanType.EPSILON: + pred_noise = output + pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output) + # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2) + # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2) + + t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32] + # best + target = th.where(t > 249, noise, x_start) + output = th.where(t > 249, pred_noise, pred_startx) + loss = (target - output) ** 2 + terms["mse"] = mean_flat(loss) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + if "mae" in terms: + terms["loss"] = terms["loss"] + terms["mae"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/src/diffusion/model/nets/PixArt.py b/src/diffusion/model/nets/PixArt.py new file mode 100644 index 0000000000000000000000000000000000000000..7e4249e6026334b463f4834bc2a94eebcdf7aaef --- /dev/null +++ b/src/diffusion/model/nets/PixArt.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import math +import torch +import torch.nn as nn +import os +import numpy as np +from timm.models.layers import DropPath +from timm.models.vision_transformer import PatchEmbed, Mlp + +from diffusion.model.builder import MODELS +from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from diffusion.model.nets.PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer + + +class PixArtBlock(nn.Module): + """ + A PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None, + sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionKVCompress( + hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, + qk_norm=qk_norm, **block_kwargs + ) + self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) + self.sampling = sampling + self.sr_ratio = sr_ratio + + def forward(self, x, y, t, mask=None, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +############################################################################# +# Core PixArt Model # +################################################################################# +@MODELS.register_module() +class PixArt(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path: float = 0., + caption_channels=4096, + pe_interpolation=1.0, + config=None, + model_max_length=120, + qk_norm=False, + kv_compress_config=None, + **kwargs, + ): + super().__init__() + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.pe_interpolation = pe_interpolation + self.depth = depth + self.hidden_size = hidden_size + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + num_patches = self.x_embedder.num_patches + self.base_size = input_size // self.patch_size + # Will use fixed sin-cos embedding: + self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) + + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.t_block = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, + act_layer=approx_gelu, token_num=model_max_length) + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + self.kv_compress_config = kv_compress_config + if kv_compress_config is None: + self.kv_compress_config = { + 'sampling': None, + 'scale_factor': 1, + 'kv_compress_layer': [], + } + self.blocks = nn.ModuleList([ + PixArtBlock( + hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], + input_size=(input_size // patch_size, input_size // patch_size), + sampling=self.kv_compress_config['sampling'], + sr_ratio=int( + self.kv_compress_config['scale_factor'] + ) if i in self.kv_compress_config['kv_compress_layer'] else 1, + qk_norm=qk_norm, + ) + for i in range(depth) + ]) + self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) + + self.initialize_weights() + + + def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + pos_embed = self.pos_embed.to(self.dtype) + self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep.to(x.dtype)) # (N, D) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + for block in self.blocks: + x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): + """ + dpm solver donnot need variance prediction + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + model_out = self.forward(x, timestep, y, mask) + return model_out.chunk(2, dim=1)[0] + + def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): + """ + Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, timestep, y, mask, kwargs) + model_out = model_out['x'] if isinstance(model_out, dict) else model_out + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5), + pe_interpolation=self.pe_interpolation, base_size=self.base_size + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + @property + def dtype(self): + return next(self.parameters()).dtype + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = to_2tuple(grid_size) + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / pe_interpolation + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / pe_interpolation + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# PixArt Configs # +################################################################################# +@MODELS.register_module() +def PixArt_XL_2(**kwargs): + return PixArt(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) diff --git a/src/diffusion/model/nets/PixArtMS.py b/src/diffusion/model/nets/PixArtMS.py new file mode 100644 index 0000000000000000000000000000000000000000..0fdd1f8b4876ae365d178d17753a9ce12d964d3c --- /dev/null +++ b/src/diffusion/model/nets/PixArtMS.py @@ -0,0 +1,386 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import torch +import torch.nn as nn +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp + +from diffusion.model.builder import MODELS +from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from diffusion.model.nets.PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, SizeEmbedder +from diffusion.model.nets.PixArt import PixArt, get_2d_sincos_pos_embed + +from torch.nn import Module, Linear, init + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.flatten = flatten + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class PixArtMSBlock(nn.Module): + """ + A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None, + sampling=None, sr_ratio=1, qk_norm=False, use_crossview_module=False, **block_kwargs): + super().__init__() + self.hidden_size = hidden_size + self.use_crossview_module = use_crossview_module + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # if use row attention, keep original self attention and add an extra attention layer to do row attention, + # otherwise, change orginal self attention to global self attention + self.attn = AttentionKVCompress( + hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, + qk_norm=qk_norm, **block_kwargs + ) + if self.use_crossview_module: + self.row_attn = AttentionKVCompress( + hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, + qk_norm=qk_norm, use_crossview_module=True, **block_kwargs + ) + self.norm_row_attn = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) + if self.use_crossview_module: + self.scale_shift_row_attn_table = nn.Parameter(torch.randn(3, hidden_size) / hidden_size ** 0.5) + + def forward(self, x, y, t, mask=None, HW=None, qkv_list=None, epipolar_constrains=None, cam_distances=None, n_views=None, **kwargs): + B, N, C = x.shape + + if qkv_list is not None: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW, qkv_cond=qkv_list)) + if self.use_crossview_module: + shift_ra, scale_ra, gate_ra = (self.scale_shift_row_attn_table[None] + t.reshape(B, 3, 2, -1).sum(dim=2)).chunk(3, dim=1) + x = x + self.drop_path(gate_ra * self.row_attn.forward_with_cross_view_optimized(t2i_modulate(self.norm_row_attn(x), shift_ra, scale_ra), HW=HW, qkv_cond=qkv_list, epipolar_constrains=epipolar_constrains, cam_distances=cam_distances, n_views=n_views)) + x = x + auto_grad_checkpoint(self.cross_attn, x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) + if self.use_crossview_module: + shift_ra, scale_ra, gate_ra = (self.scale_shift_row_attn_table[None] + t.reshape(B, 3, 2, -1).sum(dim=2)).chunk(3, dim=1) + x = x + self.drop_path(gate_ra * self.row_attn.forward_with_cross_view_optimized(t2i_modulate(self.norm_row_attn(x), shift_ra, scale_ra), HW=HW, epipolar_constrains=epipolar_constrains, cam_distances=cam_distances, n_views=n_views)) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +class PixArtMSBlockControl(nn.Module): + """ + A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, sr_ratio=2, **block_kwargs): + super().__init__() + self.hidden_size = hidden_size + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionKVCompress(hidden_size, num_heads=num_heads, qkv_bias=True, + sr_ratio=sr_ratio, return_qkv=True, **block_kwargs) + self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.window_size = window_size + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) + + def forward(self, x, y, t, mask=None, HW=None, **kwargs): + B, N, C = x.shape + H, W = HW + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) + x, qkv_list = self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW) + x = x + self.drop_path(gate_msa * x) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x, qkv_list + + +class ControlSRDitBlockHalf(Module): + def __init__(self, base_block: PixArtMSBlockControl, block_index: 0) -> None: + super().__init__() + self.copied_block = base_block + self.block_index = block_index + + self.hidden_size = hidden_size = base_block.hidden_size + if self.block_index == 0: + self.before_proj = Linear(hidden_size, hidden_size) + init.zeros_(self.before_proj.weight) + init.zeros_(self.before_proj.bias) + self.after_proj = Linear(hidden_size, hidden_size) + init.zeros_(self.after_proj.weight) + init.zeros_(self.after_proj.bias) + + def forward(self, x, y, t, mask=None, c=None, HW=None): + if self.block_index == 0: + # the first block + c = self.before_proj(c) + c, qkv_list = self.copied_block(x + c, y, t, mask, HW=HW) + c_skip = self.after_proj(c) + else: + # load from previous c and produce the c for skip connection + c, qkv_list = self.copied_block(c, y, t, mask, HW=HW) + c_skip = self.after_proj(c) + + return c, c_skip, qkv_list + + +############################################################################# +# Core PixArt Model # +################################################################################# +@MODELS.register_module() +class PixArtMS(PixArt): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + learn_sigma=True, + pred_sigma=True, + drop_path: float = 0., + caption_channels=4096, + pe_interpolation=1., + config=None, + model_max_length=120, + micro_condition=False, + qk_norm=False, + kv_compress_config=None, + use_crossview_module=False, + **kwargs, + ): + super().__init__( + input_size=input_size, + patch_size=patch_size, + in_channels=in_channels, + hidden_size=hidden_size, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + class_dropout_prob=class_dropout_prob, + learn_sigma=learn_sigma, + pred_sigma=pred_sigma, + drop_path=drop_path, + pe_interpolation=pe_interpolation, + config=config, + model_max_length=model_max_length, + qk_norm=qk_norm, + kv_compress_config=kv_compress_config, + **kwargs, + ) + self.h = self.w = 0 + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.t_block = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True) + self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=model_max_length) + self.micro_conditioning = micro_condition + if self.micro_conditioning: + self.csize_embedder = SizeEmbedder(hidden_size//3) # c_size embed + self.ar_embedder = SizeEmbedder(hidden_size//3) # aspect ratio embed + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + if kv_compress_config is None: + kv_compress_config = { + 'sampling': None, + 'scale_factor': 1, + 'kv_compress_layer': [], + } + self.blocks = nn.ModuleList([ + PixArtMSBlock( + hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], + input_size=(input_size // patch_size, input_size // patch_size), + sampling=kv_compress_config['sampling'], + sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1, + qk_norm=qk_norm, + use_crossview_module=use_crossview_module, + ) + for i in range(depth) + ]) + self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) + + self.initialize() + + def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + bs = x.shape[0] + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size + pos_embed = torch.from_numpy( + get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation, + base_size=self.base_size + ) + ).unsqueeze(0).to(x.device).to(self.dtype) + + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep) # (N, D) + + if self.micro_conditioning: + c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype) + csize = self.csize_embedder(c_size, bs) # (N, D) + ar = self.ar_embedder(ar, bs) # (N, D) + t = t + torch.cat([csize, ar], dim=1) + + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, D) + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + for block in self.blocks: + x = auto_grad_checkpoint(block, x, y, t0, y_lens, (self.h, self.w), **kwargs) # (N, T, D) #support grad checkpoint + + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + + return x + + def forward_with_dpmsolver(self, x, timestep, y, data_info, **kwargs): + """ + dpm solver donnot need variance prediction + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs) + return model_out.chunk(2, dim=1)[0] + + def forward_with_cfg(self, x, timestep, y, cfg_scale, data_info, mask=None, **kwargs): + """ + Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, timestep, y, mask, data_info=data_info, **kwargs) + model_out = model_out['x'] if isinstance(model_out, dict) else model_out + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + assert self.h * self.w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p)) + return imgs + + def initialize(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + if self.micro_conditioning: + nn.init.normal_(self.csize_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.csize_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.ar_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.ar_embedder.mlp[2].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + if hasattr(block, 'row_attn'): + nn.init.constant_(block.row_attn.proj.weight, 0) + nn.init.constant_(block.row_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + +################################################################################# +# PixArt Configs # +################################################################################# +@MODELS.register_module() +def PixArtMS_XL_2(**kwargs): + return PixArtMS(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) \ No newline at end of file diff --git a/src/diffusion/model/nets/PixArt_blocks.py b/src/diffusion/model/nets/PixArt_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..203a53f98061c721e291d3aaacc4bade46dd3e82 --- /dev/null +++ b/src/diffusion/model/nets/PixArt_blocks.py @@ -0,0 +1,628 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import xformers.ops +from einops import rearrange +from timm.models.vision_transformer import Mlp, Attention as Attention_ + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def t2i_modulate(x, shift, scale): + return x * (1 + scale) + shift + + +def batch_cosine_sim(x, y): + if type(x) is list: + x = torch.cat(x, dim=0) + if type(y) is list: + y = torch.cat(y, dim=0) + x = x / x.norm(dim=-1, keepdim=True) + y = y / y.norm(dim=-1, keepdim=True) + + y = rearrange(y, "b n c -> b c n") + + similarity = x @ y + return similarity + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., **block_kwargs): + super(MultiHeadCrossAttention, self).__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_model = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + + self.q_linear = nn.Linear(d_model, d_model) + self.kv_linear = nn.Linear(d_model, d_model*2) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(d_model, d_model) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + B, N, C = x.shape + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + attn_bias = None + if mask is not None: + attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + x = x.view(B, -1, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class AttentionKVCompress(Attention_): + """Multi-head Attention block with KV token compression and qk norm.""" + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + sampling='conv', + sr_ratio=1, + qk_norm=False, + return_qkv=False, + use_crossview_module=False, + **block_kwargs, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + """ + super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, **block_kwargs) + + self.sampling = sampling # ['conv', 'ave', 'uniform', 'uniform_every'] + self.sr_ratio = sr_ratio + self.return_qkv = return_qkv + self.use_crossview_module = use_crossview_module + + if sr_ratio > 1 and sampling == 'conv': + # Avg Conv Init. + self.sr = nn.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio) + self.sr.weight.data.fill_(1/sr_ratio**2) + self.sr.bias.data.zero_() + self.norm = nn.LayerNorm(dim) + if qk_norm: + self.q_norm = nn.LayerNorm(dim) + self.k_norm = nn.LayerNorm(dim) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + + self.key_frames_dict = dict() + + def downsample_2d(self, tensor, H, W, scale_factor, sampling=None): + if sampling is None or scale_factor == 1: + return tensor + B, N, C = tensor.shape + + if sampling == 'uniform_every': + return tensor[:, ::scale_factor], int(N // scale_factor) + + tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2) + new_H, new_W = int(H / scale_factor), int(W / scale_factor) + new_N = new_H * new_W + + if sampling == 'ave': + tensor = F.interpolate( + tensor, scale_factor=1 / scale_factor, mode='nearest' + ).permute(0, 2, 3, 1) + elif sampling == 'uniform': + tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1) + elif sampling == 'conv': + tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1) + tensor = self.norm(tensor) + else: + raise ValueError + + return tensor.reshape(B, new_N, C).contiguous(), new_N + + def forward(self, x, mask=None, HW=None, block_id=None, qkv_cond=None, n_views=None): + if self.use_crossview_module: + # for multi-view row attention + h = int((x.shape[1])**0.5) + x = rearrange(x, "(b v) (h w) c -> (b h) (v w) c", v=n_views, h=h) + + B, N, C = x.shape + if HW is None: + H = W = int(N ** 0.5) + else: + H, W = HW + qkv = self.qkv(x).reshape(B, N, 3, C) + q, k, v = qkv.unbind(2) + dtype = q.dtype + q = self.q_norm(q) + k = self.k_norm(k) + + new_N = N + # KV compression + if self.sr_ratio > 1: + k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling) + v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling) + + q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype) + k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + + use_fp32_attention = getattr(self, 'fp32_attention', False) # necessary for NAN loss + + if qkv_cond is not None: + assert mask is None + if use_fp32_attention: + q, k, v = q.float(), k.float(), v.float() + qkv_cond = [item.float() for item in qkv_cond] + + v = v + qkv_cond[2] + attn_bias = None + x_temp = xformers.ops.memory_efficient_attention(qkv_cond[1], k, v, p=self.attn_drop.p, attn_bias=attn_bias) + x = xformers.ops.memory_efficient_attention(q, qkv_cond[0], x_temp, p=self.attn_drop.p, attn_bias=attn_bias) + else: + if use_fp32_attention: + q, k, v = q.float(), k.float(), v.float() + + attn_bias = None + if mask is not None: + attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) + attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf')) + + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + + x = x.view(B, N, C) + + if self.use_crossview_module: + x = rearrange(x, "(b h) (v w) c -> (b v) (h w) c", v=n_views, h=h) + + x = self.proj(x) + x = self.proj_drop(x) + if self.return_qkv: + return x, [v, k, q] + else: + return x + + def forward_with_cross_view(self, x, mask=None, HW=None, block_id=None, qkv_cond=None, epipolar_constrains=None, cam_distances=None, n_views=None): + B, N, C = x.shape # (b v) (h w) c + h = int(N**0.5) + + # get multi-view row attention results + if self.return_qkv: + x, [v, k, q] = self.forward(x, mask, HW, block_id, qkv_cond, n_views=n_views) # (b v) (h w) c + else: + x = self.forward(x, mask, HW, block_id, qkv_cond, n_views=n_views) # (b v) (h w) c + + x = rearrange(x, "(b v) (h w) c -> b v (h w) c", v=n_views, h=h) + epipolar_constrains = rearrange(epipolar_constrains, "(b v) kv ... -> b v kv ...", v=n_views, kv=2) + cam_distances = rearrange(cam_distances, "(b v) kv -> b v kv", v=n_views, kv=2) + + # get near-view aggragation results + x_agg = x.clone() + for i in range(n_views): + # near two views are the key views + kv_idx = [(i-1)%n_views, (i+1)%n_views] + + nv = x_agg[:, [i]] # b 1 (h w) c + kv = x_agg[:, kv_idx] # b 2 (h w) c + + # sim: b (1 h w) (2 h w) + with torch.no_grad(): + sim = batch_cosine_sim( + rearrange(nv, "b k (h w) c -> b (k h w) c", h=h, k=1), + rearrange(kv, "b k (h w) c -> b (k h w) c", h=h, k=2) + ) + + sims = sim.chunk(2, dim=2) # [b 1hw 1hw, b 1hw 1hw] + + idxs = [] + sim_l = [] + for j, sim in enumerate(sims): + idx_epipolar = epipolar_constrains[:, i, j] # b hw hw + sim[idx_epipolar] = 0 + sim, sim_idx = sim.max(dim=-1) # b 1hw + + sim = (sim + 1.) / 2. + sim_l.append(((sim)).view(-1, 1 * N, 1).repeat(1, 1, C)) # b 1hw c + idxs.append(sim_idx.view(-1, 1 * N, 1).repeat(1, 1, C)) # b 1hw c + + attn_1, attn_2 = kv[:, 0], kv[:, 1] + attn_output1 = attn_1.gather(dim=1, index=idxs[0]) # b 1hw c + attn_output2 = attn_2.gather(dim=1, index=idxs[1]) # b 1hw c + + d1 = cam_distances[:, i, 0] # b + d2 = cam_distances[:, i, 1] # b + w1 = d2 / (d1 + d2) + w1 = (w1.unsqueeze(-1).unsqueeze(-1)).to(attn_output1.dtype) + + w1 = (w1 * sim_l[0]) / (w1 * sim_l[0] + (1-w1) * sim_l[1]) + + nv_output = w1 * attn_output1 + (1-w1) * attn_output2 + nv_output = rearrange(nv_output, "b (k h w) c -> b k (h w) c", k=1, h=h) # b 1 hw c + + x_agg[:, [i]] = nv + (nv_output - nv).detach() + + x = (x_agg + x) / 2. + x = rearrange(x, "b v (h w) c -> (b v) (h w) c", v=n_views, h=h) + + if self.return_qkv: + return x, [v, k, q] + else: + return x + + + def forward_with_cross_view_optimized(self, x, mask=None, HW=None, block_id=None, qkv_cond=None, epipolar_constrains=None, cam_distances=None, n_views=None): + B, N, C = x.shape # (b v) (h w) c + h = int(N**0.5) + + # get multi-view row attention results + if self.return_qkv: + x, [v, k, q] = self.forward(x, mask, HW, block_id, qkv_cond, n_views=n_views) # (b v) (h w) c + else: + x = self.forward(x, mask, HW, block_id, qkv_cond, n_views=n_views) # (b v) (h w) c + + x = rearrange(x, "(b v) (h w) c -> b v (h w) c", v=n_views, h=h) + epipolar_constrains = rearrange(epipolar_constrains, "(b v) kv ... -> b v kv ...", v=n_views, kv=2) + cam_distances = rearrange(cam_distances, "(b v) kv -> b v kv", v=n_views, kv=2) + + # get near-view aggragation results + x_agg = x.clone() + for i in range(n_views): + # near two views are the key views + kv_idx = [(i-1)%n_views, (i+1)%n_views] + + nv = x_agg[:, [i]] # b 1 (h w) c + kv = x_agg[:, kv_idx] # b 2 (h w) c + + # sim: b (1 h w) (2 h w) + with torch.no_grad(): + sim = batch_cosine_sim( + rearrange(nv, "b k (h w) c -> b (k h w) c", h=h, k=1), + rearrange(kv, "b k (h w) c -> b (k h w) c", h=h, k=2) + ) + + sim = sim.chunk(2, dim=2) # [b 1hw 1hw, b 1hw 1hw] + sim = torch.stack(sim, dim=1) # b 2 hw hw + + idx_epipolar = epipolar_constrains[:, i, :] # b 2 hw hw + sim[idx_epipolar] = 0 + + sim, sim_idx = sim.max(dim=-1) # b 2 hw + sim = (sim + 1.) / 2. + + sim = sim.unsqueeze(-1).repeat(1, 1, 1, C) # b 2 1hw c + idx = sim_idx.unsqueeze(-1).repeat(1, 1, 1, C) # b 2 1hw c + + attn_output1 = kv[:, 0].gather(dim=1, index=idx[:, 0]) # b 1hw c + attn_output2 = kv[:, 1].gather(dim=1, index=idx[:, 1]) # b 1hw c + + d1 = cam_distances[:, i, 0] # b + d2 = cam_distances[:, i, 1] # b + w1 = d2 / (d1 + d2) + w1 = w1.unsqueeze(-1).unsqueeze(-1).to(attn_output1.dtype) + w1 = (w1 * sim[:, 0]) / (w1 * sim[:, 0] + (1-w1) * sim[:, 1]) + + nv_output = w1 * attn_output1 + (1-w1) * attn_output2 + nv_output = rearrange(nv_output, "b (k h w) c -> b k (h w) c", k=1, h=h) # b 1 hw c + + x_agg[:, [i]] = nv + (nv_output - nv).detach() + + x = (x_agg + x) / 2. + x = rearrange(x, "b v (h w) c -> (b v) (h w) c", v=n_views, h=h) + + if self.return_qkv: + return x, [v, k, q] + else: + return x + + +################################################################################# +# AMP attention with fp32 softmax to fix loss NaN problem during training # +################################################################################# +class Attention(Attention_): + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + use_fp32_attention = getattr(self, 'fp32_attention', False) + if use_fp32_attention: + q, k = q.float(), k.float() + with torch.cuda.amp.autocast(enabled=not use_fp32_attention): + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class T2IFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5) + self.out_channels = out_channels + + def forward(self, x, t): + shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + x = t2i_modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class MaskFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True) + ) + def forward(self, x, t): + shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DecoderLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, decoder_hidden_size): + super().__init__() + self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + def forward(self, x, t): + shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) + x = modulate(self.norm_decoder(x), shift, scale) + x = self.linear(x) + return x + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(self.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + @property + def dtype(self): + # 返回模型参数的数据类型 + return next(self.parameters()).dtype + + +class SizeEmbedder(TimestepEmbedder): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size) + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + self.outdim = hidden_size + + def forward(self, s, bs): + if s.ndim == 1: + s = s[:, None] + assert s.ndim == 2 + if s.shape[0] != bs: + s = s.repeat(bs//s.shape[0], 1) + assert s.shape[0] == bs + b, dims = s.shape[0], s.shape[1] + s = rearrange(s, "b d -> (b d)") + s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype) + s_emb = self.mlp(s_freq) + s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) + return s_emb + + @property + def dtype(self): + # 返回模型参数的数据类型 + return next(self.parameters()).dtype + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120): + super().__init__() + self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0) + self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5)) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return caption + + def forward(self, caption, train, force_drop_ids=None): + if train: + assert caption.shape[2:] == self.y_embedding.shape + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption = self.token_drop(caption, force_drop_ids) + caption = self.y_proj(caption) + return caption + + +class CaptionEmbedderDoubleBr(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120): + super().__init__() + self.proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0) + self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5) + self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5) + self.uncond_prob = uncond_prob + + def token_drop(self, global_caption, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption) + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return global_caption, caption + + def forward(self, caption, train, force_drop_ids=None): + assert caption.shape[2: ] == self.y_embedding.shape + global_caption = caption.mean(dim=2).squeeze() + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids) + y_embed = self.proj(global_caption) + return y_embed, caption \ No newline at end of file diff --git a/src/diffusion/model/nets/__init__.py b/src/diffusion/model/nets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e29b9a04cd7eacb53db8c59da218af3e8ef34364 --- /dev/null +++ b/src/diffusion/model/nets/__init__.py @@ -0,0 +1,4 @@ +from .encoder import Encoder, MVEncoder, ResnetBlock +from .PixArt import PixArt, PixArt_XL_2 +from .PixArtMS import PixArtMS, PixArtMS_XL_2, PixArtMSBlock +from .pixart_controlnet import ControlPixArtHalf, ControlPixArtMSHalf, ControlPixArtMSMVHalfWithEncoder diff --git a/src/diffusion/model/nets/encoder.py b/src/diffusion/model/nets/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..afe729c9f5997cca6862b0b3dfd2110b3609561d --- /dev/null +++ b/src/diffusion/model/nets/encoder.py @@ -0,0 +1,656 @@ +from typing import Optional, Any +from inspect import isfunction +import numbers + +import torch +import torch.nn as nn +from torch import einsum +import torch.nn.functional as F +from einops import rearrange, repeat + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +def exists(val): + return val is not None + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states.to(input_dtype) + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, enable_rmsnorm=False, qk_norm=False): + super().__init__() + # print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + # f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + + # if enable_rmsnorm: + # self.q_rmsnorm = RMSNorm(query_dim, eps=1e-5) + # self.k_rmsnorm = RMSNorm(context_dim, eps=1e-5) + + self.q_norm = RMSNorm(self.dim_head, elementwise_affine=True, eps=1e-5) if qk_norm else nn.Identity() + self.k_norm = RMSNorm(self.dim_head, elementwise_affine=True, eps=1e-5) if qk_norm else nn.Identity() + + # self.enable_rmsnorm = enable_rmsnorm + + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + # self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + # self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + # self.attention_op: Optional[Any] = MemoryEfficientAttentionFlashAttentionOp + + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + q, k = self.q_norm(q), self.k_norm(k) # for stable amp training + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention + } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + # return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + return self._forward(x, context) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class BasicTransformerBlock3D(BasicTransformerBlock): + + def forward(self, x, context=None, num_frames=1): + # return checkpoint(self._forward, (x, context, num_frames), self.parameters(), self.checkpoint) + return self._forward(x, context, num_frames) # , self.parameters(), self.checkpoint + + def _forward(self, x, context=None, num_frames=1): + x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous() + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous() + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer3D(nn.Module): + ''' 3D self-attention ''' + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + elif context_dim is None: + context_dim = [None] * depth + + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock3D(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None, num_frames=1): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i], num_frames=num_frames) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none", "mv-vanilla"], f'attn_type {attn_type} unknown' + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "mv-vanilla": + assert attn_kwargs is not None + return SpatialTransformer3D(in_channels, **attn_kwargs) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=0): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb=None): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, + z_channels, double_z=True, + use_linear_attn=False, attn_type="vanilla", + attn_kwargs={}, + z_downsample_size=1, + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type, attn_kwargs=attn_kwargs)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, attn_kwargs=attn_kwargs) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=z_downsample_size, + padding=1) + + def forward(self, x, **kwargs): + # timestep embedding + temb = None + + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h, temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions-1: + h = (self.down[i_level].downsample(h)) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class MVEncoder(Encoder): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="mv-vanilla", z_downsample_size=1, **ignore_kwargs): + super().__init__(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, in_channels=in_channels, resolution=resolution, z_channels=z_channels, double_z=double_z, use_linear_attn=use_linear_attn, attn_type=attn_type, + z_downsample_size=z_downsample_size, + add_fusion_layer=False, + **ignore_kwargs) + + def forward(self, x, n_views): + return super().forward(x, num_frames=n_views) \ No newline at end of file diff --git a/src/diffusion/model/nets/pixart_controlnet.py b/src/diffusion/model/nets/pixart_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..431f29d0fe63f9ed79fcfa3645cfc043040e7b47 --- /dev/null +++ b/src/diffusion/model/nets/pixart_controlnet.py @@ -0,0 +1,376 @@ +import re +import torch +import torch.nn as nn + +from copy import deepcopy +from torch import Tensor +from torch.nn import Module, Linear, init +from typing import Any, Mapping + +from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt, MVEncoder +from diffusion.model.nets.PixArt import get_2d_sincos_pos_embed +from diffusion.model.utils import auto_grad_checkpoint + + +# The implementation of ControlNet-Half architrecture +# https://github.com/lllyasviel/ControlNet/discussions/188 +class ControlT2IDitBlockHalf(Module): + def __init__(self, base_block: PixArtMSBlock, block_index: 0, zero_init=True, base_size=None) -> None: + super().__init__() + self.copied_block = deepcopy(base_block) + self.block_index = block_index + + for p in self.copied_block.parameters(): + p.requires_grad_(True) + + self.copied_block.load_state_dict(base_block.state_dict()) + self.copied_block.train() + + self.hidden_size = hidden_size = base_block.hidden_size + if self.block_index == 0: + self.before_proj = Linear(hidden_size, hidden_size) + # we still keep the before_proj as zero initialed + init.zeros_(self.before_proj.weight) + init.zeros_(self.before_proj.bias) + self.after_proj = Linear(hidden_size, hidden_size) + if zero_init: + init.zeros_(self.after_proj.weight) + init.zeros_(self.after_proj.bias) + + def forward(self, x, y, t, mask=None, c=None, epipolar_constrains=None, cam_distances=None, n_views=None): + if self.block_index == 0: + # the first block + c = self.before_proj(c) + c = self.copied_block(x + c, y, t, mask, epipolar_constrains=epipolar_constrains, cam_distances=cam_distances, n_views=n_views) + c_skip = self.after_proj(c) + else: + # load from previous c and produce the c for skip connection + c = self.copied_block(c, y, t, mask, epipolar_constrains=epipolar_constrains, cam_distances=cam_distances, n_views=n_views) + c_skip = self.after_proj(c) + + return c, c_skip + + +# The implementation of ControlPixArtHalf net +class ControlPixArtHalf(Module): + # only support single res model + def __init__(self, base_model: PixArt, copy_blocks_num: int = 13) -> None: + super().__init__() + self.base_model = base_model.eval() + self.controlnet = [] + self.copy_blocks_num = copy_blocks_num + self.total_blocks_num = len(base_model.blocks) + for p in self.base_model.parameters(): + p.requires_grad_(False) + + # Copy first copy_blocks_num block + for i in range(copy_blocks_num): + self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i)) + self.controlnet = nn.ModuleList(self.controlnet) + + def __getattr__(self, name: str) -> Tensor or Module: + if name in [ + 'base_model', + 'controlnet', + 'encoder', + 'controlnet_t_block', + 'noise_embedding', + ]: + return super().__getattr__(name) + else: + return getattr(self.base_model, name) + + def forward_c(self, c): + self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size + pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype) + return self.x_embedder(c) + pos_embed if c is not None else c + + # def forward(self, x, t, c, **kwargs): + # return self.base_model(x, t, c=self.forward_c(c), **kwargs) + def forward(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs): + # modify the original PixArtMS forward function + if c is not None: + c = c.to(self.dtype) + c = self.forward_c(c) + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + pos_embed = self.pos_embed.to(self.dtype) + self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep.to(x.dtype)) # (N, D) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # define the first layer + x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint + + if c is not None: + # update c + for index in range(1, self.copy_blocks_num + 1): + c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs) + x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs) + + # update x + for index in range(self.copy_blocks_num + 1, self.total_blocks_num): + x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) + else: + for index in range(1, self.total_blocks_num): + x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) + + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs): + model_out = self.forward(x, t, y, data_info=data_info, c=c, **kwargs) + return model_out.chunk(2, dim=1)[0] + + def forward_with_cfg(self, x, timestep, y, cfg_scale, data_info, c, **kwargs): + """ + Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, timestep, y, data_info=data_info, c=c) + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + if all((k.startswith(('base_model', 'controlnet', 'encoder', 'controlnet_t_block', 'noise_embedding'))) for k in state_dict.keys()): + return super().load_state_dict(state_dict, strict) + else: + new_key = {} + for k in state_dict.keys(): + new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k) + for k, v in new_key.items(): + if k != v: + print(f"replace {k} to {v}") + state_dict[v] = state_dict.pop(k) + + return self.base_model.load_state_dict(state_dict, strict) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + assert self.h * self.w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p)) + return imgs + + @property + def dtype(self): + return next(self.parameters()).dtype + + +# The implementation for PixArtMS_Half + 1024 resolution +class ControlPixArtMSHalf(ControlPixArtHalf): + # support multi-scale res model (multi-scale model can also be applied to single reso training & inference) + def __init__(self, base_model: PixArtMS, copy_blocks_num: int = 13) -> None: + super().__init__(base_model=base_model, copy_blocks_num=copy_blocks_num) + + def forward(self, x, timestep, y, mask=None, data_info=None, c=None, need_forward_c=True, **kwargs): + # modify the original PixArtMS forward function + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + if c is not None and need_forward_c: + c = c.to(self.dtype) + c = self.forward_c(c) + + bs = x.shape[0] + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size + pos_embed = torch.from_numpy( + get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation, + base_size=self.base_size + ) + ).unsqueeze(0).to(x.device).to(self.dtype) + + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep) # (N, D) + + if self.micro_conditioning: + c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype) + csize = self.csize_embedder(c_size, bs) # (N, D) + ar = self.ar_embedder(ar, bs) # (N, D) + t = t + torch.cat([csize, ar], dim=1) + + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, D) + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + y_lens = [int(item) for item in y_lens] + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # define the first layer + x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint + + if c is not None: + # update c + for index in range(1, self.copy_blocks_num + 1): + c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs) + x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs) + + # update x + for index in range(self.copy_blocks_num + 1, self.total_blocks_num): + x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) + else: + for index in range(1, self.total_blocks_num): + x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) + + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + +# 3DEnhancer Backbone +class ControlPixArtMSMVHalfWithEncoder(ControlPixArtMSHalf): + def __init__(self, base_model: PixArtMS, copy_blocks_num: int = 13) -> None: + super().__init__(base_model=base_model, copy_blocks_num=copy_blocks_num) + + self.encoder = MVEncoder( + double_z=False, + resolution=512, + in_channels=9, + ch=64, + ch_mult=[1, 2, 4, 4], + num_res_blocks=1, + dropout=0.0, + attn_resolutions=[], + out_ch=3, # unused + z_channels=self.hidden_size, + attn_kwargs = { + 'n_heads': 8, + 'd_head': 64, + }, + z_downsample_size=2, + ) + + self.noise_embedding = nn.Embedding(500, self.hidden_size) + self.noise_embedding.weight.data.fill_(0) + + self.controlnet_t_block = nn.Sequential( + nn.SiLU(), + nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True) + ) + + self.attetion_token_num = self.base_size**2 + + def encode(self, input_img, camera_pose, n_views): + # fuse this two on 2nd dim + # input_img: b3hw, camera_pose: b6hw (b%4==0) + z_lq = torch.cat((input_img, camera_pose), dim=1) + z_lq = self.encoder(z_lq, n_views) + z_lq = z_lq.permute(0, 2, 3, 1).reshape(-1, self.attetion_token_num, self.hidden_size) + + return z_lq + + def forward(self, x, timestep, y, mask=None, data_info=None, input_img=None, camera_pose=None, c=None, noise_level=None, epipolar_constrains=None, cam_distances=None, n_views=None, **kwargs): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + + c = self.encode(input_img, camera_pose, n_views).to(x.dtype) if c is None else c + + bs = x.shape[0] + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size + pos_embed = torch.from_numpy( + get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation, + base_size=self.base_size + ) + ).unsqueeze(0).to(x.device).to(self.dtype) + + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep) # (N, D) + + noise_level = self.noise_embedding(noise_level) + controlnet_t = t + noise_level + + if self.micro_conditioning: + c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype) + csize = self.csize_embedder(c_size, bs) # (N, D) + ar = self.ar_embedder(ar, bs) # (N, D) + t = t + torch.cat([csize, ar], dim=1) + + t0 = self.t_block(t) + controlnet_t0 = self.controlnet_t_block(controlnet_t) + y = self.y_embedder(y, self.training) # (N, D) + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + y_lens = [int(item) for item in y_lens] + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, None, None, epipolar_constrains, cam_distances, n_views, **kwargs) # (N, T, D) #support grad checkpoint + + if c is not None: + # update c + for index in range(1, self.copy_blocks_num + 1): + c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, controlnet_t0, y_lens, c, epipolar_constrains=epipolar_constrains, cam_distances=cam_distances, n_views=n_views, **kwargs) + x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, None, None, epipolar_constrains, cam_distances, n_views, **kwargs) + + # update x + for index in range(self.copy_blocks_num + 1, self.total_blocks_num): + x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, None, None, epipolar_constrains, cam_distances, n_views, **kwargs) + else: + for index in range(1, self.total_blocks_num): + x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, None, None, epipolar_constrains, cam_distances, n_views, **kwargs) + + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_dpmsolver(self, x, t, y, data_info, c, noise_level, epipolar_constrains, cam_distances, n_views, **kwargs): + model_out = self.forward(x, t, y, data_info=data_info, c=c, noise_level=noise_level, epipolar_constrains=epipolar_constrains, cam_distances=cam_distances, n_views=n_views, **kwargs) + return model_out.chunk(2, dim=1)[0] \ No newline at end of file diff --git a/src/diffusion/model/nets/sft.py b/src/diffusion/model/nets/sft.py new file mode 100644 index 0000000000000000000000000000000000000000..483b01a6ef7c6adc6e17e916aaffe500bf1b840f --- /dev/null +++ b/src/diffusion/model/nets/sft.py @@ -0,0 +1,79 @@ +# From https://github.com/Fanghua-Yu/SUPIR/blob/master/SUPIR/modules/SUPIR_v0.py + +import torch +import torch as th +import torch.nn as nn + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + # return super().forward(x.float()).type(x.dtype) + return super().forward(x) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class ZeroSFT(nn.Module): + def __init__(self, label_nc, norm_nc, nhidden=128, norm=True, mask=False, zero_init=True): + super().__init__() + + # param_free_norm_type = str(parsed.group(1)) + ks = 3 + pw = ks // 2 + + self.norm = norm + if self.norm: + self.param_free_norm = normalization(norm_nc) + else: + self.param_free_norm = nn.Identity() + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), + nn.SiLU() + ) + + if zero_init: + self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)) + self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)) + else: + self.zero_mul = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + self.zero_add = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + + def forward(self, c, h, control_scale=1): + h_raw = h + actv = self.mlp_shared(c) + gamma = self.zero_mul(actv) + beta = self.zero_add(actv) + h = self.param_free_norm(h) * (gamma + 1) + beta + + return h * control_scale + h_raw * (1 - control_scale) \ No newline at end of file diff --git a/src/diffusion/model/respace.py b/src/diffusion/model/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..61cd9dfb329741399f6ab7d0d3e7394bcb2c3112 --- /dev/null +++ b/src/diffusion/model/respace.py @@ -0,0 +1,134 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def training_losses_diffusers( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, timestep, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype) + new_ts = map_tensor[timestep] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, timestep=new_ts, **kwargs) diff --git a/src/diffusion/model/sa_solver.py b/src/diffusion/model/sa_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..e51cdc8c8b941f68682b9bac576d0d35a496ea7a --- /dev/null +++ b/src/diffusion/model/sa_solver.py @@ -0,0 +1,1149 @@ +import torch +import torch.nn.functional as F +import math +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32, + ): + """Thanks to DPM-Solver for their code base""" + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + def edm_sigma(self, t): + return self.marginal_std(t) / self.marginal_alpha(t) + + def edm_inverse_sigma(self, edmsigma): + alpha = 1 / (edmsigma ** 2 + 1).sqrt() + sigma = alpha * edmsigma + lambda_t = torch.log(alpha / sigma) + t = self.inverse_lambda(lambda_t) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Thanks to DPM-Solver for their code base""" + """Create a wrapper function for the noise prediction model. + SA-Solver needs to solve the continuous-time diffusion SDEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + We support four types of the diffusion model by setting `model_type`: + 1. "noise": noise prediction model. (Trained by predicting noise). + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for SA-Solver. + =============================================================== + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - alpha_t[0] * output) / sigma_t[0] + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return alpha_t[0] * output + sigma_t[0] * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -sigma_t[0] * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * sigma_t * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class SASolver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="data_prediction", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995 + ): + """ + Construct a SA-Solver + The default value for algorithm_type is "data_prediction" and we recommend not to change it to + "noise_prediction". For details, please see Appendix A.2.4 in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["data_prediction", "noise_prediction"] + + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + self.predict_x0 = algorithm_type == "data_prediction" + + self.sigma_min = float(self.noise_schedule.edm_sigma(torch.tensor([1e-3]))) + self.sigma_max = float(self.noise_schedule.edm_sigma(torch.tensor([1]))) + + def dynamic_thresholding_fn(self, x0, t=None): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, order, device): + """Compute the intermediate time steps for sampling. + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = lambda_T + torch.linspace(torch.tensor(0.).cpu().item(), + (lambda_0 - lambda_T).cpu().item() ** (1. / order), N + 1).pow( + order).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time': + t = torch.linspace(t_T ** (1. / order), t_0 ** (1. / order), N + 1).pow(order).to(device) + return t + elif skip_type == 'karras': + sigma_min = max(0.002, self.sigma_min) + sigma_max = min(80, self.sigma_max) + sigma_steps = torch.linspace(sigma_max ** (1. / 7), sigma_min ** (1. / 7), N + 1).pow(7).to(device) + t = self.noise_schedule.edm_inverse_sigma(sigma_steps) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time' or 'karras'".format(skip_type)) + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def get_coefficients_exponential_negative(self, order, interval_start, interval_end): + """ + Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end + For calculating the coefficient of gradient terms after the lagrange interpolation, + see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + For noise_prediction formula. + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + if order == 0: + return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1) + elif order == 1: + return torch.exp(-interval_end) * ( + (interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1)) + elif order == 2: + return torch.exp(-interval_end) * ( + (interval_start ** 2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) - ( + interval_end ** 2 + 2 * interval_end + 2)) + elif order == 3: + return torch.exp(-interval_end) * ( + (interval_start ** 3 + 3 * interval_start ** 2 + 6 * interval_start + 6) * torch.exp( + interval_end - interval_start) - (interval_end ** 3 + 3 * interval_end ** 2 + 6 * interval_end + 6)) + + def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau): + """ + Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end + For calculating the coefficient of gradient terms after the lagrange interpolation, + see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + For data_prediction formula. + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + # after change of variable(cov) + interval_end_cov = (1 + tau ** 2) * interval_end + interval_start_cov = (1 + tau ** 2) * interval_start + + if order == 0: + return torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / ( + (1 + tau ** 2)) + elif order == 1: + return torch.exp(interval_end_cov) * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 2) + elif order == 2: + return torch.exp(interval_end_cov) * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) - ( + interval_start_cov ** 2 - 2 * interval_start_cov + 2) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 3) + elif order == 3: + return torch.exp(interval_end_cov) * ( + (interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) - ( + interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 4) + + def lagrange_polynomial_coefficient(self, order, lambda_list): + """ + Calculate the coefficient of lagrange polynomial + For lagrange interpolation + """ + assert order in [0, 1, 2, 3] + assert order == len(lambda_list) - 1 + if order == 0: + return [[1]] + elif order == 1: + return [[1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])], + [1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]] + elif order == 2: + denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) + denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) + denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) + return [[1 / denominator1, + (-lambda_list[1] - lambda_list[2]) / denominator1, + lambda_list[1] * lambda_list[2] / denominator1], + + [1 / denominator2, + (-lambda_list[0] - lambda_list[2]) / denominator2, + lambda_list[0] * lambda_list[2] / denominator2], + + [1 / denominator3, + (-lambda_list[0] - lambda_list[1]) / denominator3, + lambda_list[0] * lambda_list[1] / denominator3] + ] + elif order == 3: + denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * ( + lambda_list[0] - lambda_list[3]) + denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * ( + lambda_list[1] - lambda_list[3]) + denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * ( + lambda_list[2] - lambda_list[3]) + denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * ( + lambda_list[3] - lambda_list[2]) + return [[1 / denominator1, + (-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1, + (lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[ + 3]) / denominator1, + (-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1], + + [1 / denominator2, + (-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2, + (lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[ + 3]) / denominator2, + (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2], + + [1 / denominator3, + (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3, + (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[ + 3]) / denominator3, + (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3], + + [1 / denominator4, + (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4, + (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[ + 2]) / denominator4, + (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4] + + ] + + def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau): + """ + Calculate the coefficient of gradients. + """ + assert order in [1, 2, 3, 4] + assert order == len(lambda_list), 'the length of lambda list must be equal to the order' + coefficients = [] + lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list) + for i in range(order): + coefficient = 0 + for j in range(order): + if self.predict_x0: + coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_positive( + order - 1 - j, interval_start, interval_end, tau) + else: + coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_negative( + order - 1 - j, interval_start, interval_end) + coefficients.append(coefficient) + assert len(coefficients) == order, 'the length of coefficients does not match the order' + return coefficients + + def adams_bashforth_update(self, order, x, tau, model_prev_list, t_prev_list, noise, t): + """ + SA-Predictor, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4" + + # get noise schedule + ns = self.noise_schedule + alpha_t = ns.marginal_alpha(t) + sigma_t = ns.marginal_std(t) + lambda_t = ns.marginal_lambda(t) + alpha_prev = ns.marginal_alpha(t_prev_list[-1]) + sigma_prev = ns.marginal_std(t_prev_list[-1]) + gradient_part = torch.zeros_like(x) + h = lambda_t - ns.marginal_lambda(t_prev_list[-1]) + lambda_list = [] + for i in range(order): + lambda_list.append(ns.marginal_lambda(t_prev_list[-(i + 1)])) + gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t, + lambda_list, tau) + + for i in range(order): + if self.predict_x0: + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_prev_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part + + return x_t + + def adams_moulton_update(self, order, x, tau, model_prev_list, t_prev_list, noise, t): + """ + SA-Corrector, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + + assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4" + + # get noise schedule + ns = self.noise_schedule + alpha_t = ns.marginal_alpha(t) + sigma_t = ns.marginal_std(t) + lambda_t = ns.marginal_lambda(t) + alpha_prev = ns.marginal_alpha(t_prev_list[-1]) + sigma_prev = ns.marginal_std(t_prev_list[-1]) + gradient_part = torch.zeros_like(x) + h = lambda_t - ns.marginal_lambda(t_prev_list[-1]) + lambda_list = [] + t_list = t_prev_list + [t] + for i in range(order): + lambda_list.append(ns.marginal_lambda(t_list[-(i + 1)])) + gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t, + lambda_list, tau) + + for i in range(order): + if self.predict_x0: + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_prev_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part + + return x_t + + def adams_bashforth_update_few_steps(self, order, x, tau, model_prev_list, t_prev_list, noise, t): + """ + SA-Predictor, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + + assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4" + + # get noise schedule + ns = self.noise_schedule + alpha_t = ns.marginal_alpha(t) + sigma_t = ns.marginal_std(t) + lambda_t = ns.marginal_lambda(t) + alpha_prev = ns.marginal_alpha(t_prev_list[-1]) + sigma_prev = ns.marginal_std(t_prev_list[-1]) + gradient_part = torch.zeros_like(x) + h = lambda_t - ns.marginal_lambda(t_prev_list[-1]) + lambda_list = [] + for i in range(order): + lambda_list.append(ns.marginal_lambda(t_prev_list[-(i + 1)])) + gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t, + lambda_list, tau) + + if self.predict_x0: + if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2)) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda( + t_prev_list[-2])) + gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2)) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda( + t_prev_list[-2])) + + for i in range(order): + if self.predict_x0: + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_prev_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part + + return x_t + + def adams_moulton_update_few_steps(self, order, x, tau, model_prev_list, t_prev_list, noise, t): + """ + SA-Corrector, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + """ + + assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4" + + # get noise schedule + ns = self.noise_schedule + alpha_t = ns.marginal_alpha(t) + sigma_t = ns.marginal_std(t) + lambda_t = ns.marginal_lambda(t) + alpha_prev = ns.marginal_alpha(t_prev_list[-1]) + sigma_prev = ns.marginal_std(t_prev_list[-1]) + gradient_part = torch.zeros_like(x) + h = lambda_t - ns.marginal_lambda(t_prev_list[-1]) + lambda_list = [] + t_list = t_prev_list + [t] + for i in range(order): + lambda_list.append(ns.marginal_lambda(t_list[-(i + 1)])) + gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t, + lambda_list, tau) + + if self.predict_x0: + if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2 * h)) + gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2 * h)) + + for i in range(order): + if self.predict_x0: + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_prev_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part + + return x_t + + def sample_few_steps(self, x, tau, steps=5, t_start=None, t_end=None, skip_type='time', skip_order=1, + predictor_order=3, corrector_order=4, pc_mode='PEC', return_intermediate=False + ): + """ + For the PC-mode, please refer to the wiki page + https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode + 'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations + We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs. + """ + + skip_first_step = False + skip_final_step = True + lower_order_final = True + denoise_to_zero = False + + assert pc_mode in ['PEC', 'PECE'], 'Predictor-corrector mode only supports PEC and PECE' + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + + device = x.device + intermediates = [] + with torch.no_grad(): + assert steps >= max(predictor_order, corrector_order - 1) + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, order=skip_order, + device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + noise = torch.randn_like(x) + t_prev_list = [t] + # do not evaluate if skip_first_step + if skip_first_step: + if self.predict_x0: + alpha_t = self.noise_schedule.marginal_alpha(t) + sigma_t = self.noise_schedule.marginal_std(t) + model_prev_list = [(1 - sigma_t) / alpha_t * x] + else: + model_prev_list = [x] + else: + model_prev_list = [self.model_fn(x, t)] + + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + # determine the first several values + for step in tqdm(range(1, max(predictor_order, corrector_order - 1))): + + t = timesteps[step] + predictor_order_used = min(predictor_order, step) + corrector_order_used = min(corrector_order, step + 1) + noise = torch.randn_like(x) + # predictor step + x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, + noise=noise, t=t) + # evaluation step + model_x = self.model_fn(x_p, t) + + # update model_list + model_prev_list.append(model_x) + # corrector step + if corrector_order > 0: + x = self.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, + noise=noise, t=t) + else: + x = x_p + + # evaluation step if correction and mode = pece + if corrector_order > 0: + if pc_mode == 'PECE': + model_x = self.model_fn(x, t) + del model_prev_list[-1] + model_prev_list.append(model_x) + + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + t_prev_list.append(t) + + for step in tqdm(range(max(predictor_order, corrector_order - 1), steps + 1)): + if lower_order_final: + predictor_order_used = min(predictor_order, steps - step + 1) + corrector_order_used = min(corrector_order, steps - step + 2) + + else: + predictor_order_used = predictor_order + corrector_order_used = corrector_order + t = timesteps[step] + noise = torch.randn_like(x) + + # predictor step + if skip_final_step and step == steps and not denoise_to_zero: + x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=0, + model_prev_list=model_prev_list, + t_prev_list=t_prev_list, noise=noise, t=t) + else: + x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, + t_prev_list=t_prev_list, noise=noise, t=t) + + # evaluation step + # do not evaluate if skip_final_step and step = steps + if not skip_final_step or step < steps: + model_x = self.model_fn(x_p, t) + + # update model_list + # do not update if skip_final_step and step = steps + if not skip_final_step or step < steps: + model_prev_list.append(model_x) + + # corrector step + # do not correct if skip_final_step and step = steps + if corrector_order > 0: + if not skip_final_step or step < steps: + x = self.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, + t_prev_list=t_prev_list, noise=noise, t=t) + else: + x = x_p + else: + x = x_p + + # evaluation step if mode = pece and step != steps + if corrector_order > 0: + if pc_mode == 'PECE' and step < steps: + model_x = self.model_fn(x, t) + del model_prev_list[-1] + model_prev_list.append(model_x) + + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + t_prev_list.append(t) + del model_prev_list[0] + + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + def sample_more_steps(self, x, tau, steps=20, t_start=None, t_end=None, skip_type='time', skip_order=1, + predictor_order=3, corrector_order=4, pc_mode='PEC', return_intermediate=False + ): + """ + For the PC-mode, please refer to the wiki page + https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode + 'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations + We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs. + """ + + skip_first_step = False + skip_final_step = False + lower_order_final = True + denoise_to_zero = True + + assert pc_mode in ['PEC', 'PECE'], 'Predictor-corrector mode only supports PEC and PECE' + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + + device = x.device + intermediates = [] + with torch.no_grad(): + assert steps >= max(predictor_order, corrector_order - 1) + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, order=skip_order, + device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + noise = torch.randn_like(x) + t_prev_list = [t] + # do not evaluate if skip_first_step + if skip_first_step: + if self.predict_x0: + alpha_t = self.noise_schedule.marginal_alpha(t) + sigma_t = self.noise_schedule.marginal_std(t) + model_prev_list = [(1 - sigma_t) / alpha_t * x] + else: + model_prev_list = [x] + else: + model_prev_list = [self.model_fn(x, t)] + + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + # determine the first several values + for step in tqdm(range(1, max(predictor_order, corrector_order - 1))): + + t = timesteps[step] + predictor_order_used = min(predictor_order, step) + corrector_order_used = min(corrector_order, step + 1) + noise = torch.randn_like(x) + # predictor step + x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, noise=noise, + t=t) + # evaluation step + model_x = self.model_fn(x_p, t) + + # update model_list + model_prev_list.append(model_x) + # corrector step + if corrector_order > 0: + x = self.adams_moulton_update(order=corrector_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, noise=noise, + t=t) + else: + x = x_p + + # evaluation step if mode = pece + if corrector_order > 0: + if pc_mode == 'PECE': + model_x = self.model_fn(x, t) + del model_prev_list[-1] + model_prev_list.append(model_x) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + t_prev_list.append(t) + + for step in tqdm(range(max(predictor_order, corrector_order - 1), steps + 1)): + if lower_order_final: + predictor_order_used = min(predictor_order, steps - step + 1) + corrector_order_used = min(corrector_order, steps - step + 2) + + else: + predictor_order_used = predictor_order + corrector_order_used = corrector_order + t = timesteps[step] + noise = torch.randn_like(x) + + # predictor step + if skip_final_step and step == steps and not denoise_to_zero: + x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=0, + model_prev_list=model_prev_list, t_prev_list=t_prev_list, + noise=noise, t=t) + else: + x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, + noise=noise, t=t) + + # evaluation step + # do not evaluate if skip_final_step and step = steps + if not skip_final_step or step < steps: + model_x = self.model_fn(x_p, t) + + # update model_list + # do not update if skip_final_step and step = steps + if not skip_final_step or step < steps: + model_prev_list.append(model_x) + + # corrector step + # do not correct if skip_final_step and step = steps + if corrector_order > 0: + if not skip_final_step or step < steps: + x = self.adams_moulton_update(order=corrector_order_used, x=x, tau=tau(t), + model_prev_list=model_prev_list, t_prev_list=t_prev_list, + noise=noise, t=t) + else: + x = x_p + else: + x = x_p + + # evaluation step if mode = pece and step != steps + if corrector_order > 0: + if pc_mode == 'PECE' and step < steps: + model_x = self.model_fn(x, t) + del model_prev_list[-1] + model_prev_list.append(model_x) + + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + t_prev_list.append(t) + del model_prev_list[0] + + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + def sample(self, mode, x, tau, steps, t_start=None, t_end=None, skip_type='time', skip_order=1, predictor_order=3, + corrector_order=4, pc_mode='PEC', return_intermediate=False + ): + """ + For the PC-mode, please refer to the wiki page + https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode + 'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations + We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs. + + 'few_steps' mode is recommended. The differences between 'few_steps' and 'more_steps' are as below: + 1) 'few_steps' do not correct at final step and do not denoise to zero, while 'more_steps' do these two. + Thus the NFEs for 'few_steps' = steps, NFEs for 'more_steps' = steps + 2 + For most of the experiments and tasks, we find these two operations do not have much help to sample quality. + 2) 'few_steps' use a rescaling trick as in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf + We find it will slightly improve the sample quality especially in few steps. + """ + assert mode in ['few_steps', 'more_steps'], "mode must be either 'few_steps' or 'more_steps'" + if mode == 'few_steps': + return self.sample_few_steps(x=x, tau=tau, steps=steps, t_start=t_start, t_end=t_end, skip_type=skip_type, + skip_order=skip_order, predictor_order=predictor_order, + corrector_order=corrector_order, pc_mode=pc_mode, + return_intermediate=return_intermediate) + else: + return self.sample_more_steps(x=x, tau=tau, steps=steps, t_start=t_start, t_end=t_end, skip_type=skip_type, + skip_order=skip_order, predictor_order=predictor_order, + corrector_order=corrector_order, pc_mode=pc_mode, + return_intermediate=return_intermediate) + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/src/diffusion/model/t5.py b/src/diffusion/model/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..19cb6f93378cf2419c14ac03174949819e1a5826 --- /dev/null +++ b/src/diffusion/model/t5.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- +import os +import re +import html +import urllib.parse as ul + +import ftfy +import torch +from bs4 import BeautifulSoup +from transformers import T5EncoderModel, AutoTokenizer +from huggingface_hub import hf_hub_download + +class T5Embedder: + + available_models = ['t5-v1_1-xxl'] + bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa + + def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True, + t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120): + self.device = torch.device(device) + self.torch_dtype = torch_dtype or torch.bfloat16 + if t5_model_kwargs is None: + t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype} + if use_offload_folder is not None: + t5_model_kwargs['offload_folder'] = use_offload_folder + t5_model_kwargs['device_map'] = { + 'shared': self.device, + 'encoder.embed_tokens': self.device, + 'encoder.block.0': self.device, + 'encoder.block.1': self.device, + 'encoder.block.2': self.device, + 'encoder.block.3': self.device, + 'encoder.block.4': self.device, + 'encoder.block.5': self.device, + 'encoder.block.6': self.device, + 'encoder.block.7': self.device, + 'encoder.block.8': self.device, + 'encoder.block.9': self.device, + 'encoder.block.10': self.device, + 'encoder.block.11': self.device, + 'encoder.block.12': 'disk', + 'encoder.block.13': 'disk', + 'encoder.block.14': 'disk', + 'encoder.block.15': 'disk', + 'encoder.block.16': 'disk', + 'encoder.block.17': 'disk', + 'encoder.block.18': 'disk', + 'encoder.block.19': 'disk', + 'encoder.block.20': 'disk', + 'encoder.block.21': 'disk', + 'encoder.block.22': 'disk', + 'encoder.block.23': 'disk', + 'encoder.final_layer_norm': 'disk', + 'encoder.dropout': 'disk', + } + else: + t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device} + + self.use_text_preprocessing = use_text_preprocessing + self.hf_token = hf_token + self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_') + self.dir_or_name = dir_or_name + tokenizer_path, path = dir_or_name, dir_or_name + if local_cache: + cache_dir = os.path.join(self.cache_dir, dir_or_name) + tokenizer_path, path = cache_dir, cache_dir + elif dir_or_name in self.available_models: + cache_dir = os.path.join(self.cache_dir, dir_or_name) + for filename in [ + 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', + 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin' + ]: + hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir, + force_filename=filename, token=self.hf_token) + tokenizer_path, path = cache_dir, cache_dir + else: + cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl') + for filename in [ + 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', + ]: + hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir, + force_filename=filename, token=self.hf_token) + tokenizer_path = cache_dir + + print(tokenizer_path) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() + self.model_max_length = model_max_length + + def get_text_embeddings(self, texts): + texts = [self.text_preprocessing(text) for text in texts] + + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.model_max_length, + padding='max_length', + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors='pt' + ) + + text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'] + text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'] + + with torch.no_grad(): + text_encoder_embs = self.model( + input_ids=text_tokens_and_mask['input_ids'].to(self.device), + attention_mask=text_tokens_and_mask['attention_mask'].to(self.device), + )['last_hidden_state'].detach() + return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device) + + def text_preprocessing(self, text): + if self.use_text_preprocessing: + # The exact text cleaning as was in the training stage: + text = self.clean_caption(text) + text = self.clean_caption(text) + return text + else: + return text.lower().strip() + + @staticmethod + def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + def clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub('', 'person', caption) + # urls: + caption = re.sub( + r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + caption = re.sub( + r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + # html: + caption = BeautifulSoup(caption, features='html.parser').text + + # @ + caption = re.sub(r'@[\w\d]+\b', '', caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) + caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) + caption = re.sub(r'[\u3200-\u32ff]+', '', caption) + caption = re.sub(r'[\u3300-\u33ff]+', '', caption) + caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) + caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) + caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa + '-', caption) + + # кавычки к одному стандарту + caption = re.sub(r'[`´«»“”¨]', '"', caption) + caption = re.sub(r'[‘’]', "'", caption) + + # " + caption = re.sub(r'"?', '', caption) + # & + caption = re.sub(r'&', '', caption) + + # ip adresses: + caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) + + # article ids: + caption = re.sub(r'\d:\d\d\s+$', '', caption) + + # \n + caption = re.sub(r'\\n', ' ', caption) + + # "#123" + caption = re.sub(r'#\d{1,3}\b', '', caption) + # "#12345.." + caption = re.sub(r'#\d{5,}\b', '', caption) + # "123456.." + caption = re.sub(r'\b\d{6,}\b', '', caption) + # filenames: + caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) + + # + caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r'(?:\-|\_)') + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, ' ', caption) + + caption = self.basic_clean(caption) + + caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 + caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc + caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 + + caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) + caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) + caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) + caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) + caption = re.sub(r'\bpage\s+\d+\b', '', caption) + + caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... + + caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) + + caption = re.sub(r'\b\s+\:\s+', r': ', caption) + caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) + caption = re.sub(r'\s+', ' ', caption) + + caption.strip() + + caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) + caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) + caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) + caption = re.sub(r'^\.\S+$', '', caption) + + return caption.strip() diff --git a/src/diffusion/model/timestep_sampler.py b/src/diffusion/model/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..ec6a583841ea7e39eb02d5027a5cf2f52890195b --- /dev/null +++ b/src/diffusion/model/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs, device=local_ts.device) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs, device=local_losses.device) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/src/diffusion/model/utils.py b/src/diffusion/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d9de176aed8c04823490db42d30509bf96140b75 --- /dev/null +++ b/src/diffusion/model/utils.py @@ -0,0 +1,512 @@ +import os +import sys +import torch.nn as nn +from torch.utils.checkpoint import checkpoint, checkpoint_sequential +import torch.nn.functional as F +import torch +import torch.distributed as dist +import re +import math +from collections.abc import Iterable +from itertools import repeat +from torchvision import transforms as T +import random +from PIL import Image + + +def _ntuple(n): + def parse(x): + if isinstance(x, Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) + +def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): + assert isinstance(model, nn.Module) + + def set_attr(module): + module.grad_checkpointing = True + module.fp32_attention = use_fp32_attention + module.grad_checkpointing_step = gc_step + model.apply(set_attr) + + +def auto_grad_checkpoint(module, *args, **kwargs): + if getattr(module, 'grad_checkpointing', False): + if isinstance(module, Iterable): + gc_step = module[0].grad_checkpointing_step + return checkpoint_sequential(module, gc_step, *args, **kwargs) + else: + return checkpoint(module, *args, **kwargs) + return module(*args, **kwargs) + + +def checkpoint_sequential(functions, step, input, *args, **kwargs): + + # Hack for keyword-only parameter in a python 2.7-compliant way + preserve = kwargs.pop('preserve_rng_state', True) + if kwargs: + raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) + + def run_function(start, end, functions): + def forward(input): + for j in range(start, end + 1): + input = functions[j](input, *args) + return input + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = list(functions.children()) + + # the last chunk has to be non-volatile + end = -1 + segment = len(functions) // step + for start in range(0, step * (segment - 1), step): + end = start + step - 1 + input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) + return run_function(end + 1, len(functions) - 1, functions)(input) + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size, k_size, rel_pos): + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + +def mean_flat(tensor): + return tensor.mean(dim=list(range(1, tensor.ndim))) + + +################################################################################# +# Token Masking and Unmasking # +################################################################################# +def get_mask(batch, length, mask_ratio, device, mask_type=None, data_info=None, extra_len=0): + """ + Get the binary mask for the input sequence. + Args: + - batch: batch size + - length: sequence length + - mask_ratio: ratio of tokens to mask + - data_info: dictionary with info for reconstruction + return: + mask_dict with following keys: + - mask: binary mask, 0 is keep, 1 is remove + - ids_keep: indices of tokens to keep + - ids_restore: indices to restore the original order + """ + assert mask_type in ['random', 'fft', 'laplacian', 'group'] + mask = torch.ones([batch, length], device=device) + len_keep = int(length * (1 - mask_ratio)) - extra_len + + if mask_type == 'random' or mask_type == 'group': + noise = torch.rand(batch, length, device=device) # noise in [0, 1] + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + ids_removed = ids_shuffle[:, len_keep:] + + elif mask_type in ['fft', 'laplacian']: + if 'strength' in data_info: + strength = data_info['strength'] + + else: + N = data_info['N'][0] + img = data_info['ori_img'] + # 获取原图的尺寸信息 + _, C, H, W = img.shape + if mask_type == 'fft': + # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N) + reshaped_image = img.reshape((batch, -1, H // N, N, W // N, N)) + fft_image = torch.fft.fftn(reshaped_image, dim=(3, 5)) + # 取绝对值并求和获取频率强度 + strength = torch.sum(torch.abs(fft_image), dim=(1, 3, 5)).reshape((batch, -1,)) + elif type == 'laplacian': + laplacian_kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32).reshape(1, 1, 3, 3) + laplacian_kernel = laplacian_kernel.repeat(C, 1, 1, 1) + # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N) + reshaped_image = img.reshape(-1, C, H // N, N, W // N, N).permute(0, 2, 4, 1, 3, 5).reshape(-1, C, N, N) + laplacian_response = F.conv2d(reshaped_image, laplacian_kernel, padding=1, groups=C) + strength = laplacian_response.sum(dim=[1, 2, 3]).reshape((batch, -1,)) + + # 对频率强度进行归一化,然后使用torch.multinomial进行采样 + probabilities = strength / (strength.max(dim=1)[0][:, None]+1e-5) + ids_shuffle = torch.multinomial(probabilities.clip(1e-5, 1), length, replacement=False) + ids_keep = ids_shuffle[:, :len_keep] + ids_restore = torch.argsort(ids_shuffle, dim=1) + ids_removed = ids_shuffle[:, len_keep:] + + mask[:, :len_keep] = 0 + mask = torch.gather(mask, dim=1, index=ids_restore) + + return {'mask': mask, + 'ids_keep': ids_keep, + 'ids_restore': ids_restore, + 'ids_removed': ids_removed} + + +def mask_out_token(x, ids_keep, ids_removed=None): + """ + Mask out the tokens specified by ids_keep. + Args: + - x: input sequence, [N, L, D] + - ids_keep: indices of tokens to keep + return: + - x_masked: masked sequence + """ + N, L, D = x.shape # batch, length, dim + x_remain = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + if ids_removed is not None: + x_masked = torch.gather(x, dim=1, index=ids_removed.unsqueeze(-1).repeat(1, 1, D)) + return x_remain, x_masked + else: + return x_remain + + +def mask_tokens(x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + +def unmask_tokens(x, ids_restore, mask_token): + # x: [N, T, D] if extras == 0 (i.e., no cls token) else x: [N, T+1, D] + mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + x = torch.cat([x, mask_tokens], dim=1) + x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + return x + + +# Parse 'None' to None and others to float value +def parse_float_none(s): + assert isinstance(s, str) + return None if s == 'None' else float(s) + + +#---------------------------------------------------------------------------- +# Parse a comma separated list of numbers or ranges and return a list of ints. +# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] + +def parse_int_list(s): + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + m = range_re.match(p) + if m: + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + + +def init_processes(fn, args): + """ Initialize the distributed environment. """ + os.environ['MASTER_ADDR'] = args.master_address + os.environ['MASTER_PORT'] = str(random.randint(2000, 6000)) + print(f'MASTER_ADDR = {os.environ["MASTER_ADDR"]}') + print(f'MASTER_PORT = {os.environ["MASTER_PORT"]}') + torch.cuda.set_device(args.local_rank) + dist.init_process_group(backend='nccl', init_method='env://', rank=args.global_rank, world_size=args.global_size) + fn(args) + if args.global_size > 1: + cleanup() + + +def mprint(*args, **kwargs): + """ + Print only from rank 0. + """ + if dist.get_rank() == 0: + print(*args, **kwargs) + + +def cleanup(): + """ + End DDP training. + """ + dist.barrier() + mprint("Done!") + dist.barrier() + dist.destroy_process_group() + + +#---------------------------------------------------------------------------- +# logging info. +class Logger(object): + """ + Redirect stderr to stdout, optionally print stdout to a file, + and optionally force flushing on both stdout and the file. + """ + + def __init__(self, file_name=None, file_mode="w", should_flush=True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def write(self, text): + """Write text to stdout (and a file) and optionally flush.""" + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self): + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self): + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + + +class StackedRandomGenerator: + def __init__(self, device, seeds): + super().__init__() + self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] + + def randn(self, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) + + def randn_like(self, input): + return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) + + def randint(self, *args, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) + + +def prepare_prompt_ar(prompt, ratios, device='cpu', show=True): + # get aspect_ratio or ar + aspect_ratios = re.findall(r"--aspect_ratio\s+(\d+:\d+)", prompt) + ars = re.findall(r"--ar\s+(\d+:\d+)", prompt) + custom_hw = re.findall(r"--hw\s+(\d+:\d+)", prompt) + if show: + print("aspect_ratios:", aspect_ratios, "ars:", ars, "hws:", custom_hw) + prompt_clean = prompt.split("--aspect_ratio")[0].split("--ar")[0].split("--hw")[0] + if len(aspect_ratios) + len(ars) + len(custom_hw) == 0 and show: + print("Wrong prompt format. Set to default ar: 1. change your prompt into format '--ar h:w or --hw h:w' for correct generating") + if len(aspect_ratios) != 0: + ar = float(aspect_ratios[0].split(':')[0]) / float(aspect_ratios[0].split(':')[1]) + elif len(ars) != 0: + ar = float(ars[0].split(':')[0]) / float(ars[0].split(':')[1]) + else: + ar = 1. + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) + if len(custom_hw) != 0: + custom_hw = [float(custom_hw[0].split(':')[0]), float(custom_hw[0].split(':')[1])] + else: + custom_hw = ratios[closest_ratio] + default_hw = ratios[closest_ratio] + prompt_show = f'prompt: {prompt_clean.strip()}\nSize: --ar {closest_ratio}, --bin hw {ratios[closest_ratio]}, --custom hw {custom_hw}' + return prompt_clean, prompt_show, torch.tensor(default_hw, device=device)[None], torch.tensor([float(closest_ratio)], device=device)[None], torch.tensor(custom_hw, device=device)[None] + + +def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int): + orig_hw = torch.tensor([samples.shape[2], samples.shape[3]], dtype=torch.int) + custom_hw = torch.tensor([int(new_height), int(new_width)], dtype=torch.int) + + if (orig_hw != custom_hw).all(): + ratio = max(custom_hw[0] / orig_hw[0], custom_hw[1] / orig_hw[1]) + resized_width = int(orig_hw[1] * ratio) + resized_height = int(orig_hw[0] * ratio) + + transform = T.Compose([ + T.Resize((resized_height, resized_width)), + T.CenterCrop(custom_hw.tolist()) + ]) + return transform(samples) + else: + return samples + + +def resize_and_crop_img(img: Image, new_width, new_height): + orig_width, orig_height = img.size + + ratio = max(new_width/orig_width, new_height/orig_height) + resized_width = int(orig_width * ratio) + resized_height = int(orig_height * ratio) + + img = img.resize((resized_width, resized_height), Image.LANCZOS) + + left = (resized_width - new_width)/2 + top = (resized_height - new_height)/2 + right = (resized_width + new_width)/2 + bottom = (resized_height + new_height)/2 + + img = img.crop((left, top, right, bottom)) + + return img + + + +def mask_feature(emb, mask): + if emb.shape[0] == 1: + keep_index = mask.sum().item() + return emb[:, :, :keep_index, :], keep_index + else: + masked_feature = emb * mask[:, None, :, None] + return masked_feature, emb.shape[2] \ No newline at end of file diff --git a/src/diffusion/sa_sampler.py b/src/diffusion/sa_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..09372337f0e5868c5cfc5081418d2bf5907c7a59 --- /dev/null +++ b/src/diffusion/sa_sampler.py @@ -0,0 +1,94 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np + +from diffusion.model.sa_solver import NoiseScheduleVP, model_wrapper, SASolver +from .model import gaussian_diffusion as gd + + +class SASolverSampler(object): + def __init__(self, model, + noise_schedule="linear", + diffusion_steps=1000, + device='cpu', + ): + super().__init__() + self.model = model + self.device = device + to_torch = lambda x: x.clone().detach().to(torch.float32).to(device) + betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) + alphas = 1.0 - betas + self.register_buffer('alphas_cumprod', to_torch(np.cumprod(alphas, axis=0))) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + model_kwargs={}, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + device = self.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + self.model, + ns, + model_type="noise", + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + model_kwargs=model_kwargs, + ) + + sasolver = SASolver(model_fn, ns, algorithm_type="data_prediction") + + tau_t = lambda t: eta if 0.2 <= t <= 0.8 else 0 + + x = sasolver.sample(mode='few_steps', x=img, tau=tau_t, steps=S, skip_type='time', skip_order=1, predictor_order=2, corrector_order=2, pc_mode='PEC', return_intermediate=False) + + return x.to(device), None \ No newline at end of file diff --git a/src/diffusion/sa_solver_diffusers.py b/src/diffusion/sa_solver_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..8e57e9365bf4659136711a9cc2d9566bc16bafe0 --- /dev/null +++ b/src/diffusion/sa_solver_diffusers.py @@ -0,0 +1,856 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://arxiv.org/abs/2309.05019 +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + +import math +from typing import List, Optional, Tuple, Union, Callable + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class SASolverScheduler(SchedulerMixin, ConfigMixin): + """ + `SASolverScheduler` is a fast dedicated high-order solver for diffusion SDEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + predictor_order (`int`, defaults to 2): + The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for guided + sampling, and `predictor_order=3` for unconditional sampling. + corrector_order (`int`, defaults to 2): + The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for guided + sampling, and `corrector_order=3` for unconditional sampling. + predictor_corrector_mode (`str`, defaults to `PEC`): + The predictor-corrector mode can be `PEC` or 'PECE'. It is recommended to use `PEC` mode for fast + sampling, and `PECE` for high-quality sampling (PECE needs around twice model evaluations as PEC). + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `data_prediction`): + Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use `data_prediction` + with `solver_order=2` for guided sampling like in Stable Diffusion. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Default = True. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co./papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + predictor_order: int = 2, + corrector_order: int = 2, + predictor_corrector_mode: str = 'PEC', + prediction_type: str = "epsilon", + tau_func: Callable = lambda t: 1 if t >= 200 and t <= 800 else 0, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "data_prediction", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + if algorithm_type not in ["data_prediction", "noise_prediction"]: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.timestep_list = [None] * max(predictor_order, corrector_order - 1) + self.model_outputs = [None] * max(predictor_order, corrector_order - 1) + + self.tau_func = tau_func + self.predict_x0 = algorithm_type == "data_prediction" + self.lower_order_nums = 0 + self.last_sample = None + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + ) + + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * max(self.config.predictor_order, self.config.corrector_order - 1) + self.lower_order_nums = 0 + self.last_sample = None + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The converted model output. + """ + + # SA-Solver_data_prediction needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["data_prediction"]: + if self.config.prediction_type == "epsilon": + # SA-Solver only needs the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the SASolverScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # SA-Solver_noise_prediction needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["noise_prediction"]: + if self.config.prediction_type == "epsilon": + # SA-Solver only needs the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the SASolverScheduler." + ) + + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def get_coefficients_exponential_negative(self, order, interval_start, interval_end): + """ + Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + if order == 0: + return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1) + elif order == 1: + return torch.exp(-interval_end) * ( + (interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1)) + elif order == 2: + return torch.exp(-interval_end) * ( + (interval_start ** 2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) - ( + interval_end ** 2 + 2 * interval_end + 2)) + elif order == 3: + return torch.exp(-interval_end) * ( + (interval_start ** 3 + 3 * interval_start ** 2 + 6 * interval_start + 6) * torch.exp( + interval_end - interval_start) - (interval_end ** 3 + 3 * interval_end ** 2 + 6 * interval_end + 6)) + + def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau): + """ + Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + # after change of variable(cov) + interval_end_cov = (1 + tau ** 2) * interval_end + interval_start_cov = (1 + tau ** 2) * interval_start + + if order == 0: + return torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / ( + (1 + tau ** 2)) + elif order == 1: + return torch.exp(interval_end_cov) * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 2) + elif order == 2: + return torch.exp(interval_end_cov) * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) - ( + interval_start_cov ** 2 - 2 * interval_start_cov + 2) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 3) + elif order == 3: + return torch.exp(interval_end_cov) * ( + (interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) - ( + interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) * torch.exp( + -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 4) + + def lagrange_polynomial_coefficient(self, order, lambda_list): + """ + Calculate the coefficient of lagrange polynomial + """ + + assert order in [0, 1, 2, 3] + assert order == len(lambda_list) - 1 + if order == 0: + return [[1]] + elif order == 1: + return [[1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])], + [1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]] + elif order == 2: + denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) + denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) + denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) + return [[1 / denominator1, + (-lambda_list[1] - lambda_list[2]) / denominator1, + lambda_list[1] * lambda_list[2] / denominator1], + + [1 / denominator2, + (-lambda_list[0] - lambda_list[2]) / denominator2, + lambda_list[0] * lambda_list[2] / denominator2], + + [1 / denominator3, + (-lambda_list[0] - lambda_list[1]) / denominator3, + lambda_list[0] * lambda_list[1] / denominator3] + ] + elif order == 3: + denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * ( + lambda_list[0] - lambda_list[3]) + denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * ( + lambda_list[1] - lambda_list[3]) + denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * ( + lambda_list[2] - lambda_list[3]) + denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * ( + lambda_list[3] - lambda_list[2]) + return [[1 / denominator1, + (-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1, + (lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[ + 3]) / denominator1, + (-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1], + + [1 / denominator2, + (-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2, + (lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[ + 3]) / denominator2, + (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2], + + [1 / denominator3, + (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3, + (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[ + 3]) / denominator3, + (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3], + + [1 / denominator4, + (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4, + (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[ + 2]) / denominator4, + (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4] + + ] + + def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau): + assert order in [1, 2, 3, 4] + assert order == len(lambda_list), 'the length of lambda list must be equal to the order' + coefficients = [] + lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list) + for i in range(order): + coefficient = 0 + for j in range(order): + if self.predict_x0: + + coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_positive( + order - 1 - j, interval_start, interval_end, tau) + else: + coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_negative( + order - 1 - j, interval_start, interval_end) + coefficients.append(coefficient) + assert len(coefficients) == order, 'the length of coefficients does not match the order' + return coefficients + + def stochastic_adams_bashforth_update( + self, + model_output: torch.FloatTensor, + prev_timestep: int, + sample: torch.FloatTensor, + noise: torch.FloatTensor, + order: int, + tau: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the SA-Predictor. + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of SA-Predictor at this timestep. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + + assert noise is not None + timestep_list = self.timestep_list + model_output_list = self.model_outputs + s0, t = self.timestep_list[-1], prev_timestep + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + gradient_part = torch.zeros_like(sample) + h = lambda_t - lambda_s0 + lambda_list = [] + + for i in range(order): + lambda_list.append(self.lambda_t[timestep_list[-(i + 1)]]) + + gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau) + + x = sample + + if self.predict_x0: + if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[ + timestep_list[-2]]) + gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[ + timestep_list[-2]]) + + for i in range(order): + if self.predict_x0: + + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_output_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_output_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part + + x_t = x_t.to(x.dtype) + return x_t + + def stochastic_adams_moulton_update( + self, + this_model_output: torch.FloatTensor, + this_timestep: int, + last_sample: torch.FloatTensor, + last_noise: torch.FloatTensor, + this_sample: torch.FloatTensor, + order: int, + tau: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the SA-Corrector. + + Args: + this_model_output (`torch.FloatTensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.FloatTensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.FloatTensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The order of SA-Corrector at this step. + + Returns: + `torch.FloatTensor`: + The corrected sample tensor at the current timestep. + """ + + assert last_noise is not None + timestep_list = self.timestep_list + model_output_list = self.model_outputs + s0, t = self.timestep_list[-1], this_timestep + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + gradient_part = torch.zeros_like(this_sample) + h = lambda_t - lambda_s0 + t_list = timestep_list + [this_timestep] + lambda_list = [] + for i in range(order): + lambda_list.append(self.lambda_t[t_list[-(i + 1)]]) + + model_prev_list = model_output_list + [this_model_output] + + gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau) + + x = last_sample + + if self.predict_x0: + if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2 * h)) + gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * ( + h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ( + (1 + tau ** 2) ** 2 * h)) + + for i in range(order): + if self.predict_x0: + gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[ + i] * model_prev_list[-(i + 1)] + else: + gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * last_noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * last_noise + + if self.predict_x0: + x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part + else: + x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part + + x_t = x_t.to(x.dtype) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the SA-Solver. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + + use_corrector = ( + step_index > 0 and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, timestep, sample) + + if use_corrector: + current_tau = self.tau_func(self.timestep_list[-1]) + sample = self.stochastic_adams_moulton_update( + this_model_output=model_output_convert, + this_timestep=timestep, + last_sample=self.last_sample, + last_noise=self.last_noise, + this_sample=sample, + order=self.this_corrector_order, + tau=current_tau, + ) + + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + + for i in range(max(self.config.predictor_order, self.config.corrector_order - 1) - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + + if self.config.lower_order_final: + this_predictor_order = min(self.config.predictor_order, len(self.timesteps) - step_index) + this_corrector_order = min(self.config.corrector_order, len(self.timesteps) - step_index + 1) + else: + this_predictor_order = self.config.predictor_order + this_corrector_order = self.config.corrector_order + + self.this_predictor_order = min(this_predictor_order, self.lower_order_nums + 1) # warmup for multistep + self.this_corrector_order = min(this_corrector_order, self.lower_order_nums + 2) # warmup for multistep + assert self.this_predictor_order > 0 + assert self.this_corrector_order > 0 + + self.last_sample = sample + self.last_noise = noise + + current_tau = self.tau_func(self.timestep_list[-1]) + prev_sample = self.stochastic_adams_bashforth_update( + model_output=model_output_convert, + prev_timestep=prev_timestep, + sample=sample, + noise=noise, + order=self.this_predictor_order, + tau=current_tau, + ) + + if self.lower_order_nums < max(self.config.predictor_order, self.config.corrector_order - 1): + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/src/diffusion/utils/__init__.py b/src/diffusion/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/diffusion/utils/data.py b/src/diffusion/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..27fae41e9e759e446b3f5f1fad13abd0520aa396 --- /dev/null +++ b/src/diffusion/utils/data.py @@ -0,0 +1,134 @@ + +ASPECT_RATIO_2880 = { + '0.25': [1408.0, 5760.0], '0.26': [1408.0, 5568.0], '0.27': [1408.0, 5376.0], '0.28': [1408.0, 5184.0], + '0.32': [1600.0, 4992.0], '0.33': [1600.0, 4800.0], '0.34': [1600.0, 4672.0], '0.4': [1792.0, 4480.0], + '0.42': [1792.0, 4288.0], '0.47': [1920.0, 4096.0], '0.49': [1920.0, 3904.0], '0.51': [1920.0, 3776.0], + '0.55': [2112.0, 3840.0], '0.59': [2112.0, 3584.0], '0.68': [2304.0, 3392.0], '0.72': [2304.0, 3200.0], + '0.78': [2496.0, 3200.0], '0.83': [2496.0, 3008.0], '0.89': [2688.0, 3008.0], '0.93': [2688.0, 2880.0], + '1.0': [2880.0, 2880.0], '1.07': [2880.0, 2688.0], '1.12': [3008.0, 2688.0], '1.21': [3008.0, 2496.0], + '1.28': [3200.0, 2496.0], '1.39': [3200.0, 2304.0], '1.47': [3392.0, 2304.0], '1.7': [3584.0, 2112.0], + '1.82': [3840.0, 2112.0], '2.03': [3904.0, 1920.0], '2.13': [4096.0, 1920.0], '2.39': [4288.0, 1792.0], + '2.5': [4480.0, 1792.0], '2.92': [4672.0, 1600.0], '3.0': [4800.0, 1600.0], '3.12': [4992.0, 1600.0], + '3.68': [5184.0, 1408.0], '3.82': [5376.0, 1408.0], '3.95': [5568.0, 1408.0], '4.0': [5760.0, 1408.0] +} + +ASPECT_RATIO_2048 = { + '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], '0.27': [1024.0, 3840.0], '0.28': [1024.0, 3712.0], + '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0], + '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0], + '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0], + '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0], + '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0], + '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0], + '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0], + '2.5': [3200.0, 1280.0], '2.89': [3328.0, 1152.0], '3.0': [3456.0, 1152.0], '3.11': [3584.0, 1152.0], + '3.62': [3712.0, 1024.0], '3.75': [3840.0, 1024.0], '3.88': [3968.0, 1024.0], '4.0': [4096.0, 1024.0] +} + +ASPECT_RATIO_1024 = { + '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.], + '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], + '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], + '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], + '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], + '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], + '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], + '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], + '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.], + '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.], +} + +ASPECT_RATIO_512 = { + '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], + '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], + '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], + '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], + '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], + '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], + '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], + '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], + '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], + '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] + } + +ASPECT_RATIO_256 = { + '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0], + '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], + '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], + '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], + '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], + '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], + '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], + '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], + '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0], + '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0] +} + +ASPECT_RATIO_256_TEST = { + '0.25': [128.0, 512.0], '0.28': [128.0, 464.0], + '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], + '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], + '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], + '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], + '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], + '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], + '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], + '2.5': [400.0, 160.0], '3.0': [432.0, 144.0], + '4.0': [512.0, 128.0] +} + +ASPECT_RATIO_512_TEST = { + '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0], + '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], + '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], + '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], + '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], + '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], + '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], + '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], + '2.5': [800.0, 320.0], '3.0': [864.0, 288.0], + '4.0': [1024.0, 256.0] + } + +ASPECT_RATIO_1024_TEST = { + '0.25': [512., 2048.], '0.28': [512., 1856.], + '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], + '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], + '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], + '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], + '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], + '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], + '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], + '2.5': [1600., 640.], '3.0': [1728., 576.], + '4.0': [2048., 512.], +} + +ASPECT_RATIO_2048_TEST = { + '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], + '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0], + '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0], + '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0], + '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0], + '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0], + '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0], + '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0], + '2.5': [3200.0, 1280.0], '3.0': [3456.0, 1152.0], + '4.0': [4096.0, 1024.0] +} + +ASPECT_RATIO_2880_TEST = { + '0.25': [2048.0, 8192.0], '0.26': [2048.0, 7936.0], + '0.32': [2304.0, 7168.0], '0.33': [2304.0, 6912.0], '0.35': [2304.0, 6656.0], '0.4': [2560.0, 6400.0], + '0.42': [2560.0, 6144.0], '0.48': [2816.0, 5888.0], '0.5': [2816.0, 5632.0], '0.52': [2816.0, 5376.0], + '0.57': [3072.0, 5376.0], '0.6': [3072.0, 5120.0], '0.68': [3328.0, 4864.0], '0.72': [3328.0, 4608.0], + '0.78': [3584.0, 4608.0], '0.82': [3584.0, 4352.0], '0.88': [3840.0, 4352.0], '0.94': [3840.0, 4096.0], + '1.0': [4096.0, 4096.0], '1.07': [4096.0, 3840.0], '1.13': [4352.0, 3840.0], '1.21': [4352.0, 3584.0], + '1.29': [4608.0, 3584.0], '1.38': [4608.0, 3328.0], '1.46': [4864.0, 3328.0], '1.67': [5120.0, 3072.0], + '1.75': [5376.0, 3072.0], '2.0': [5632.0, 2816.0], '2.09': [5888.0, 2816.0], '2.4': [6144.0, 2560.0], + '2.5': [6400.0, 2560.0], '3.0': [6912.0, 2304.0], + '4.0': [8192.0, 2048.0], +} + +def get_chunks(lst, n): + for i in range(0, len(lst), n): + yield lst[i:i + n] diff --git a/src/diffusion/utils/misc.py b/src/diffusion/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..c4efea5d9e1ff39331ea1f717c3227406a0c68ea --- /dev/null +++ b/src/diffusion/utils/misc.py @@ -0,0 +1,358 @@ +import collections +import os +import random + +import numpy as np +import torch +import torch.distributed as dist +from mmcv import Config +from mmcv.runner import get_dist_info + + +os.environ["MOX_SILENT_MODE"] = "1" # mute moxing log + + +def read_config(file): + # solve config loading conflict when multi-processes + import time + while True: + config = Config.fromfile(file) + if len(config) == 0: + time.sleep(0.1) + continue + break + return config + + +def init_random_seed(seed=None, device='cuda'): + """Initialize random seed. + + If the seed is not set, the seed will be automatically randomized, + and then broadcast to all processes to prevent some potential bugs. + + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + + Returns: + int: Seed to be used. + """ + if seed is not None: + return seed + + # Make sure all ranks share the same random seed to prevent + # some potential bugs. Please refer to + # https://github.com/open-mmlab/mmdetection/issues/6339 + rank, world_size = get_dist_info() + seed = np.random.randint(2 ** 31) + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() + + +def set_random_seed(seed, deterministic=False): + """Set random seed. + + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +class DebugUnderflowOverflow: + """ + This debug class helps detect and understand where the model starts getting very large or very small, and more + importantly `nan` or `inf` weight and activation elements. + There are 2 working modes: + 1. Underflow/overflow detection (default) + 2. Specific batch absolute min/max tracing without detection + Mode 1: Underflow/overflow detection + To activate the underflow/overflow detection, initialize the object with the model : + ```python + debug_overflow = DebugUnderflowOverflow(model) + ``` + then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or + output elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this + event, each frame reporting + 1. the fully qualified module name plus the class name whose `forward` was run + 2. the absolute min and max value of all elements for each module weights, and the inputs and output + For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16 mixed precision : + ``` + Detected inf/nan during batch_number=0 + Last 21 forward frames: + abs min abs max metadata + [...] + encoder.block.2.layer.1.DenseReluDense.wi_0 Linear + 2.17e-07 4.50e+00 weight + 1.79e-06 4.65e+00 input[0] + 2.68e-06 3.70e+01 output + encoder.block.2.layer.1.DenseReluDense.wi_1 Linear + 8.08e-07 2.66e+01 weight + 1.79e-06 4.65e+00 input[0] + 1.27e-04 2.37e+02 output + encoder.block.2.layer.1.DenseReluDense.wo Linear + 1.01e-06 6.44e+00 weight + 0.00e+00 9.74e+03 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense + 1.79e-06 4.65e+00 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.dropout Dropout + 3.18e-04 6.27e+04 input[0] + 0.00e+00 inf output + ``` + You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value + was around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which + renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than + 64K, and we get an overlow. + As you can see it's the previous frames that we need to look into when the numbers start going into very large for + fp16 numbers. + The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed. + By default the last 21 frames are printed. You can change the default to adjust for your needs. For example : + ```python + debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100) + ``` + To validate that you have set up this debugging feature correctly, and you intend to use it in a training that may + take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in the next + section. + Mode 2. Specific batch absolute min/max tracing without detection + The second work mode is per-batch tracing with the underflow/overflow detection feature turned off. + Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a + given batch, and only do that for batches 1 and 3. Then you instantiate this class as : + ```python + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3]) + ``` + And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed. + This is helpful if you know that the program starts misbehaving after a certain batch number, so you can + fast-forward right to that area. + Early stopping: + You can also specify the batch number after which to stop the training, with : + ```python + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3) + ``` + This feature is mainly useful in the tracing mode, but you can use it for any mode. + **Performance**: + As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the + training down. Therefore remember to turn it off once the debugging needs have been met. + Args: + model (`nn.Module`): + The model to debug. + max_frames_to_save (`int`, *optional*, defaults to 21): + How many frames back to record + trace_batch_nums(`List[int]`, *optional*, defaults to `[]`): + Which batch numbers to trace (turns detection off) + abort_after_batch_num (`int``, *optional*): + Whether to abort after a certain batch number has finished + """ + + def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None): + self.model = model + self.trace_batch_nums = trace_batch_nums + self.abort_after_batch_num = abort_after_batch_num + + # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence + self.frames = collections.deque([], max_frames_to_save) + self.frame = [] + self.batch_number = 0 + self.total_calls = 0 + self.detected_overflow = False + self.prefix = " " + + self.analyse_model() + + self.register_forward_hook() + + def save_frame(self, frame=None): + if frame is not None: + self.expand_frame(frame) + self.frames.append("\n".join(self.frame)) + self.frame = [] # start a new frame + + def expand_frame(self, line): + self.frame.append(line) + + def trace_frames(self): + print("\n".join(self.frames)) + self.frames = [] + + def reset_saved_frames(self): + self.frames = [] + + def dump_saved_frames(self): + print(f"\nDetected inf/nan during batch_number={self.batch_number} " + f"Last {len(self.frames)} forward frames:" + f"{'abs min':8} {'abs max':8} metadata" + f"'\n'.join(self.frames)" + f"\n\n") + self.frames = [] + + def analyse_model(self): + # extract the fully qualified module names, to be able to report at run time. e.g.: + # encoder.block.2.layer.0.SelfAttention.o + # + # for shared weights only the first shared module name will be registered + self.module_names = {m: name for name, m in self.model.named_modules()} + # self.longest_module_name = max(len(v) for v in self.module_names.values()) + + def analyse_variable(self, var, ctx): + if torch.is_tensor(var): + self.expand_frame(self.get_abs_min_max(var, ctx)) + if self.detect_overflow(var, ctx): + self.detected_overflow = True + elif var is None: + self.expand_frame(f"{'None':>17} {ctx}") + else: + self.expand_frame(f"{'not a tensor':>17} {ctx}") + + def batch_start_frame(self): + self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***") + self.expand_frame(f"{'abs min':8} {'abs max':8} metadata") + + def batch_end_frame(self): + self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number - 1} ***\n\n") + + def create_frame(self, module, input, output): + self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}") + + # params + for name, p in module.named_parameters(recurse=False): + self.analyse_variable(p, name) + + # inputs + if isinstance(input, tuple): + for i, x in enumerate(input): + self.analyse_variable(x, f"input[{i}]") + else: + self.analyse_variable(input, "input") + + # outputs + if isinstance(output, tuple): + for i, x in enumerate(output): + # possibly a tuple of tuples + if isinstance(x, tuple): + for j, y in enumerate(x): + self.analyse_variable(y, f"output[{i}][{j}]") + else: + self.analyse_variable(x, f"output[{i}]") + else: + self.analyse_variable(output, "output") + + self.save_frame() + + def register_forward_hook(self): + self.model.apply(self._register_forward_hook) + + def _register_forward_hook(self, module): + module.register_forward_hook(self.forward_hook) + + def forward_hook(self, module, input, output): + # - input is a tuple of packed inputs (could be non-Tensors) + # - output could be a Tensor or a tuple of Tensors and non-Tensors + + last_frame_of_batch = False + + trace_mode = True if self.batch_number in self.trace_batch_nums else False + if trace_mode: + self.reset_saved_frames() + + if self.total_calls == 0: + self.batch_start_frame() + self.total_calls += 1 + + # count batch numbers - the very first forward hook of the batch will be called when the + # batch completes - i.e. it gets called very last - we know this batch has finished + if module == self.model: + self.batch_number += 1 + last_frame_of_batch = True + + self.create_frame(module, input, output) + + # if last_frame_of_batch: + # self.batch_end_frame() + + if trace_mode: + self.trace_frames() + + if last_frame_of_batch: + self.batch_start_frame() + + if self.detected_overflow and not trace_mode: + self.dump_saved_frames() + + # now we can abort, as it's pointless to continue running + raise ValueError( + "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. " + "Please scroll up above this traceback to see the activation values prior to this event." + ) + + # abort after certain batch if requested to do so + if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num: + raise ValueError( + f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to `abort_after_batch_num={self.abort_after_batch_num}` arg" + ) + + @staticmethod + def get_abs_min_max(var, ctx): + abs_var = var.abs() + return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}" + + @staticmethod + def detect_overflow(var, ctx): + """ + Report whether the tensor contains any `nan` or `inf` entries. + This is useful for detecting overflows/underflows and best to call right after the function that did some math that + modified the tensor in question. + This function contains a few other helper features that you can enable and tweak directly if you want to track + various other things. + Args: + var: the tensor variable to check + ctx: the message to print as a context + Return: + `True` if `inf` or `nan` was detected, `False` otherwise + """ + detected = False + if torch.isnan(var).any().item(): + detected = True + print(f"{ctx} has nans") + if torch.isinf(var).any().item(): + detected = True + print(f"{ctx} has infs") + if var.dtype == torch.float32 and torch.ge(var.abs(), 65535).any().item(): + detected = True + print(f"{ctx} has overflow values {var.abs().max().item()}.") + # if needed to monitor large elements can enable the following + if 0: # and detected: + n100 = var[torch.ge(var.abs(), 100)] + if n100.numel() > 0: + print(f"{ctx}: n100={n100.numel()}") + n1000 = var[torch.ge(var.abs(), 1000)] + if n1000.numel() > 0: + print(f"{ctx}: n1000={n1000.numel()}") + n10000 = var[torch.ge(var.abs(), 10000)] + if n10000.numel() > 0: + print(f"{ctx}: n10000={n10000.numel()}") + + if 0: + print(f"min={var.min():9.2e} max={var.max():9.2e}") + + if 0: + print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})") + + return detected diff --git a/src/enhancer.py b/src/enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2efa4dee0c15c88df063c218fe4be6f80a3a0c --- /dev/null +++ b/src/enhancer.py @@ -0,0 +1,139 @@ +import math +import torch +import torch.nn.functional as F +from diffusers.models import AutoencoderKL +from transformers import T5EncoderModel, T5Tokenizer +from safetensors.torch import load_model + +from diffusion import IDDPM, DPMS +from diffusion.utils.misc import read_config +from diffusion.model.nets import PixArtMS_XL_2, ControlPixArtMSMVHalfWithEncoder +from diffusion.utils.data import ASPECT_RATIO_512_TEST +from utils.camera import get_camera_poses +from utils.postprocess import adaptive_instance_normalization, wavelet_reconstruction + + +class Enhancer: + def __init__(self, model_path, config_path): + self.config = read_config(config_path) + + self.image_size = self.config.image_size + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.weight_dtype = torch.float16 + + self._load_model(model_path, self.config.pipeline_load_from) + + def _load_model(self, model_path, pipeline_load_from): + self.tokenizer = T5Tokenizer.from_pretrained(pipeline_load_from, subfolder="tokenizer") + self.text_encoder = T5EncoderModel.from_pretrained(pipeline_load_from, subfolder="text_encoder", torch_dtype=self.weight_dtype).to(self.device) + + self.vae = AutoencoderKL.from_pretrained(pipeline_load_from, subfolder="vae", torch_dtype=self.weight_dtype).to(self.device) + del self.vae.encoder # we do not use vae encoder + + # only support fixed latent size currently + latent_size = self.image_size // 8 + lewei_scale = {512: 1, 1024: 2} + model_kwargs = { + "model_max_length": self.config.model_max_length, + "qk_norm": self.config.qk_norm, + "kv_compress_config": self.config.kv_compress_config if self.config.kv_compress else None, + "micro_condition": self.config.micro_condition, + "use_crossview_module": getattr(self.config, 'use_crossview_module', False), + } + model = PixArtMS_XL_2(input_size=latent_size, pe_interpolation=lewei_scale[self.image_size], **model_kwargs).to(self.device) + model = ControlPixArtMSMVHalfWithEncoder(model).to(self.weight_dtype).to(self.device) + load_model(model, model_path) + model.eval() + self.model = model + + self.noise_maker = IDDPM(str(self.config.train_sampling_steps)) + + @torch.no_grad() + def _encode_prompt(self, text_prompt, n_views): + txt_tokens = self.tokenizer( + text_prompt, + max_length=self.config.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt" + ).to(self.device) + caption_embs = self.text_encoder( + txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask)[0][:, None] + emb_masks = txt_tokens.attention_mask + + caption_embs = caption_embs.repeat_interleave(n_views, dim=0).to(self.weight_dtype) + emb_masks = emb_masks.repeat_interleave(n_views, dim=0).to(self.weight_dtype) + + return caption_embs, emb_masks + + @torch.no_grad() + def inference(self, mv_imgs, c2ws, prompt="", fov=math.radians(49.1), noise_level=120, cfg_scale=4.5, sample_steps=20, color_shift=None): + mv_imgs = F.interpolate(mv_imgs, size=(512, 512), mode='bilinear', align_corners=False) + + n_views = mv_imgs.shape[0] + # pixle-sigma input tensor range is [-1, 1] + mv_imgs = 2.*mv_imgs - 1. + + originial_mv_imgs = mv_imgs.clone().to(self.device) + if noise_level == 0: + noise_level = torch.zeros((n_views,)).long().to(self.device) + else: + noise_level = noise_level * torch.ones((n_views,)).long().to(self.device) + mv_imgs = self.noise_maker.q_sample(mv_imgs.to(self.device), noise_level-1) + + cur_camera_pose, epipolar_constrains, cam_distances = get_camera_poses(c2ws=c2ws, fov=fov, h=mv_imgs.size(-2), w=mv_imgs.size(-1)) + epipolar_constrains = epipolar_constrains.to(self.device) + cam_distances = cam_distances.to(self.weight_dtype).to(self.device) + + caption_embs, emb_masks = self._encode_prompt(prompt, n_views) + null_y = self.model.y_embedder.y_embedding[None].repeat(n_views, 1, 1)[:, None] + + latent_size_h, latent_size_w = mv_imgs.size(-2) // 8, mv_imgs.size(-1) // 8 + z = torch.randn(n_views, 4, latent_size_h, latent_size_w, device=self.device) + z_lq = self.model.encode( + mv_imgs.to(self.weight_dtype).to(self.device), + cur_camera_pose.to(self.weight_dtype).to(self.device), + n_views=n_views, + ) + + model_kwargs = dict( + c=torch.cat([z_lq] * 2), + data_info={}, + mask=emb_masks, + noise_level=torch.cat([noise_level] * 2), + epipolar_constrains=torch.cat([epipolar_constrains] * 2), + cam_distances=torch.cat([cam_distances] * 2), + n_views=n_views, + ) + dpm_solver = DPMS( + self.model.forward_with_dpmsolver, + condition=caption_embs, + uncondition=null_y, + cfg_scale=cfg_scale, + model_kwargs=model_kwargs + ) + samples = dpm_solver.sample( + z, + steps=sample_steps, + order=2, + skip_type="time_uniform", + method="multistep", + disable_progress_ui=False, + ) + + samples = samples.to(self.weight_dtype) + + output_mv_imgs = self.vae.decode(samples / self.vae.config.scaling_factor).sample + + if color_shift == "adain": + for i, output_mv_img in enumerate(output_mv_imgs): + output_mv_imgs[i] = adaptive_instance_normalization(output_mv_img.unsqueeze(0), originial_mv_imgs[i:i+1]).squeeze(0) + elif color_shift == "wavelet": + for i, output_mv_img in enumerate(output_mv_imgs): + output_mv_imgs[i] = wavelet_reconstruction(output_mv_img.unsqueeze(0), originial_mv_imgs[i:i+1]).squeeze(0) + + output_mv_imgs = torch.clamp((output_mv_imgs + 1.) / 2., 0, 1) + + torch.cuda.empty_cache() + return output_mv_imgs \ No newline at end of file diff --git a/src/utils/camera.py b/src/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea8db689038740b927cf390903b6b6fbf467f98 --- /dev/null +++ b/src/utils/camera.py @@ -0,0 +1,338 @@ +import torch +from kornia.core import Tensor, concatenate + +import torch +import math +import numpy as np +from torch import nn +from kiui.cam import orbit_camera + + +# gaussian splatting utils.graphics_utils +def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + + C2W = np.linalg.inv(Rt) + cam_center = C2W[:3, 3] + cam_center = (cam_center + translate) * scale + C2W[:3, 3] = cam_center + Rt = np.linalg.inv(C2W) + return np.float32(Rt) + + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + + +def fov2focal(fov, pixels): + return pixels / (2 * math.tan(fov / 2)) + + +def focal2fov(focal, pixels): + return 2*math.atan(pixels/(2*focal)) + + +# gaussian splatting scene.camera +class Camera(nn.Module): + def __init__(self, R, T, FoVx, FoVy, + trans=np.array([0.0, 0.0, 0.0]), scale=1.0 + ): + super(Camera, self).__init__() + + self.R = R + self.T = T + self.FoVx = FoVx + self.FoVy = FoVy + + self.zfar = 100.0 + self.znear = 0.01 + + self.trans = trans + self.scale = scale + + self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1) + self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1) + self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) + self.camera_center = self.world_view_transform.inverse()[3, :3] + + +# gaussian splatting utils.camera_utils +def loadCam(c2w, fovx, image_height=512, image_width=512): + # load_camera + w2c = np.linalg.inv(c2w) + + R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code + T = w2c[:3, 3] + + fovy = focal2fov(fov2focal(fovx, image_width), image_height) + FovY = fovy + FovX = fovx + + return Camera(R=R, T=T, + FoVx=FovX, FoVy=FovY) + + +# epipolar calculation related +@torch.no_grad() +def fundamental_from_projections(P1: Tensor, P2: Tensor) -> Tensor: + r"""Get the Fundamental matrix from Projection matrices. + + Args: + P1: The projection matrix from first camera with shape :math:`(*, 3, 4)`. + P2: The projection matrix from second camera with shape :math:`(*, 3, 4)`. + + Returns: + The fundamental matrix with shape :math:`(*, 3, 3)`. + """ + if not (len(P1.shape) >= 2 and P1.shape[-2:] == (3, 4)): + raise AssertionError(P1.shape) + if not (len(P2.shape) >= 2 and P2.shape[-2:] == (3, 4)): + raise AssertionError(P2.shape) + if P1.shape[:-2] != P2.shape[:-2]: + raise AssertionError + + def vstack(x: Tensor, y: Tensor) -> Tensor: + return concatenate([x, y], dim=-2) + + X1 = P1[..., 1:, :] + X2 = vstack(P1[..., 2:3, :], P1[..., 0:1, :]) + X3 = P1[..., :2, :] + + Y1 = P2[..., 1:, :] + Y2 = vstack(P2[..., 2:3, :], P2[..., 0:1, :]) + Y3 = P2[..., :2, :] + + X1Y1, X2Y1, X3Y1 = vstack(X1, Y1), vstack(X2, Y1), vstack(X3, Y1) + X1Y2, X2Y2, X3Y2 = vstack(X1, Y2), vstack(X2, Y2), vstack(X3, Y2) + X1Y3, X2Y3, X3Y3 = vstack(X1, Y3), vstack(X2, Y3), vstack(X3, Y3) + + F_vec = torch.cat( + [ + X1Y1.det().reshape(-1, 1), + X2Y1.det().reshape(-1, 1), + X3Y1.det().reshape(-1, 1), + X1Y2.det().reshape(-1, 1), + X2Y2.det().reshape(-1, 1), + X3Y2.det().reshape(-1, 1), + X1Y3.det().reshape(-1, 1), + X2Y3.det().reshape(-1, 1), + X3Y3.det().reshape(-1, 1), + ], + dim=1, + ) + + return F_vec.view(*P1.shape[:-2], 3, 3) + + +def get_fundamental_matrix_with_H(cam1, cam2, current_H, current_W): + NDC_2_pixel = torch.tensor([[current_W / 2, 0, current_W / 2], [0, current_H / 2, current_H / 2], [0, 0, 1]]) + # NDC_2_pixel_inversed = torch.inverse(NDC_2_pixel) + NDC_2_pixel = NDC_2_pixel.float() + cam_1_tranformation = cam1.full_proj_transform[:, [0,1,3]].T.float() + cam_2_tranformation = cam2.full_proj_transform[:, [0,1,3]].T.float() + cam_1_pixel = NDC_2_pixel@cam_1_tranformation + cam_2_pixel = NDC_2_pixel@cam_2_tranformation + + # print(NDC_2_pixel.dtype, cam_1_tranformation.dtype, cam_2_tranformation.dtype, cam_1_pixel.dtype, cam_2_pixel.dtype) + + cam_1_pixel = cam_1_pixel.float() + cam_2_pixel = cam_2_pixel.float() + # print("cam_1", cam_1_pixel.dtype, cam_1_pixel.shape) + # print("cam_2", cam_2_pixel.dtype, cam_2_pixel.shape) + # print(NDC_2_pixel@cam_1_tranformation, NDC_2_pixel@cam_2_tranformation) + return fundamental_from_projections(cam_1_pixel, cam_2_pixel) + + +def point_to_line_dist(points, lines): + """ + Calculate the distance from points to lines in 2D. + points: Nx3 + lines: Mx3 + + return distance: NxM + """ + numerator = torch.abs(lines @ points.T) + denominator = torch.linalg.norm(lines[:,:2], dim=1, keepdim=True) + return numerator / denominator + + +def compute_epipolar_constrains(cam1, cam2, current_H=64, current_W=64): + n_frames = 1 + # sequence_length = current_W * current_H + fundamental_matrix_1 = [] + + fundamental_matrix_1.append(get_fundamental_matrix_with_H(cam1, cam2, current_H, current_W)) + fundamental_matrix_1 = torch.stack(fundamental_matrix_1, dim=0) + + x = torch.arange(current_W) + y = torch.arange(current_H) + x, y = torch.meshgrid(x, y, indexing='xy') + x = x.reshape(-1) + y = y.reshape(-1) + heto_cam2 = torch.stack([x, y, torch.ones(size=(len(x),))], dim=1).view(-1, 3) + heto_cam1 = torch.stack([x, y, torch.ones(size=(len(x),))], dim=1).view(-1, 3) + # epipolar_line: n_frames X seq_len, 3 + line1 = (heto_cam2.unsqueeze(0).repeat(n_frames, 1, 1) @ fundamental_matrix_1).view(-1, 3) + + distance1 = point_to_line_dist(heto_cam1, line1) + + idx1_epipolar = distance1 > 1 # sequence_length x sequence_lengths + + return idx1_epipolar + + +def compute_camera_distance(cams, key_cams): + cam_centers = [cam.camera_center for cam in cams] + key_cam_centers = [cam.camera_center for cam in key_cams] + cam_centers = torch.stack(cam_centers) + key_cam_centers = torch.stack(key_cam_centers) + cam_distance = torch.cdist(cam_centers, key_cam_centers) + + return cam_distance + + +def get_intri(target_im=None, h=None, w=None, normalize=False): + if target_im is None: + assert (h is not None and w is not None) + else: + h, w = target_im.shape[:2] + + fx = fy = 1422.222 + res_raw = 1024 + f_x = f_y = fx * h / res_raw + K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) + if normalize: # center is [0.5, 0.5], eg3d renderer tradition + K[:2] /= h + return K + + +def normalize_camera(c, c_frame0): + B = c.shape[0] + camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 + canonical_camera_poses = c_frame0[:, :16].reshape(1, 4, 4) + inverse_canonical_pose = np.linalg.inv(canonical_camera_poses) + inverse_canonical_pose = np.repeat(inverse_canonical_pose, B, 0) + + cam_radius = np.linalg.norm( + c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3], + axis=-1, + keepdims=False) # since g-buffer adopts dynamic radius here. + + frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0) + frame1_fixed_pos[:, 2, -1] = -cam_radius + + transform = frame1_fixed_pos @ inverse_canonical_pose + + new_camera_poses = np.repeat( + transform, 1, axis=0 + ) @ camera_poses # [v, 4, 4]. np.repeat() is th.repeat_interleave() + + c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], + axis=-1) + + return c + + +def gen_rays(c2w, intrinsics, h, w): + # Generate rays + yy, xx = torch.meshgrid( + torch.arange(h, dtype=torch.float32) + 0.5, + torch.arange(w, dtype=torch.float32) + 0.5, + indexing='ij') + + # normalize to 0-1 pixel range + yy = yy / h + xx = xx / w + + cx, cy, fx, fy = intrinsics[2], intrinsics[ + 5], intrinsics[0], intrinsics[4] + + xx = (xx - cx) / fx + yy = (yy - cy) / fy + zz = torch.ones_like(xx) + dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention + dirs /= torch.norm(dirs, dim=-1, keepdim=True) + dirs = dirs.reshape(-1, 3, 1) + del xx, yy, zz + + dirs = (c2w[None, :3, :3] @ dirs)[..., 0] + + origins = c2w[None, :3, 3].expand(h * w, -1).contiguous() + origins = origins.view(h, w, 3) + dirs = dirs.view(h, w, 3) + + return origins, dirs + + +def get_c2ws(elevations, amuziths, camera_radius=1.5): + c2ws = np.stack([ + orbit_camera(elevation, amuzith, radius=camera_radius) for elevation, amuzith in zip(elevations, amuziths) + ], axis=0) + + # change kiui opengl camera system to our camera system + c2ws[:, :3, 1:3] *= -1 + c2ws[:, [0, 1, 2], :] = c2ws[:, [2, 0, 1], :] + c2ws = c2ws.reshape(-1, 16) + + return c2ws + + +def get_camera_poses(c2ws, fov, h, w, intrinsics=None): + if intrinsics is None: + intrinsics = get_intri(h=64, w=64, normalize=True).reshape(9) + + c2ws = normalize_camera(c2ws, c2ws[0:1]) + + rays_pluckers = [] + c2ws = c2ws.reshape((-1, 4, 4)) + c2ws = torch.from_numpy(c2ws).float() + + gs_cams = [] + for i, c2w in enumerate(c2ws): + gs_cams.append(loadCam(c2w.numpy(), fov, h, w)) + rays_o, rays_d = gen_rays(c2w, intrinsics, h, w) + rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], + dim=-1) # [h, w, 6] + rays_pluckers.append(rays_plucker.permute(2, 0, 1)) # [6, h, w] + + n_views = len(gs_cams) + epipolar_constrains = [] + cam_distances = [] + for i in range(n_views): + cur_epipolar_constrains = [] + kv_idxs = [(i-1)%n_views, (i+1)%n_views] + for kv_idx in kv_idxs: + # False means that the position is on the epipolar line + cam_epipolar_constrain = compute_epipolar_constrains(gs_cams[kv_idx], gs_cams[i], current_H=h//16, current_W=w//16) + cur_epipolar_constrains.append(cam_epipolar_constrain) + + cam_distances.append(compute_camera_distance([gs_cams[i]], [gs_cams[kv_idxs[0]], gs_cams[kv_idxs[1]]])) # 1, 2 + epipolar_constrains.append(torch.stack(cur_epipolar_constrains, dim=0)) + + rays_pluckers = torch.stack(rays_pluckers) # [v, 6, h, w] + cam_distances = torch.cat(cam_distances, dim=0) # [v, 2] + epipolar_constrains = torch.stack(epipolar_constrains, dim=0) # [v, 2, 1024, 1024] + + return rays_pluckers, epipolar_constrains, cam_distances \ No newline at end of file diff --git a/src/utils/postprocess.py b/src/utils/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed718c71cffcd8d93c301dbe7e7a7d6b80579d8 --- /dev/null +++ b/src/utils/postprocess.py @@ -0,0 +1,84 @@ +import torch +from torch.nn import functional as F + + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +def wavelet_blur(image, radius: int): + """ + Apply wavelet blur to the input tensor. + """ + # input shape: (1, 3, H, W) + # convolution kernel + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + # add channel dimensions to the kernel to make it a 4D tensor + kernel = kernel[None, None] + # repeat the kernel across all input channels + kernel = kernel.repeat(3, 1, 1, 1) + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') + # apply convolution + output = F.conv2d(image, kernel, groups=3, dilation=radius) + return output + + +def wavelet_decomposition(image, levels=5): + """ + Apply wavelet decomposition to the input tensor. + This function only returns the low frequency & the high frequency. + """ + high_freq = torch.zeros_like(image) + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq += (image - low_freq) + image = low_freq + + return high_freq, low_freq + + +def wavelet_reconstruction(content_feat, style_feat): + """ + Apply wavelet decomposition, so that the content will have the same color as the style. + """ + # calculate the wavelet decomposition of the content feature + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq + # calculate the wavelet decomposition of the style feature + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq + # reconstruct the content feature with the style's high frequency + return content_high_freq + style_low_freq \ No newline at end of file