Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,620 Bytes
ddc10aa 036212e ddc10aa c74d396 036212e c74d396 e6c561b ddc10aa 8fb1bf9 ddc10aa 0f174d0 ddc10aa c74d396 ddc10aa c74d396 ddc10aa 96b2858 ddc10aa b0aca51 ddc10aa 036212e ddc10aa b0aca51 9a14442 ee36817 036212e 0f174d0 20f73e3 0f174d0 036212e 0f174d0 036212e 0f174d0 ddc10aa c39621e b0aca51 c39621e ddc10aa 20f73e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import torch
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import gradio as gr
from model import (
UNet,
VQVAE,
LinearNoiseScheduler,
get_tokenizer_and_model,
get_text_representation,
dataset_params,
diffusion_params,
ldm_params,
autoencoder_params,
train_params,
)
from huggingface_hub import hf_hub_download
import spaces
import json
print("Gradio version:", gr.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Currently running on {device}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Download config and checkpoint files from HF Hub
config_path = hf_hub_download(
repo_id="RishabA/celeba-cond-ddpm", filename="config.json"
)
with open(config_path, "r") as f:
config = json.load(f)
ldm_ckpt_path = hf_hub_download(
repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/ddpm_ckpt_class_cond.pth"
)
vae_ckpt_path = hf_hub_download(
repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/vqvae_autoencoder_ckpt.pth"
)
# Instantiate and load the models
unet = UNet(config["autoencoder_params"]["z_channels"], config["ldm_params"]).to(device)
vae = VQVAE(
config["dataset_params"]["image_channels"], config["autoencoder_params"]
).to(device)
unet_state = torch.load(ldm_ckpt_path, map_location=device)
unet.load_state_dict(unet_state["model_state_dict"])
print(unet_state["epoch"])
vae_state = torch.load(vae_ckpt_path, map_location=device)
vae.load_state_dict(vae_state["model_state_dict"])
unet.eval()
vae.eval()
print("Model and checkpoints loaded successfully!")
@spaces.GPU
def sample_ddpm_inference(text_prompt):
"""
Given a text prompt and (optionally) an image condition (as a PIL image),
sample from the diffusion model and return a generated image (PIL image).
"""
mask_image_pil = None
guidance_scale = 2.0
image_display_rate = 4
# Create noise scheduler
scheduler = LinearNoiseScheduler(
num_timesteps=diffusion_params["num_timesteps"],
beta_start=diffusion_params["beta_start"],
beta_end=diffusion_params["beta_end"],
)
# Get conditioning config from ldm_params
condition_config = ldm_params.get("condition_config", None)
condition_types = (
condition_config.get("condition_types", [])
if condition_config is not None
else []
)
# Load text tokenizer/model for conditioning
text_model_type = condition_config["text_condition_config"]["text_embed_model"]
text_tokenizer, text_model = get_tokenizer_and_model(text_model_type, device=device)
# Get empty text representation for classifier-free guidance
empty_text_embed = get_text_representation([""], text_tokenizer, text_model, device)
# Get text representation of the input prompt
text_prompt_embed = get_text_representation(
[text_prompt], text_tokenizer, text_model, device
)
# Prepare image conditioning:
# If the user uploaded a mask image (should be a PIL image), convert it; otherwise, use zeros.
if "image" in condition_types:
if mask_image_pil is not None:
mask_transform = transforms.Compose(
[
transforms.Resize(
(
ldm_params["condition_config"]["image_condition_config"][
"image_condition_h"
],
ldm_params["condition_config"]["image_condition_config"][
"image_condition_w"
],
)
),
transforms.ToTensor(),
]
)
mask_tensor = (
mask_transform(mask_image_pil).unsqueeze(0).to(device)
) # (1, channels, H, W)
else:
# Create a zero mask with the required number of channels (e.g. 18)
ic = ldm_params["condition_config"]["image_condition_config"][
"image_condition_input_channels"
]
H = ldm_params["condition_config"]["image_condition_config"][
"image_condition_h"
]
W = ldm_params["condition_config"]["image_condition_config"][
"image_condition_w"
]
mask_tensor = torch.zeros((1, ic, H, W), device=device)
else:
mask_tensor = None
# Build conditioning dictionaries for classifier-free guidance:
# For unconditional, we use empty text and zero mask.
uncond_input = {}
cond_input = {}
if "text" in condition_types:
uncond_input["text"] = empty_text_embed
cond_input["text"] = text_prompt_embed
if "image" in condition_types:
# Use zeros for unconditioning, and the provided mask for conditioning.
uncond_input["image"] = torch.zeros_like(mask_tensor)
cond_input["image"] = mask_tensor
# Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat)
# For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32.
latent_size = dataset_params["image_size"] // (
2 ** sum(autoencoder_params["down_sample"])
)
batch = train_params["num_samples"]
z_channels = autoencoder_params["z_channels"]
# Sample initial latent noise
xt = torch.randn((batch, z_channels, latent_size, latent_size), device=device)
# Sampling loop (reverse diffusion)
T = diffusion_params["num_timesteps"]
for i in reversed(range(T)):
t = torch.full((batch,), i, dtype=torch.long, device=device)
with torch.no_grad():
# Get conditional noise prediction
noise_pred_cond = unet(xt, t, cond_input)
if guidance_scale > 1:
noise_pred_uncond = unet(xt, t, uncond_input)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
else:
noise_pred = noise_pred_cond
xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, t)
if i % image_display_rate == 0 or i == 0:
# Decode current latent into image
generated = vae.decode(xt)
generated = torch.clamp(generated, -1, 1)
generated = (generated + 1) / 2 # scale to [0,1]
grid = make_grid(generated, nrow=1)
pil_img = transforms.ToPILImage()(grid.cpu())
yield pil_img
css_str = """
.title {
font-size: 48px;
text-align: center;
margin-top: 20px;
}
.description {
font-size: 20px;
text-align: center;
margin-bottom: 40px;
}
"""
with gr.Blocks(css=css_str) as demo:
gr.Markdown("<div class='title'>Conditioned Latent Diffusion with CelebA</div>")
gr.Markdown(
"<div class='description'>Enter a text prompt and (optionally) upload a mask image for conditioning; the generated image will update as the reverse diffusion progresses.</div>"
)
with gr.Row():
text_input = gr.Textbox(
label="Text Prompt",
lines=2,
placeholder="E.g., 'He is a man with brown hair.'",
)
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image", type="pil")
generate_button.click(
fn=sample_ddpm_inference,
inputs=[text_input],
outputs=[output_image],
)
if __name__ == "__main__":
demo.launch(share=True) |