File size: 6,875 Bytes
cd29368
 
 
a5a7c8e
 
cd29368
a5a7c8e
cd29368
a5a7c8e
 
cd29368
 
 
 
a5a7c8e
2bf201c
a5a7c8e
f5fe7df
a5a7c8e
 
 
d8876a7
 
 
d840f35
 
039fd33
6cad18a
d840f35
 
 
 
a5a7c8e
 
d840f35
a5a7c8e
f5fe7df
 
a5a7c8e
 
 
 
33b6ed2
 
 
 
 
 
 
a5a7c8e
 
 
d840f35
 
 
 
 
 
 
 
 
 
 
 
 
eeca189
d840f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
039fd33
d840f35
 
 
 
 
039fd33
 
d840f35
a5a7c8e
 
 
 
 
 
 
 
d840f35
 
 
 
 
 
 
 
 
cd29368
10b347e
 
 
 
 
 
 
 
 
 
 
 
cd29368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e80b38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd29368
 
d840f35
cd29368
d840f35
 
 
 
 
 
 
8e80b38
 
 
2e1d77a
8e80b38
d840f35
 
 
 
 
8e80b38
 
 
d840f35
8e80b38
d840f35
 
 
 
 
 
 
 
 
 
cd29368
89635f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/env python
# coding: utf-8
import gradio as gr
import random
import torch
from collections import defaultdict
from diffusers import DiffusionPipeline
from functools import partial
from itertools import zip_longest
from typing import List
from PIL import Image

SELECT_LABEL = "Select as seed"

MODEL_ID = "CompVis/ldm-text2im-large-256"
STEPS = 25   # while running on CPU
ETA = 0.3
GUIDANCE_SCALE = 6

ldm = DiffusionPipeline.from_pretrained(MODEL_ID)

import torch
print(f"cuda: {torch.cuda.is_available()}")

with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
    state = gr.Variable({
        'selected': -1,
        'seeds': [random.randint(0, 2 ** 32 - 1) for _ in range(6)]
    })

    def infer_seeded_image(prompt, seed):
        print(f"Prompt: {prompt}, seed: {seed}")
        images, _ = infer_grid(prompt, n=1, seeds=[seed])
        return images[0]

    def infer_grid(prompt, n=6, seeds=[]):
        # Unfortunately we have to iterate instead of requesting all images at once,
        # because we have no way to get the intermediate generation seeds.
        result = defaultdict(list)
        for _, seed in zip_longest(range(n), seeds, fillvalue=None):
            seed = random.randint(0, 2**32 - 1) if seed is None else seed
            _ = torch.manual_seed(seed)
            with torch.autocast("cuda"):
                images = ldm(
                    [prompt],
                    num_inference_steps=STEPS,
                    eta=ETA,
                    guidance_scale=GUIDANCE_SCALE
                )["sample"]
            result["images"].append(images[0])
            result["seeds"].append(seed)
        return result["images"], result["seeds"]

    def infer(prompt, state):
        """
        Outputs:
        - Grid images (list)
        - Seeded Image (Image or None)
        - Grid Box with updated visibility
        - Seeded Box with updated visibility
        """
        grid_images = [None] * 6
        image_with_seed = None
        visible = (False, False)

        if (seed_index := state["selected"]) > -1:
            seed = state["seeds"][seed_index]
            image_with_seed = infer_seeded_image(prompt, seed)
            visible = (False, True)
        else:
            grid_images, seeds = infer_grid(prompt)
            state["seeds"] = seeds
            visible = (True, False)

        boxes = [gr.Box.update(visible=v) for v in visible]
        return grid_images + [image_with_seed] + boxes + [state]

    def update_state(selected_index: int, value, state):
        if value == '':
            others_value = None
        else:
            others_value = ''
            state["selected"] = selected_index
        others = gr.Radio.update(value=others_value)
        return [others] * 5 + [state]

    def clear_seed(state):
        """Update state of Radio buttons, grid, seeded_box"""
        state["selected"] = -1
        return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state]

    def image_block():
        return gr.Image(
            interactive=False, show_label=False
        ).style(
            # border = (True, True, False, True),
            rounded = (True, True, False, False),
        )

    def radio_block():
        radio = gr.Radio(
            choices=[SELECT_LABEL], interactive=True, show_label=False,
        ).style(
            # border = (False, True, True, True),
            # rounded = (False, False, True, True)
            container=False
        )
        return radio

    gr.Markdown(
        """
        <h1><center>Latent Diffusion Demo</center></h1>
        <p>Type anything to generate a few images that represent your prompt.
        Select one of the results to use as a <b>seed</b> for the next generation:
        you can try variations of your prompt starting from the same state and see how it changes.
        For example, <i>Labrador in the style of Vermeer</i> could be tweaked to
        <i>Labrador in the style of Picasso</i> or <i>Lynx in the style of Van Gogh</i>.
        If your prompts are similar, the tweaked result should also have a similar structure
        but different details or style.</p>
        """
    )
    with gr.Group():
        with gr.Box():
            with gr.Row().style(mobile_collapse=False, equal_height=True):
                text = gr.Textbox(
                    label="Enter your prompt", show_label=False, max_lines=1
                ).style(
                    border=(True, False, True, True),
                    # margin=False,
                    rounded=(True, False, False, True),
                    container=False,
                )
                btn = gr.Button("Run").style(
                    margin=False,
                    rounded=(False, True, True, False),
                )

        ## Can we create a Component with these, so it can participate as an output?
        with (grid := gr.Box()):
            with gr.Row():
                with gr.Box().style(border=None):
                    image1 = image_block()
                    select1 = radio_block()
                with gr.Box().style(border=None):
                    image2 = image_block()
                    select2 = radio_block()
                with gr.Box().style(border=None):
                    image3 = image_block()
                    select3 = radio_block()
            with gr.Row():
                with gr.Box().style(border=None):
                    image4 = image_block()
                    select4 = radio_block()
                with gr.Box().style(border=None):
                    image5 = image_block()
                    select5 = radio_block()
                with gr.Box().style(border=None):
                    image6 = image_block()
                    select6 = radio_block()

        images = [image1, image2, image3, image4, image5, image6]
        selectors = [select1, select2, select3, select4, select5, select6]

        for i, radio in enumerate(selectors):
            others = list(filter(lambda s: s != radio, selectors))
            radio.change(
                partial(update_state, i),
                inputs=[radio, state],
                outputs=others + [state]
            )

    with (seeded_box := gr.Box()):
        seeded_image = image_block()
        clear_seed_button = gr.Button("Return to Grid")
    seeded_box.visible = False
    clear_seed_button.click(
        clear_seed,
        inputs=[state],
        outputs=selectors + [grid, seeded_box] + [state]
    )

    all_images = images + [seeded_image]
    boxes = [grid, seeded_box]
    infer_outputs = all_images + boxes + [state]

    text.submit(
        infer,
        inputs=[text, state],
        outputs=infer_outputs
    )
    btn.click(
        infer,
        inputs=[text, state],
        outputs=infer_outputs
    )

demo.launch(enable_queue=True)