|
import torch |
|
from torchvision.utils import make_grid |
|
import math |
|
from PIL import Image |
|
from diffusion import create_diffusion |
|
from diffusers.models import AutoencoderKL |
|
import gradio as gr |
|
from imagenet_class_data import IMAGENET_1K_CLASSES |
|
from download import find_model |
|
from models import DiT_XL_2 |
|
|
|
|
|
def load_model(image_size=256): |
|
assert image_size in [256, 512] |
|
latent_size = image_size // 8 |
|
model = DiT_XL_2(input_size=latent_size).to(device) |
|
state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt") |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
return model |
|
|
|
|
|
torch.set_grad_enabled(False) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
find_model(f"DiT-XL-2-512x512.pt") |
|
model = load_model(image_size=256) |
|
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device) |
|
current_image_size = 256 |
|
current_vae_model = "stabilityai/sd-vae-ft-mse" |
|
|
|
|
|
def generate(image_size, vae_model, class_label, cfg_scale, num_sampling_steps, seed): |
|
n = 1 |
|
image_size = int(image_size.split("x")[0]) |
|
global current_image_size |
|
if image_size != current_image_size: |
|
global model |
|
model = model.to("cpu") |
|
del model |
|
if device == "cuda": |
|
torch.cuda.empty_cache() |
|
model = load_model(image_size=image_size) |
|
current_image_size = image_size |
|
|
|
global current_vae_model |
|
if vae_model != current_vae_model: |
|
global vae |
|
if device == "cuda": |
|
vae.to("cpu") |
|
del vae |
|
vae = AutoencoderKL.from_pretrained(vae_model).to(device) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
diffusion = create_diffusion(str(num_sampling_steps)) |
|
|
|
|
|
latent_size = image_size // 8 |
|
z = torch.randn(n, 4, latent_size, latent_size, device=device) |
|
y = torch.tensor([class_label] * n, device=device) |
|
|
|
|
|
z = torch.cat([z, z], 0) |
|
y_null = torch.tensor([1000] * n, device=device) |
|
y = torch.cat([y, y_null], 0) |
|
model_kwargs = dict(y=y, cfg_scale=cfg_scale) |
|
|
|
|
|
samples = diffusion.p_sample_loop( |
|
model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device |
|
) |
|
samples, _ = samples.chunk(2, dim=0) |
|
samples = vae.decode(samples / 0.18215).sample |
|
|
|
|
|
samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy() |
|
samples = [Image.fromarray(sample) for sample in samples] |
|
return samples |
|
|
|
|
|
description = '''This is a demo of our DiT image generation models. DiTs are a new class of diffusion models with |
|
transformer backbones. They are class-conditional models trained on ImageNet-1K, and they outperform prior DDPMs.''' |
|
|
|
duplicate = '''Skip the queue by duplicating this space and upgrading to GPU in settings |
|
<a href="https://huggingface.co./spaces/wpeebles/DiT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>''' |
|
|
|
project_links = ''' |
|
<p style="text-align: center"> |
|
<a href="https://www.wpeebles.com/DiT.html">Project Page</a> · |
|
<a href="http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb">Colab</a> · |
|
<a href="http://arxiv.org/abs/2212.09748">Paper</a> · |
|
<a href="https://github.com/facebookresearch/DiT">GitHub</a></p>''' |
|
|
|
examples = [ |
|
["512x512", "stabilityai/sd-vae-ft-mse", "golden retriever", 4.0, 200, 4, 1000], |
|
["512x512", "stabilityai/sd-vae-ft-mse", "macaw", 4.0, 200, 4, 1], |
|
["512x512", "stabilityai/sd-vae-ft-mse", "balloon", 4.0, 200, 4, 1], |
|
["512x512", "stabilityai/sd-vae-ft-mse", "cliff, drop, drop-off", 4.0, 200, 4, 7], |
|
["512x512", "stabilityai/sd-vae-ft-mse", "Pembroke, Pembroke Welsh corgi", 4.0, 200, 4, 0], |
|
["256x256", "stabilityai/sd-vae-ft-mse", "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", 4.0, 200, |
|
4, 1], |
|
["256x256", "stabilityai/sd-vae-ft-mse", "teddy, teddy bear", 4.0, 200, 4, 3], |
|
["256x256", "stabilityai/sd-vae-ft-mse", "cheeseburger", 4.0, 200, 4, 2], |
|
|
|
] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("<h1 style='text-align: center'>Scalable Diffusion Models with Transformers (DiT)</h1>") |
|
gr.Markdown(project_links) |
|
gr.Markdown(description) |
|
gr.Markdown(duplicate) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem('Generate'): |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
image_size = gr.inputs.Radio(choices=["256x256", "512x512"], default="256x256", label='DiT Model Resolution') |
|
vae_model = gr.inputs.Radio(choices=["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"], |
|
default="stabilityai/sd-vae-ft-mse", label='VAE Decoder') |
|
with gr.Row(): |
|
i1k_class = gr.inputs.Dropdown( |
|
list(IMAGENET_1K_CLASSES.values()), |
|
default='golden retriever', |
|
type="index", label='ImageNet-1K Class' |
|
) |
|
cfg_scale = gr.inputs.Slider(minimum=1, maximum=25, step=0.1, default=4.0, label='Classifier-free Guidance Scale') |
|
steps = gr.inputs.Slider(minimum=4, maximum=1000, step=1, default=75, label='Sampling Steps') |
|
|
|
seed = gr.inputs.Number(default=0, label='Seed') |
|
button = gr.Button("Generate", variant="primary") |
|
with gr.Column(): |
|
output = gr.Gallery(label='Generated Images').style(grid=[2], height="auto") |
|
button.click(generate, inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, seed], outputs=[output]) |
|
with gr.Row(): |
|
ex = gr.Examples(examples=examples, fn=generate, |
|
inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, seed], |
|
outputs=[output], |
|
cache_examples=True) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|