File size: 3,871 Bytes
0b2b0ab
 
 
 
 
 
2489323
5086590
0b2b0ab
 
 
 
 
 
 
2489323
 
 
 
 
 
 
 
 
 
 
 
 
 
0b2b0ab
 
 
 
 
 
 
 
 
d51b792
 
 
 
 
 
 
 
0b2b0ab
 
d51b792
 
0b2b0ab
 
d51b792
 
0b2b0ab
d51b792
0b2b0ab
d51b792
0b2b0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d51b792
0b2b0ab
 
d51b792
0b2b0ab
 
 
d51b792
510e898
0b2b0ab
 
2489323
18c30e2
2489323
 
 
0b2b0ab
 
 
 
 
d51b792
0b2b0ab
 
2489323
0b2b0ab
 
2489323
f3fb43b
 
 
9341531
f3fb43b
 
0b2b0ab
2489323
 
 
 
 
 
 
 
0b2b0ab
 
 
 
 
 
 
 
 
 
d51b792
 
0b2b0ab
 
f5d4495
2489323
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
from PIL import Image
from torchvision import transforms
from load_model import sample
import torch
import random
import os


device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device

image_size = 128


def show_example_fn():
    sketch = Image.open("examples/sketch.png")
    scribble_folder = "./examples/scribbles/"

    png_files = [f for f in os.listdir(scribble_folder) if f.lower().endswith(".png")]

    # get random scribble
    random_scribble = Image.open(
        os.path.join(scribble_folder, random.choice(png_files))
    )

    return [sketch, random_scribble]


transform = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1),
    ]
)


def process_images(
    sketch,
    scribbles,
    sampling_steps,
    seed_nr,
    upscale,
    progress=gr.Progress(),
):
    w, h = sketch.size

    sketch = transform(sketch.convert("RGB"))
    scribbles = transform(scribbles.convert("RGB"))

    if upscale:
        return transforms.Resize((h, w))(
            sample(sketch, scribbles, sampling_steps, seed_nr, progress)
        )

    else:
        return sample(sketch, scribbles, sampling_steps, seed_nr, progress)


theme = gr.themes.Monochrome()


with gr.Blocks(theme=theme) as demo:
    with gr.Row():
        gr.Markdown(
            "<h1 style='text-align: center; font-size: 30px;'>Image Inpainting with Conditional Diffusion by MedicAI</h1>"
        )

    with gr.Row():
        with gr.Column():
            sketch_input = gr.Image(type="pil", label="Sketch", height=500)
        with gr.Column():
            scribbles_input = gr.Image(type="pil", label="Scribbles", height=500)

        with gr.Column():
            output = gr.Image(type="pil", label="Output")

    with gr.Row():
        with gr.Column():
            seed_slider = gr.Number(
                label="Random Seed 🎲 (if the image generated is not to your liking, simply use another seed)",
                value=5,
            )

            upscale_button = gr.Checkbox(
                label=f"Stretch (If you want to stretch the downloadable output to the input size, check this box, the default output of neural networks is {image_size}x{image_size} )",
                value=False,
            )

        with gr.Column():
            sampling_slider = gr.Slider(
                minimum=1,
                maximum=250,
                step=1,
                label="DDPM Sampling Steps πŸ”„ (the higher the number of steps the higher the quality of the images)",
                value=50,
            )
            show_example = gr.Button(value="Show Example Input ")

    with gr.Row():
        generate_button = gr.Button(value="Paint 🎨 ")
    with gr.Row():
        generate_info = gr.Markdown(
            "<p style='text-align: center; font-size: 16px;'>"
            "Notes:  Depending on where you run this demo, it might take a while to generate the output. For the HF space it may take up to 20 minutes for 100 sampling steps. We recommend lowering the sampling steps to 10 for the HF space. Model trained using this  <a href='https://huggingface.co./datasets/pawlo2013/anime_diffusion_full'>dataset</a>."
            "</p>"
        )

    show_example.click(
        show_example_fn,
        inputs=[],
        outputs=[sketch_input, scribbles_input],
        concurrency_limit=1,
        trigger_mode="once",
    )

    generate_button.click(
        process_images,
        inputs=[
            sketch_input,
            scribbles_input,
            sampling_slider,
            seed_slider,
            upscale_button,
        ],
        outputs=output,
        concurrency_limit=1,
        trigger_mode="once",
    )

if __name__ == "__main__":
    demo.launch()