File size: 6,539 Bytes
af7c0ce 8e18d7f c4529e7 8e18d7f af7c0ce 8e18d7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gradio as gr
import numpy as np
from PIL import Image
import os
import torch
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from paintmind.engine.util import instantiate_from_config
from paintmind.stage1.diffuse_slot import DiffuseSlot
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename='semanticist_tok_XL.pkl')
config_path = 'configs/tokenizer_config.yaml'
cfg = OmegaConf.load(config_path)
ckpt = torch.load(ckpt_path, map_location='cpu')
from paintmind.utils.datasets import vae_transforms
from PIL import Image
transform = vae_transforms('test')
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if device == 'cuda':
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Tesla T4
def norm_ip(img, low, high):
img.clamp_(min=low, max=high)
img.sub_(low).div_(max(high - low, 1e-5))
def norm_range(t, value_range):
if value_range is not None:
norm_ip(t, value_range[0], value_range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))
from PIL import Image
def convert_np(img):
ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
return ndarr
def convert_PIL(img):
ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
img = Image.fromarray(ndarr)
return img
ckpt = {k.replace('._orig_mod', ''): v for k, v in ckpt.items()}
model = DiffuseSlot(**cfg['trainer']['params']['model']['params'])
msg = model.load_state_dict(ckpt, strict=False)
model = model.to(device)
model = model.eval()
model.enable_nest = True
def viz_diff_slots(model, img, nums, cfg=1.0, return_img=False):
n_slots_inf = []
for num_slots_to_inference in nums:
recon_n = model(
img, None, sample=True, cfg=cfg,
inference_with_n_slots=num_slots_to_inference,
)
n_slots_inf.append(recon_n)
return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))]
# Removed process_image function as its functionality is now in the update_outputs function
with gr.Blocks() as demo:
with gr.Row():
# First column - Input and configs
with gr.Column(scale=1):
gr.Markdown("## Input")
input_image = gr.Image(label="Upload an image", type="numpy")
with gr.Group():
gr.Markdown("### Configuration")
show_gallery = gr.Checkbox(label="Show Gallery", value=False)
# You can add more config options here
# slider = gr.Slider(minimum=0, maximum=10, value=5, label="Processing Intensity")
slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value")
labels_input = gr.Textbox(
label="Gallery Labels (comma-separated)",
value="1, 4, 16, 64, 256",
placeholder="Enter comma-separated numbers for the number of slots to use"
)
# Second column - Output (conditionally rendered)
with gr.Column(scale=1):
gr.Markdown("## Output")
# Container for conditional rendering
with gr.Group(visible=False) as gallery_container:
gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True)
# Always visible output image
output_image = gr.Image(label="Processed Image", type="numpy")
# Handle form submission
submit_btn = gr.Button("Process")
# Define the processing logic
def update_outputs(image, show_gallery_value, slider_value, labels_text):
# Update the visibility of the gallery container
gallery_container.visible = show_gallery_value
try:
# Parse the labels from the text input
if labels_text and "," in labels_text:
labels = [int(label.strip()) for label in labels_text.split(",")]
else:
# Default labels if none provided or in wrong format
labels = [1, 4, 16, 64, 256]
except:
labels = [1, 4, 16, 64, 256]
while len(labels) < 3:
labels.append(256)
# Process the image based on configurations
if image is None:
# Return placeholder if no image is uploaded
placeholder = np.zeros((300, 300, 3), dtype=np.uint8)
return gallery_container, [], placeholder
image = Image.fromarray(image)
img = transform(image)
img = img.unsqueeze(0).to(device)
recon = viz_diff_slots(model, img, [256], cfg=slider_value)[0]
if not show_gallery_value:
# If only the image should be shown, return just the processed image
return gallery_container, [], recon
else:
model_decompose = viz_diff_slots(model, img, labels, cfg=slider_value)
# Create image variations and pair them with labels
gallery_images = [
(image, 'GT'),
# (np.array(Image.fromarray(image).convert("L").convert("RGB")), labels[1]),
# (np.array(Image.fromarray(image).rotate(180)), labels[2])
] + [(img, 'Recon. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)]
return gallery_container, gallery_images, image
# Connect the inputs and outputs
submit_btn.click(
fn=update_outputs,
inputs=[input_image, show_gallery, slider, labels_input],
outputs=[gallery_container, gallery, output_image]
)
# Also update when checkbox changes
show_gallery.change(
fn=lambda value: gr.update(visible=value),
inputs=[show_gallery],
outputs=[gallery_container]
)
# Add examples
examples = [
["examples/city.jpg", False, 4.0, "1,4,16,64,256"],
["examples/food.jpg", True, 4.0, "1,4,16,64,256"],
["examples/highland.webp", True, 4.0, "1,4,16,64,256"],
]
gr.Examples(
examples=examples,
inputs=[input_image, show_gallery, slider, labels_input],
outputs=[gallery_container, gallery, output_image],
fn=update_outputs,
cache_examples=True
)
# Launch the demo
if __name__ == "__main__":
demo.launch()
|