File size: 7,620 Bytes
ddc10aa
036212e
 
ddc10aa
c74d396
 
 
036212e
 
 
 
 
 
 
 
c74d396
e6c561b
 
ddc10aa
 
8fb1bf9
 
 
 
ddc10aa
0f174d0
 
ddc10aa
c74d396
ddc10aa
 
 
 
 
 
 
 
 
 
 
 
 
c74d396
ddc10aa
 
 
 
 
 
 
96b2858
ddc10aa
 
 
 
 
 
 
 
 
 
b0aca51
 
ddc10aa
036212e
 
ddc10aa
b0aca51
 
9a14442
ee36817
036212e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f174d0
 
 
 
 
 
 
 
 
 
 
20f73e3
0f174d0
 
036212e
0f174d0
 
 
 
036212e
0f174d0
ddc10aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c39621e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0aca51
c39621e
 
ddc10aa
 
20f73e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import torch
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import gradio as gr
from model import (
    UNet,
    VQVAE,
    LinearNoiseScheduler,
    get_tokenizer_and_model,
    get_text_representation,
    dataset_params,
    diffusion_params,
    ldm_params,
    autoencoder_params,
    train_params,
)
from huggingface_hub import hf_hub_download
import spaces
import json


print("Gradio version:", gr.__version__)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Currently running on {device}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

# Download config and checkpoint files from HF Hub
config_path = hf_hub_download(
    repo_id="RishabA/celeba-cond-ddpm", filename="config.json"
)
with open(config_path, "r") as f:
    config = json.load(f)

ldm_ckpt_path = hf_hub_download(
    repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/ddpm_ckpt_class_cond.pth"
)
vae_ckpt_path = hf_hub_download(
    repo_id="RishabA/celeba-cond-ddpm", filename="celebhq/vqvae_autoencoder_ckpt.pth"
)

# Instantiate and load the models
unet = UNet(config["autoencoder_params"]["z_channels"], config["ldm_params"]).to(device)
vae = VQVAE(
    config["dataset_params"]["image_channels"], config["autoencoder_params"]
).to(device)

unet_state = torch.load(ldm_ckpt_path, map_location=device)
unet.load_state_dict(unet_state["model_state_dict"])
print(unet_state["epoch"])

vae_state = torch.load(vae_ckpt_path, map_location=device)
vae.load_state_dict(vae_state["model_state_dict"])

unet.eval()
vae.eval()

print("Model and checkpoints loaded successfully!")


@spaces.GPU
def sample_ddpm_inference(text_prompt):
    """
    Given a text prompt and (optionally) an image condition (as a PIL image),
    sample from the diffusion model and return a generated image (PIL image).
    """

    mask_image_pil = None
    guidance_scale = 2.0
    image_display_rate = 4

    # Create noise scheduler
    scheduler = LinearNoiseScheduler(
        num_timesteps=diffusion_params["num_timesteps"],
        beta_start=diffusion_params["beta_start"],
        beta_end=diffusion_params["beta_end"],
    )
    # Get conditioning config from ldm_params
    condition_config = ldm_params.get("condition_config", None)
    condition_types = (
        condition_config.get("condition_types", [])
        if condition_config is not None
        else []
    )

    # Load text tokenizer/model for conditioning
    text_model_type = condition_config["text_condition_config"]["text_embed_model"]
    text_tokenizer, text_model = get_tokenizer_and_model(text_model_type, device=device)

    # Get empty text representation for classifier-free guidance
    empty_text_embed = get_text_representation([""], text_tokenizer, text_model, device)

    # Get text representation of the input prompt
    text_prompt_embed = get_text_representation(
        [text_prompt], text_tokenizer, text_model, device
    )

    # Prepare image conditioning:
    # If the user uploaded a mask image (should be a PIL image), convert it; otherwise, use zeros.
    if "image" in condition_types:
        if mask_image_pil is not None:
            mask_transform = transforms.Compose(
                [
                    transforms.Resize(
                        (
                            ldm_params["condition_config"]["image_condition_config"][
                                "image_condition_h"
                            ],
                            ldm_params["condition_config"]["image_condition_config"][
                                "image_condition_w"
                            ],
                        )
                    ),
                    transforms.ToTensor(),
                ]
            )
            mask_tensor = (
                mask_transform(mask_image_pil).unsqueeze(0).to(device)
            )  # (1, channels, H, W)
        else:
            # Create a zero mask with the required number of channels (e.g. 18)
            ic = ldm_params["condition_config"]["image_condition_config"][
                "image_condition_input_channels"
            ]
            H = ldm_params["condition_config"]["image_condition_config"][
                "image_condition_h"
            ]
            W = ldm_params["condition_config"]["image_condition_config"][
                "image_condition_w"
            ]
            mask_tensor = torch.zeros((1, ic, H, W), device=device)
    else:
        mask_tensor = None

    # Build conditioning dictionaries for classifier-free guidance:
    # For unconditional, we use empty text and zero mask.
    uncond_input = {}
    cond_input = {}
    if "text" in condition_types:
        uncond_input["text"] = empty_text_embed
        cond_input["text"] = text_prompt_embed
    if "image" in condition_types:
        # Use zeros for unconditioning, and the provided mask for conditioning.
        uncond_input["image"] = torch.zeros_like(mask_tensor)
        cond_input["image"] = mask_tensor

    # Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat)
    # For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32.
    latent_size = dataset_params["image_size"] // (
        2 ** sum(autoencoder_params["down_sample"])
    )
    batch = train_params["num_samples"]
    z_channels = autoencoder_params["z_channels"]

    # Sample initial latent noise
    xt = torch.randn((batch, z_channels, latent_size, latent_size), device=device)

    # Sampling loop (reverse diffusion)
    T = diffusion_params["num_timesteps"]
    for i in reversed(range(T)):
        t = torch.full((batch,), i, dtype=torch.long, device=device)

        with torch.no_grad():
            # Get conditional noise prediction
            noise_pred_cond = unet(xt, t, cond_input)
            if guidance_scale > 1:
                noise_pred_uncond = unet(xt, t, uncond_input)
                noise_pred = noise_pred_uncond + guidance_scale * (
                    noise_pred_cond - noise_pred_uncond
                )
            else:
                noise_pred = noise_pred_cond
            xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, t)

            if i % image_display_rate == 0 or i == 0:
                # Decode current latent into image
                generated = vae.decode(xt)

                generated = torch.clamp(generated, -1, 1)
                generated = (generated + 1) / 2  # scale to [0,1]
                grid = make_grid(generated, nrow=1)
                pil_img = transforms.ToPILImage()(grid.cpu())

                yield pil_img


css_str = """
.title { 
    font-size: 48px; 
    text-align: center; 
    margin-top: 20px; 
}
.description { 
    font-size: 20px; 
    text-align: center; 
    margin-bottom: 40px; 
}
"""

with gr.Blocks(css=css_str) as demo:
    gr.Markdown("<div class='title'>Conditioned Latent Diffusion with CelebA</div>")
    gr.Markdown(
        "<div class='description'>Enter a text prompt and (optionally) upload a mask image for conditioning; the generated image will update as the reverse diffusion progresses.</div>"
    )

    with gr.Row():
        text_input = gr.Textbox(
            label="Text Prompt",
            lines=2,
            placeholder="E.g., 'He is a man with brown hair.'",
        )

    generate_button = gr.Button("Generate Image")
    output_image = gr.Image(label="Generated Image", type="pil")

    generate_button.click(
        fn=sample_ddpm_inference,
        inputs=[text_input],
        outputs=[output_image],
    )

if __name__ == "__main__":
    demo.launch(share=True)