Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torchvision.transforms as transforms | |
from torchvision.utils import make_grid | |
import gradio as gr | |
from model import ( | |
UNet, | |
VQVAE, | |
LinearNoiseScheduler, | |
get_tokenizer_and_model, | |
get_text_representation, | |
dataset_params, | |
diffusion_params, | |
ldm_params, | |
autoencoder_params, | |
train_params, | |
) | |
from huggingface_hub import hf_hub_download | |
import spaces | |
import json | |
print("Gradio version:", gr.__version__) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Currently running on {device}") | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
# Download config and checkpoint files from HF Hub | |
config_path = hf_hub_download( | |
repo_id="RishabA/celeba-cond-ddpm", filename="config.json" | |
) | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
ldm_ckpt_path = hf_hub_download( | |
repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/ddpm_ckpt_class_cond.pth" | |
) | |
vae_ckpt_path = hf_hub_download( | |
repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/vqvae_autoencoder_ckpt.pth" | |
) | |
# Instantiate and load the models | |
unet = UNet(config["autoencoder_params"]["z_channels"], config["ldm_params"]).to(device) | |
vae = VQVAE( | |
config["dataset_params"]["image_channels"], config["autoencoder_params"] | |
).to(device) | |
unet_state = torch.load(ldm_ckpt_path, map_location=device) | |
unet.load_state_dict(unet_state["model_state_dict"]) | |
print(unet_state["epoch"]) | |
vae_state = torch.load(vae_ckpt_path, map_location=device) | |
vae.load_state_dict(vae_state["model_state_dict"]) | |
unet.eval() | |
vae.eval() | |
print("Model and checkpoints loaded successfully!") | |
def sample_ddpm_inference(text_prompt): | |
""" | |
Given a text prompt and (optionally) an image condition (as a PIL image), | |
sample from the diffusion model and return a generated image (PIL image). | |
""" | |
mask_image_pil = None | |
guidance_scale = 2.0 | |
image_display_rate = 4 | |
# Create noise scheduler | |
scheduler = LinearNoiseScheduler( | |
num_timesteps=diffusion_params["num_timesteps"], | |
beta_start=diffusion_params["beta_start"], | |
beta_end=diffusion_params["beta_end"], | |
) | |
# Get conditioning config from ldm_params | |
condition_config = ldm_params.get("condition_config", None) | |
condition_types = ( | |
condition_config.get("condition_types", []) | |
if condition_config is not None | |
else [] | |
) | |
# Load text tokenizer/model for conditioning | |
text_model_type = condition_config["text_condition_config"]["text_embed_model"] | |
text_tokenizer, text_model = get_tokenizer_and_model(text_model_type, device=device) | |
# Get empty text representation for classifier-free guidance | |
empty_text_embed = get_text_representation([""], text_tokenizer, text_model, device) | |
# Get text representation of the input prompt | |
text_prompt_embed = get_text_representation( | |
[text_prompt], text_tokenizer, text_model, device | |
) | |
# Prepare image conditioning: | |
# If the user uploaded a mask image (should be a PIL image), convert it; otherwise, use zeros. | |
if "image" in condition_types: | |
if mask_image_pil is not None: | |
mask_transform = transforms.Compose( | |
[ | |
transforms.Resize( | |
( | |
ldm_params["condition_config"]["image_condition_config"][ | |
"image_condition_h" | |
], | |
ldm_params["condition_config"]["image_condition_config"][ | |
"image_condition_w" | |
], | |
) | |
), | |
transforms.ToTensor(), | |
] | |
) | |
mask_tensor = ( | |
mask_transform(mask_image_pil).unsqueeze(0).to(device) | |
) # (1, channels, H, W) | |
else: | |
# Create a zero mask with the required number of channels (e.g. 18) | |
ic = ldm_params["condition_config"]["image_condition_config"][ | |
"image_condition_input_channels" | |
] | |
H = ldm_params["condition_config"]["image_condition_config"][ | |
"image_condition_h" | |
] | |
W = ldm_params["condition_config"]["image_condition_config"][ | |
"image_condition_w" | |
] | |
mask_tensor = torch.zeros((1, ic, H, W), device=device) | |
else: | |
mask_tensor = None | |
# Build conditioning dictionaries for classifier-free guidance: | |
# For unconditional, we use empty text and zero mask. | |
uncond_input = {} | |
cond_input = {} | |
if "text" in condition_types: | |
uncond_input["text"] = empty_text_embed | |
cond_input["text"] = text_prompt_embed | |
if "image" in condition_types: | |
# Use zeros for unconditioning, and the provided mask for conditioning. | |
uncond_input["image"] = torch.zeros_like(mask_tensor) | |
cond_input["image"] = mask_tensor | |
# Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat) | |
# For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32. | |
latent_size = dataset_params["image_size"] // ( | |
2 ** sum(autoencoder_params["down_sample"]) | |
) | |
batch = train_params["num_samples"] | |
z_channels = autoencoder_params["z_channels"] | |
# Sample initial latent noise | |
xt = torch.randn((batch, z_channels, latent_size, latent_size), device=device) | |
# Sampling loop (reverse diffusion) | |
T = diffusion_params["num_timesteps"] | |
for i in reversed(range(T)): | |
t = torch.full((batch,), i, dtype=torch.long, device=device) | |
with torch.no_grad(): | |
# Get conditional noise prediction | |
noise_pred_cond = unet(xt, t, cond_input) | |
if guidance_scale > 1: | |
noise_pred_uncond = unet(xt, t, uncond_input) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_cond - noise_pred_uncond | |
) | |
else: | |
noise_pred = noise_pred_cond | |
xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, t) | |
if i % image_display_rate == 0 or i == 0: | |
# Decode current latent into image | |
generated = vae.decode(xt) | |
generated = torch.clamp(generated, -1, 1) | |
generated = (generated + 1) / 2 # scale to [0,1] | |
grid = make_grid(generated, nrow=1) | |
pil_img = transforms.ToPILImage()(grid.cpu()) | |
yield pil_img | |
css_str = """ | |
.title { | |
font-size: 48px; | |
text-align: center; | |
margin-top: 20px; | |
} | |
.description { | |
font-size: 20px; | |
text-align: center; | |
margin-bottom: 40px; | |
} | |
""" | |
with gr.Blocks(css=css_str) as demo: | |
gr.Markdown("<div class='title'>Conditioned Latent Diffusion with CelebA</div>") | |
gr.Markdown( | |
"<div class='description'>Enter a text prompt and (optionally) upload a mask image for conditioning; the generated image will update as the reverse diffusion progresses.</div>" | |
) | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="Text Prompt", | |
lines=2, | |
placeholder="E.g., 'He is a man with brown hair.'", | |
) | |
generate_button = gr.Button("Generate Image") | |
output_image = gr.Image(label="Generated Image", type="pil") | |
generate_button.click( | |
fn=sample_ddpm_inference, | |
inputs=[text_input], | |
outputs=[output_image], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |