Spaces:
Running
Running
File size: 3,871 Bytes
0b2b0ab 2489323 5086590 0b2b0ab 2489323 0b2b0ab d51b792 0b2b0ab d51b792 0b2b0ab d51b792 0b2b0ab d51b792 0b2b0ab d51b792 0b2b0ab d51b792 0b2b0ab d51b792 0b2b0ab d51b792 510e898 0b2b0ab 2489323 18c30e2 2489323 0b2b0ab d51b792 0b2b0ab 2489323 0b2b0ab 2489323 f3fb43b 9341531 f3fb43b 0b2b0ab 2489323 0b2b0ab d51b792 0b2b0ab f5d4495 2489323 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
from PIL import Image
from torchvision import transforms
from load_model import sample
import torch
import random
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device
image_size = 128
def show_example_fn():
sketch = Image.open("examples/sketch.png")
scribble_folder = "./examples/scribbles/"
png_files = [f for f in os.listdir(scribble_folder) if f.lower().endswith(".png")]
# get random scribble
random_scribble = Image.open(
os.path.join(scribble_folder, random.choice(png_files))
)
return [sketch, random_scribble]
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1),
]
)
def process_images(
sketch,
scribbles,
sampling_steps,
seed_nr,
upscale,
progress=gr.Progress(),
):
w, h = sketch.size
sketch = transform(sketch.convert("RGB"))
scribbles = transform(scribbles.convert("RGB"))
if upscale:
return transforms.Resize((h, w))(
sample(sketch, scribbles, sampling_steps, seed_nr, progress)
)
else:
return sample(sketch, scribbles, sampling_steps, seed_nr, progress)
theme = gr.themes.Monochrome()
with gr.Blocks(theme=theme) as demo:
with gr.Row():
gr.Markdown(
"<h1 style='text-align: center; font-size: 30px;'>Image Inpainting with Conditional Diffusion by MedicAI</h1>"
)
with gr.Row():
with gr.Column():
sketch_input = gr.Image(type="pil", label="Sketch", height=500)
with gr.Column():
scribbles_input = gr.Image(type="pil", label="Scribbles", height=500)
with gr.Column():
output = gr.Image(type="pil", label="Output")
with gr.Row():
with gr.Column():
seed_slider = gr.Number(
label="Random Seed π² (if the image generated is not to your liking, simply use another seed)",
value=5,
)
upscale_button = gr.Checkbox(
label=f"Stretch (If you want to stretch the downloadable output to the input size, check this box, the default output of neural networks is {image_size}x{image_size} )",
value=False,
)
with gr.Column():
sampling_slider = gr.Slider(
minimum=1,
maximum=250,
step=1,
label="DDPM Sampling Steps π (the higher the number of steps the higher the quality of the images)",
value=50,
)
show_example = gr.Button(value="Show Example Input ")
with gr.Row():
generate_button = gr.Button(value="Paint π¨ ")
with gr.Row():
generate_info = gr.Markdown(
"<p style='text-align: center; font-size: 16px;'>"
"Notes: Depending on where you run this demo, it might take a while to generate the output. For the HF space it may take up to 20 minutes for 100 sampling steps. We recommend lowering the sampling steps to 10 for the HF space. Model trained using this <a href='https://huggingface.co./datasets/pawlo2013/anime_diffusion_full'>dataset</a>."
"</p>"
)
show_example.click(
show_example_fn,
inputs=[],
outputs=[sketch_input, scribbles_input],
concurrency_limit=1,
trigger_mode="once",
)
generate_button.click(
process_images,
inputs=[
sketch_input,
scribbles_input,
sampling_slider,
seed_slider,
upscale_button,
],
outputs=output,
concurrency_limit=1,
trigger_mode="once",
)
if __name__ == "__main__":
demo.launch()
|