Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,18 +1,15 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
-
from diffusers.utils import load_image
|
5 |
-
from diffusers import
|
6 |
-
StableDiffusionPipeline,
|
7 |
-
StableDiffusionControlNetPipeline,
|
8 |
-
ControlNetModel
|
9 |
-
)
|
10 |
from peft import PeftModel, LoraConfig
|
11 |
from controlnet_aux import HEDdetector
|
12 |
from PIL import Image
|
13 |
import cv2 as cv
|
14 |
import os
|
15 |
-
|
|
|
16 |
|
17 |
MAX_SEED = np.iinfo(np.int32).max
|
18 |
MAX_IMAGE_SIZE = 1024
|
@@ -23,198 +20,141 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
23 |
model_id_default = "CompVis/stable-diffusion-v1-4"
|
24 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
def get_lora_sd_pipeline(
|
43 |
-
ckpt_dir='./lora_logos',
|
44 |
-
base_model_name_or_path=None,
|
45 |
-
dtype=torch.float16,
|
46 |
-
adapter_name="default",
|
47 |
-
controlnet=None
|
48 |
-
):
|
49 |
-
|
50 |
-
unet_sub_dir = os.path.join(ckpt_dir, "unet")
|
51 |
-
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
|
52 |
-
|
53 |
-
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
|
54 |
-
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
|
55 |
-
base_model_name_or_path = config.base_model_name_or_path
|
56 |
-
|
57 |
-
if base_model_name_or_path is None:
|
58 |
-
raise ValueError("Please specify the base model name or path")
|
59 |
-
|
60 |
-
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
61 |
-
base_model_name_or_path,
|
62 |
-
torch_dtype=dtype,
|
63 |
-
controlnet=controlnet,
|
64 |
-
)
|
65 |
-
|
66 |
-
before_params = pipe.unet.parameters()
|
67 |
-
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
|
68 |
-
pipe.unet.set_adapter(adapter_name)
|
69 |
-
after_params = pipe.unet.parameters()
|
70 |
-
print("Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
|
71 |
-
|
72 |
-
if os.path.exists(text_encoder_sub_dir):
|
73 |
-
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
def
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
86 |
-
|
|
|
87 |
|
88 |
-
return
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
|
93 |
-
torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
|
94 |
-
|
95 |
-
def map_edge_detection(image_path: str) -> Image:
|
96 |
-
source_img = load_image(image_path).convert('RGB')
|
97 |
-
edges = cv.Canny(np.array(source_img), 80, 160)
|
98 |
-
edges = np.repeat(edges[:, :, None], 3, axis=2)
|
99 |
-
final_image = Image.fromarray(edges)
|
100 |
-
return final_image
|
101 |
|
102 |
-
def
|
103 |
-
global hed
|
104 |
-
if not hed:
|
105 |
-
hed = HEDdetector.from_pretrained('lllyasviel/Annotators')
|
106 |
-
|
107 |
image = load_image(image_path).convert('RGB')
|
108 |
-
scribble_image = hed(image)
|
109 |
-
image_np = np.array(scribble_image)
|
110 |
-
image_np = cv.medianBlur(image_np, 3)
|
111 |
-
image = cv.convertScaleAbs(image_np, alpha=1.5, beta=0)
|
112 |
-
final_image = Image.fromarray(image)
|
113 |
-
return final_image
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
pipe = get_lora_sd_pipeline(
|
118 |
-
ckpt_dir='./lora_logos',
|
119 |
-
base_model_name_or_path=model_id_default,
|
120 |
-
dtype=torch_dtype,
|
121 |
-
controlnet=controlnet
|
122 |
-
).to(device)
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
def infer(
|
127 |
-
prompt,
|
128 |
-
negative_prompt,
|
129 |
-
width=512,
|
130 |
-
height=512,
|
131 |
-
num_inference_steps=20,
|
132 |
-
model_id='CompVis/stable-diffusion-v1-4',
|
133 |
-
seed=42,
|
134 |
-
guidance_scale=7.0,
|
135 |
-
lora_scale=0.5,
|
136 |
-
cn_enable=False,
|
137 |
-
cn_strength=0.0,
|
138 |
-
cn_mode='edge_detection',
|
139 |
-
cn_image=None,
|
140 |
-
ip_enable=False,
|
141 |
-
ip_scale=0.5,
|
142 |
-
ip_image=None,
|
143 |
-
progress=gr.Progress(track_tqdm=True)
|
144 |
-
):
|
145 |
|
146 |
-
|
|
|
|
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
torch_dtype=torch_dtype
|
159 |
-
)
|
160 |
-
controlnet_changed = True
|
161 |
-
else:
|
162 |
-
cn_strength = 0.0 # отключаем контролнет принудительно
|
163 |
-
|
164 |
-
if model_id != pipe._name_or_path:
|
165 |
-
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
166 |
-
model_id,
|
167 |
-
torch_dtype=torch_dtype,
|
168 |
-
controlnet=controlnet,
|
169 |
-
controlnet_conditioning_scale=cn_strength,
|
170 |
-
).to(device)
|
171 |
-
elif (model_id == pipe._name_or_path) and controlnet_changed:
|
172 |
-
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
173 |
-
model_id,
|
174 |
-
torch_dtype=torch_dtype,
|
175 |
-
controlnet=controlnet,
|
176 |
-
controlnet_conditioning_scale=cn_strength,
|
177 |
-
).to(device)
|
178 |
-
print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
|
179 |
-
print(f"LoRA scale applied: {lora_scale}")
|
180 |
-
pipe.fuse_lora(lora_scale=lora_scale)
|
181 |
-
elif (model_id == pipe._name_or_path) and not controlnet_changed:
|
182 |
-
print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
|
183 |
-
print(f"LoRA scale applied: {lora_scale}")
|
184 |
-
pipe.fuse_lora(lora_scale=lora_scale)
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
'num_inference_steps': num_inference_steps,
|
195 |
-
'width': width,
|
196 |
-
'height': height,
|
197 |
-
'generator': generator,
|
198 |
-
}
|
199 |
-
|
200 |
-
if cn_enable:
|
201 |
-
params['controlnet_conditioning_scale'] = cn_strength
|
202 |
-
if cn_mode == 'edge_detection':
|
203 |
-
control_image = map_edge_detection(cn_image)
|
204 |
-
elif cn_mode == 'scribble':
|
205 |
-
control_image = map_scribble(cn_image)
|
206 |
-
params['image'] = control_image
|
207 |
-
|
208 |
-
if ip_enable:
|
209 |
-
pipe.load_ip_adapter(
|
210 |
-
IP_ADAPTER,
|
211 |
-
subfolder="models",
|
212 |
-
weight_name=IP_ADAPTER_WEIGHT_NAME,
|
213 |
)
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
css = """
|
220 |
#col-container {
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
+
from diffusers.utils import load_image
|
5 |
+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
|
|
|
|
|
|
|
|
6 |
from peft import PeftModel, LoraConfig
|
7 |
from controlnet_aux import HEDdetector
|
8 |
from PIL import Image
|
9 |
import cv2 as cv
|
10 |
import os
|
11 |
+
from functools import lru_cache
|
12 |
+
from contextlib import contextmanager
|
13 |
|
14 |
MAX_SEED = np.iinfo(np.int32).max
|
15 |
MAX_IMAGE_SIZE = 1024
|
|
|
20 |
model_id_default = "CompVis/stable-diffusion-v1-4"
|
21 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
22 |
|
23 |
+
class PipelineManager:
|
24 |
+
def __init__(self):
|
25 |
+
self.pipe = None
|
26 |
+
self.current_model = None
|
27 |
+
self.controlnet_cache = {}
|
28 |
+
self.hed = None
|
29 |
+
|
30 |
+
@lru_cache(maxsize=2)
|
31 |
+
def get_controlnet(self, model_name: str) -> ControlNetModel:
|
32 |
+
if model_name not in self.controlnet_cache:
|
33 |
+
self.controlnet_cache[model_name] = ControlNetModel.from_pretrained(
|
34 |
+
model_name,
|
35 |
+
cache_dir="./models_cache",
|
36 |
+
torch_dtype=torch_dtype
|
37 |
+
).to(device)
|
38 |
+
return self.controlnet_cache[model_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
def get_hed_detector(self):
|
41 |
+
if self.hed is None:
|
42 |
+
self.hed = HEDdetector.from_pretrained('lllyasviel/Annotators')
|
43 |
+
return self.hed
|
44 |
+
|
45 |
+
def initialize_pipeline(self, model_id, controlnet_model):
|
46 |
+
controlnet = self.get_controlnet(controlnet_model)
|
47 |
+
if not self.pipe or model_id != self.current_model:
|
48 |
+
self.pipe = self.create_pipeline(model_id, controlnet)
|
49 |
+
self.current_model = model_id
|
50 |
+
return self.pipe
|
51 |
+
|
52 |
+
def create_pipeline(self, model_id, controlnet):
|
53 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
54 |
+
model_id,
|
55 |
+
torch_dtype=torch_dtype,
|
56 |
+
controlnet=controlnet,
|
57 |
+
cache_dir="./models_cache"
|
58 |
+
).to(device)
|
59 |
+
|
60 |
+
if os.path.exists('./lora_logos'):
|
61 |
+
pipe = self.load_lora_adapters(pipe)
|
62 |
+
|
63 |
+
return pipe
|
64 |
|
65 |
+
def load_lora_adapters(self, pipe):
|
66 |
+
unet_dir = os.path.join('./lora_logos', "unet")
|
67 |
+
text_encoder_dir = os.path.join('./lora_logos', "text_encoder")
|
68 |
+
|
69 |
+
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_dir, adapter_name="default")
|
70 |
+
if os.path.exists(text_encoder_dir):
|
71 |
+
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_dir)
|
72 |
+
|
73 |
+
return pipe.to(device)
|
74 |
+
|
75 |
+
@contextmanager
|
76 |
+
def torch_inference_mode():
|
77 |
+
with torch.inference_mode(), torch.autocast(device.type):
|
78 |
+
yield
|
79 |
+
|
80 |
+
def process_embeddings(prompt, negative_prompt, tokenizer, text_encoder):
|
81 |
+
def process_text(text):
|
82 |
+
tokens = tokenizer(text, return_tensors="pt", truncation=False).input_ids
|
83 |
+
chunks = [tokens[:, i:i+77].to(device) for i in range(0, tokens.size(1), 77)]
|
84 |
+
return torch.cat([text_encoder(chunk)[0] for chunk in chunks], dim=1)
|
85 |
|
86 |
+
prompt_emb = process_text(prompt)
|
87 |
+
negative_emb = process_text(negative_prompt)
|
88 |
+
max_len = max(prompt_emb.size(1), negative_emb.size(1))
|
89 |
|
90 |
+
return (
|
91 |
+
torch.nn.functional.pad(prompt_emb, (0, 0, 0, max_len - prompt_emb.size(1))),
|
92 |
+
torch.nn.functional.pad(negative_emb, (0, 0, 0, max_len - negative_emb.size(1)))
|
93 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
def process_control_image(image_path: str, processor: str, hed_detector) -> Image:
|
|
|
|
|
|
|
|
|
96 |
image = load_image(image_path).convert('RGB')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
+
if processor == 'edge_detection':
|
99 |
+
edges = cv.Canny(np.array(image), 80, 160)
|
100 |
+
return Image.fromarray(np.repeat(edges[:, :, None], 3, axis=2))
|
101 |
|
102 |
+
if processor == 'scribble':
|
103 |
+
scribble = hed_detector(image)
|
104 |
+
processed = cv.medianBlur(np.array(scribble), 3)
|
105 |
+
return Image.fromarray(cv.convertScaleAbs(processed, alpha=1.5))
|
106 |
+
|
107 |
+
pipeline_mgr = PipelineManager()
|
108 |
+
controlnet_models = {
|
109 |
+
"edge_detection": "lllyasviel/sd-controlnet-canny",
|
110 |
+
"scribble": "lllyasviel/sd-controlnet-scribble"
|
111 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
+
def infer(**kwargs):
|
114 |
+
generator = torch.Generator(device).manual_seed(kwargs['seed'])
|
|
|
115 |
|
116 |
+
with torch_inference_mode():
|
117 |
+
pipe = pipeline_mgr.initialize_pipeline(
|
118 |
+
kwargs['model_id'],
|
119 |
+
controlnet_models.get(kwargs['cn_mode'], controlnet_models['edge_detection'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
)
|
121 |
+
|
122 |
+
if kwargs['cn_enable'] and not kwargs['cn_image']:
|
123 |
+
raise gr.Error("ControlNet enabled but no image provided!")
|
124 |
+
|
125 |
+
prompt_emb, negative_emb = process_embeddings(
|
126 |
+
kwargs['prompt'],
|
127 |
+
kwargs['negative_prompt'],
|
128 |
+
pipe.tokenizer,
|
129 |
+
pipe.text_encoder
|
130 |
+
)
|
131 |
+
|
132 |
+
params = {
|
133 |
+
'prompt_embeds': prompt_emb,
|
134 |
+
'negative_prompt_embeds': negative_emb,
|
135 |
+
'guidance_scale': kwargs['guidance_scale'],
|
136 |
+
'num_inference_steps': kwargs['num_inference_steps'],
|
137 |
+
'width': kwargs['width'],
|
138 |
+
'height': kwargs['height'],
|
139 |
+
'generator': generator
|
140 |
+
}
|
141 |
+
|
142 |
+
if kwargs['cn_enable']:
|
143 |
+
params['image'] = process_control_image(
|
144 |
+
kwargs['cn_image'],
|
145 |
+
kwargs['cn_mode'],
|
146 |
+
pipeline_mgr.get_hed_detector()
|
147 |
+
)
|
148 |
+
params['controlnet_conditioning_scale'] = kwargs['cn_strength']
|
149 |
+
|
150 |
+
if kwargs.get('ip_enable', False):
|
151 |
+
pipe.load_ip_adapter(IP_ADAPTER, subfolder="models", weight_name=IP_ADAPTER_WEIGHT_NAME)
|
152 |
+
params['ip_adapter_image'] = load_image(kwargs['ip_image']).convert('RGB')
|
153 |
+
pipe.set_ip_adapter_scale(kwargs.get('ip_scale', 0.5))
|
154 |
+
|
155 |
+
pipe.fuse_lora(lora_scale=kwargs.get('lora_scale', 0.5))
|
156 |
+
|
157 |
+
return pipe(**params).images[0]
|
158 |
|
159 |
css = """
|
160 |
#col-container {
|