File size: 6,539 Bytes
af7c0ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e18d7f
 
c4529e7
 
8e18d7f
af7c0ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e18d7f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import gradio as gr
import numpy as np
from PIL import Image
import os
import torch
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download

from paintmind.engine.util import instantiate_from_config
from paintmind.stage1.diffuse_slot import DiffuseSlot

device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename='semanticist_tok_XL.pkl')
config_path = 'configs/tokenizer_config.yaml'
cfg = OmegaConf.load(config_path)
ckpt = torch.load(ckpt_path, map_location='cpu')
from paintmind.utils.datasets import vae_transforms
from PIL import Image

transform = vae_transforms('test')
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if device == 'cuda':
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Tesla T4


def norm_ip(img, low, high):
    img.clamp_(min=low, max=high)
    img.sub_(low).div_(max(high - low, 1e-5))

def norm_range(t, value_range):
    if value_range is not None:
        norm_ip(t, value_range[0], value_range[1])
    else:
        norm_ip(t, float(t.min()), float(t.max()))

from PIL import Image
def convert_np(img):
    ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
            .permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    return ndarr
def convert_PIL(img):
    ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
            .permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    img = Image.fromarray(ndarr)
    return img

ckpt = {k.replace('._orig_mod', ''): v for k, v in ckpt.items()}

model = DiffuseSlot(**cfg['trainer']['params']['model']['params'])
msg = model.load_state_dict(ckpt, strict=False)
model = model.to(device)
model = model.eval()
model.enable_nest = True

def viz_diff_slots(model, img, nums, cfg=1.0, return_img=False):
    n_slots_inf = []
    for num_slots_to_inference in nums:
        recon_n = model(
            img, None, sample=True, cfg=cfg,
            inference_with_n_slots=num_slots_to_inference,
        )
        n_slots_inf.append(recon_n)
    return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))]

# Removed process_image function as its functionality is now in the update_outputs function

with gr.Blocks() as demo:
    with gr.Row():
        # First column - Input and configs
        with gr.Column(scale=1):
            gr.Markdown("## Input")
            input_image = gr.Image(label="Upload an image", type="numpy")
            
            with gr.Group():
                gr.Markdown("### Configuration")
                show_gallery = gr.Checkbox(label="Show Gallery", value=False)
                # You can add more config options here
                # slider = gr.Slider(minimum=0, maximum=10, value=5, label="Processing Intensity")
                slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value")
                labels_input = gr.Textbox(
                    label="Gallery Labels (comma-separated)", 
                    value="1, 4, 16, 64, 256", 
                    placeholder="Enter comma-separated numbers for the number of slots to use"
                )
        
        # Second column - Output (conditionally rendered)
        with gr.Column(scale=1):
            gr.Markdown("## Output")
            
            # Container for conditional rendering
            with gr.Group(visible=False) as gallery_container:
                gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True)
            
            # Always visible output image
            output_image = gr.Image(label="Processed Image", type="numpy")
    
    # Handle form submission
    submit_btn = gr.Button("Process")
    
    # Define the processing logic
    def update_outputs(image, show_gallery_value, slider_value, labels_text):
        # Update the visibility of the gallery container
        gallery_container.visible = show_gallery_value
        
        try:
            # Parse the labels from the text input
            if labels_text and "," in labels_text:
                labels = [int(label.strip()) for label in labels_text.split(",")]
            else:
                # Default labels if none provided or in wrong format
                labels = [1, 4, 16, 64, 256]
        except:
            labels = [1, 4, 16, 64, 256]
        while len(labels) < 3:
            labels.append(256)
        
        # Process the image based on configurations
        if image is None:
            # Return placeholder if no image is uploaded
            placeholder = np.zeros((300, 300, 3), dtype=np.uint8)
            return gallery_container, [], placeholder
        image = Image.fromarray(image)
        img = transform(image)
        img = img.unsqueeze(0).to(device)
        recon = viz_diff_slots(model, img, [256], cfg=slider_value)[0]

            
        if not show_gallery_value:
            # If only the image should be shown, return just the processed image
            return gallery_container, [], recon
        else:
            model_decompose = viz_diff_slots(model, img, labels, cfg=slider_value)
            # Create image variations and pair them with labels
            gallery_images = [
                (image, 'GT'),
                # (np.array(Image.fromarray(image).convert("L").convert("RGB")), labels[1]),
                # (np.array(Image.fromarray(image).rotate(180)), labels[2])
            ] + [(img, 'Recon. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)]
            return gallery_container, gallery_images, image
    
    # Connect the inputs and outputs
    submit_btn.click(
        fn=update_outputs,
        inputs=[input_image, show_gallery, slider, labels_input],
        outputs=[gallery_container, gallery, output_image]
    )
    
    # Also update when checkbox changes
    show_gallery.change(
        fn=lambda value: gr.update(visible=value),
        inputs=[show_gallery],
        outputs=[gallery_container]
    )

# Add examples
    examples = [
        ["examples/city.jpg", False, 4.0, "1,4,16,64,256"],
        ["examples/food.jpg", True, 4.0, "1,4,16,64,256"],
        ["examples/highland.webp", True, 4.0, "1,4,16,64,256"],
    ]
    
    gr.Examples(
        examples=examples,
        inputs=[input_image, show_gallery, slider, labels_input],
        outputs=[gallery_container, gallery, output_image],
        fn=update_outputs,
        cache_examples=True
    )

# Launch the demo
if __name__ == "__main__":
    demo.launch()