Spaces:
Running
Running
File size: 5,008 Bytes
6b3eee7 ca683cb 6b3eee7 ca683cb 6b3eee7 8aa40ca 6b3eee7 a1e4d11 6b3eee7 8aa40ca 6b3eee7 8aa40ca 6b3eee7 8aa40ca 9133821 6b3eee7 ba16590 6b3eee7 1466874 6b3eee7 f1db938 6b3eee7 20e055c 680482f 6b3eee7 5c48d67 6b3eee7 2c1022d 6b3eee7 2c1022d 71df048 6b3eee7 9133821 6b3eee7 8aa40ca 6b3eee7 9133821 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
import os, requests
import numpy as np
from inference import setup_model, colorize_grayscale, predict_anchors
## local | remote
RUN_MODE = "remote"
if RUN_MODE != "local":
os.system("wget https://huggingface.co./menghanxia/disco/resolve/main/disco-beta.pth.rar -q")
os.rename("disco-beta.pth.rar", "./checkpoints/disco-beta.pth.rar")
## examples
os.system("wget https://huggingface.co./menghanxia/disco/resolve/main/01.jpg -q")
os.system("wget https://huggingface.co./menghanxia/disco/resolve/main/02.jpg -q")
os.system("wget https://huggingface.co./menghanxia/disco/resolve/main/03.jpg -q")
os.system("wget https://huggingface.co./menghanxia/disco/resolve/main/04.jpg -q")
## step 1: set up model
device = "cpu"
checkpt_path = "checkpoints/disco-beta.pth.rar"
colorizer, colorLabeler = setup_model(checkpt_path, device=device)
def click_colorize(rgb_img, hint_img, n_anchors, is_high_res, is_editable):
if hint_img is None:
hint_img = rgb_img
output = colorize_grayscale(colorizer, colorLabeler, rgb_img, hint_img, n_anchors, True, is_editable, device)
output1 = colorize_grayscale(colorizer, colorLabeler, rgb_img, hint_img, n_anchors, False, is_editable, device)
return output, output1
def click_predanchors(rgb_img, n_anchors, is_high_res, is_editable):
output = predict_anchors(colorizer, colorLabeler, rgb_img, n_anchors, is_high_res, is_editable, device)
return output
## step 2: configure interface
def switch_states(is_checked):
if is_checked:
return gr.Image.update(visible=True), gr.Button.update(visible=True)
else:
return gr.Image.update(visible=False), gr.Button.update(visible=False)
demo = gr.Blocks(title="DISCO", delete_cache=(1800, 3600),)
with demo:
gr.Markdown(value="""
**Gradio demo for DISCO: Disentangled Image Colorization via Global Anchors**. Check the [project page](https://menghanxia.github.io/projects/disco.html).
""")
with gr.Row():
with gr.Column():
with gr.Row():
Image_input = gr.Image(type="numpy", label="Input", interactive=True)
Image_anchor = gr.Image(type="numpy", label="Anchor", interactive=True, visible=True)
with gr.Row():
Num_anchor = gr.Number(precision=0, value=8, label="Num. of anchors (3~14)")
Radio_resolution = gr.Radio(type="index", choices=["Low (256x256)", "High (512x512)"], \
label="Colorization resolution (Low is more stable)", value="Low (256x256)")
with gr.Row():
Ckeckbox_editable = gr.Checkbox(value=False, label='Show editable anchors')
Button_show_anchor = gr.Button(value="Predict anchors", visible=True)
Button_run = gr.Button(value="Colorize")
with gr.Column():
Image_output = [gr.Image(type="numpy", label="Output", format="png"), gr.Image(type="numpy", label="Output", format="png")]
Ckeckbox_editable.change(fn=switch_states, inputs=Ckeckbox_editable, outputs=[Image_anchor, Button_show_anchor])
Button_show_anchor.click(fn=click_predanchors, inputs=[Image_input, Num_anchor, Radio_resolution, Ckeckbox_editable], outputs=Image_anchor)
Button_run.click(fn=click_colorize, inputs=[Image_input, Image_anchor, Num_anchor, Radio_resolution, Ckeckbox_editable], \
outputs=Image_output)
## guiline
gr.Markdown(value="""
**Guideline**
1. Upload your image or select one from the examples.
2. Set up the arguments: "Num. of anchors" and "Colorization resolution".
3. Run the colorization (two modes supported):
- Automatic mode: **Click** "Colorize" to get the automatically colorized output.
- Editable mode: **Check** ""Show editable anchors"; **Click** "Predict anchors"; **Redraw** the anchor colors (only anchor region will be used); **Click** "Colorize" to get the result.
""")
if RUN_MODE != "local":
gr.Examples(examples=[
['01.jpg', 8, "Low (256x256)"],
['02.jpg', 8, "Low (256x256)"],
['03.jpg', 8, "Low (256x256)"],
['04.jpg', 8, "Low (256x256)"],
],
inputs=[Image_input,Num_anchor,Radio_resolution], outputs=[Image_output], label="Examples", cache_mode="lazy",)
gr.HTML(value="""
<p style="text-align:center; color:orange"><a href='https://menghanxia.github.io/projects/disco.html' target='_blank'>DISCO Project Page</a> | <a href='https://github.com/MenghanXia/DisentangledColorization' target='_blank'>Github Repo</a></p>
""")
if RUN_MODE == "local":
demo.launch(server_name='9.134.253.83',server_port=7788)
else:
demo.queue()
demo.launch(show_error=True)
|