dezzman commited on
Commit
7bcd902
·
verified ·
1 Parent(s): 3f250af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -188
app.py CHANGED
@@ -1,18 +1,15 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
- from diffusers.utils import load_image, make_image_grid
5
- from diffusers import (
6
- StableDiffusionPipeline,
7
- StableDiffusionControlNetPipeline,
8
- ControlNetModel
9
- )
10
  from peft import PeftModel, LoraConfig
11
  from controlnet_aux import HEDdetector
12
  from PIL import Image
13
  import cv2 as cv
14
  import os
15
-
 
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 1024
@@ -23,198 +20,141 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  model_id_default = "CompVis/stable-diffusion-v1-4"
24
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
25
 
26
- hed = None
27
- dict_controlnet = {
28
- "edge_detection": "lllyasviel/sd-controlnet-canny",
29
- # "pose_estimation": "lllyasviel/sd-controlnet-openpose",
30
- # "depth_map": "lllyasviel/sd-controlnet-depth",
31
- "scribble": "lllyasviel/sd-controlnet-scribble",
32
- # "MLSD": "lllyasviel/sd-controlnet-mlsd"
33
- }
34
-
35
- controlnet = ControlNetModel.from_pretrained(
36
- dict_controlnet["edge_detection"],
37
- cache_dir="./models_cache",
38
- torch_dtype=torch_dtype,
39
- )
40
-
41
-
42
- def get_lora_sd_pipeline(
43
- ckpt_dir='./lora_logos',
44
- base_model_name_or_path=None,
45
- dtype=torch.float16,
46
- adapter_name="default",
47
- controlnet=None
48
- ):
49
-
50
- unet_sub_dir = os.path.join(ckpt_dir, "unet")
51
- text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
52
-
53
- if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
54
- config = LoraConfig.from_pretrained(text_encoder_sub_dir)
55
- base_model_name_or_path = config.base_model_name_or_path
56
-
57
- if base_model_name_or_path is None:
58
- raise ValueError("Please specify the base model name or path")
59
-
60
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
61
- base_model_name_or_path,
62
- torch_dtype=dtype,
63
- controlnet=controlnet,
64
- )
65
-
66
- before_params = pipe.unet.parameters()
67
- pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
68
- pipe.unet.set_adapter(adapter_name)
69
- after_params = pipe.unet.parameters()
70
- print("Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
71
-
72
- if os.path.exists(text_encoder_sub_dir):
73
- pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
74
 
75
- if dtype in (torch.float16, torch.bfloat16):
76
- pipe.unet.half()
77
- pipe.text_encoder.half()
78
-
79
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
82
- tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
83
- chunks = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- with torch.no_grad():
86
- embeds = [text_encoder(chunk.to(text_encoder.device))[0] for chunk in chunks]
 
87
 
88
- return torch.cat(embeds, dim=1)
89
-
90
- def align_embeddings(prompt_embeds, negative_prompt_embeds):
91
- max_length = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
92
- return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
93
- torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
94
-
95
- def map_edge_detection(image_path: str) -> Image:
96
- source_img = load_image(image_path).convert('RGB')
97
- edges = cv.Canny(np.array(source_img), 80, 160)
98
- edges = np.repeat(edges[:, :, None], 3, axis=2)
99
- final_image = Image.fromarray(edges)
100
- return final_image
101
 
102
- def map_scribble(image_path: str) -> Image:
103
- global hed
104
- if not hed:
105
- hed = HEDdetector.from_pretrained('lllyasviel/Annotators')
106
-
107
  image = load_image(image_path).convert('RGB')
108
- scribble_image = hed(image)
109
- image_np = np.array(scribble_image)
110
- image_np = cv.medianBlur(image_np, 3)
111
- image = cv.convertScaleAbs(image_np, alpha=1.5, beta=0)
112
- final_image = Image.fromarray(image)
113
- return final_image
114
-
115
-
116
-
117
- pipe = get_lora_sd_pipeline(
118
- ckpt_dir='./lora_logos',
119
- base_model_name_or_path=model_id_default,
120
- dtype=torch_dtype,
121
- controlnet=controlnet
122
- ).to(device)
123
-
124
-
125
-
126
- def infer(
127
- prompt,
128
- negative_prompt,
129
- width=512,
130
- height=512,
131
- num_inference_steps=20,
132
- model_id='CompVis/stable-diffusion-v1-4',
133
- seed=42,
134
- guidance_scale=7.0,
135
- lora_scale=0.5,
136
- cn_enable=False,
137
- cn_strength=0.0,
138
- cn_mode='edge_detection',
139
- cn_image=None,
140
- ip_enable=False,
141
- ip_scale=0.5,
142
- ip_image=None,
143
- progress=gr.Progress(track_tqdm=True)
144
- ):
145
 
146
- generator = torch.Generator(device).manual_seed(seed)
 
 
147
 
148
- global pipe
149
- global controlnet
150
-
151
- controlnet_changed = False
152
-
153
- if cn_enable:
154
- if dict_controlnet[cn_mode] != pipe.controlnet._name_or_path:
155
- controlnet = ControlNetModel.from_pretrained(
156
- dict_controlnet[cn_mode],
157
- cache_dir="./models_cache",
158
- torch_dtype=torch_dtype
159
- )
160
- controlnet_changed = True
161
- else:
162
- cn_strength = 0.0 # отключаем контролнет принудительно
163
-
164
- if model_id != pipe._name_or_path:
165
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
166
- model_id,
167
- torch_dtype=torch_dtype,
168
- controlnet=controlnet,
169
- controlnet_conditioning_scale=cn_strength,
170
- ).to(device)
171
- elif (model_id == pipe._name_or_path) and controlnet_changed:
172
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
173
- model_id,
174
- torch_dtype=torch_dtype,
175
- controlnet=controlnet,
176
- controlnet_conditioning_scale=cn_strength,
177
- ).to(device)
178
- print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
179
- print(f"LoRA scale applied: {lora_scale}")
180
- pipe.fuse_lora(lora_scale=lora_scale)
181
- elif (model_id == pipe._name_or_path) and not controlnet_changed:
182
- print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
183
- print(f"LoRA scale applied: {lora_scale}")
184
- pipe.fuse_lora(lora_scale=lora_scale)
185
 
186
- prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
187
- negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
188
- prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
189
 
190
- params = {
191
- 'prompt_embeds': prompt_embeds,
192
- 'negative_prompt_embeds': negative_prompt_embeds,
193
- 'guidance_scale': guidance_scale,
194
- 'num_inference_steps': num_inference_steps,
195
- 'width': width,
196
- 'height': height,
197
- 'generator': generator,
198
- }
199
-
200
- if cn_enable:
201
- params['controlnet_conditioning_scale'] = cn_strength
202
- if cn_mode == 'edge_detection':
203
- control_image = map_edge_detection(cn_image)
204
- elif cn_mode == 'scribble':
205
- control_image = map_scribble(cn_image)
206
- params['image'] = control_image
207
-
208
- if ip_enable:
209
- pipe.load_ip_adapter(
210
- IP_ADAPTER,
211
- subfolder="models",
212
- weight_name=IP_ADAPTER_WEIGHT_NAME,
213
  )
214
- params['ip_adapter_image'] = load_image(ip_image).convert('RGB')
215
- pipe.set_ip_adapter_scale(ip_scale)
216
-
217
- return pipe(**params).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  css = """
220
  #col-container {
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
+ from diffusers.utils import load_image
5
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
 
 
 
 
6
  from peft import PeftModel, LoraConfig
7
  from controlnet_aux import HEDdetector
8
  from PIL import Image
9
  import cv2 as cv
10
  import os
11
+ from functools import lru_cache
12
+ from contextlib import contextmanager
13
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 1024
 
20
  model_id_default = "CompVis/stable-diffusion-v1-4"
21
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
22
 
23
+ class PipelineManager:
24
+ def __init__(self):
25
+ self.pipe = None
26
+ self.current_model = None
27
+ self.controlnet_cache = {}
28
+ self.hed = None
29
+
30
+ @lru_cache(maxsize=2)
31
+ def get_controlnet(self, model_name: str) -> ControlNetModel:
32
+ if model_name not in self.controlnet_cache:
33
+ self.controlnet_cache[model_name] = ControlNetModel.from_pretrained(
34
+ model_name,
35
+ cache_dir="./models_cache",
36
+ torch_dtype=torch_dtype
37
+ ).to(device)
38
+ return self.controlnet_cache[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ def get_hed_detector(self):
41
+ if self.hed is None:
42
+ self.hed = HEDdetector.from_pretrained('lllyasviel/Annotators')
43
+ return self.hed
44
+
45
+ def initialize_pipeline(self, model_id, controlnet_model):
46
+ controlnet = self.get_controlnet(controlnet_model)
47
+ if not self.pipe or model_id != self.current_model:
48
+ self.pipe = self.create_pipeline(model_id, controlnet)
49
+ self.current_model = model_id
50
+ return self.pipe
51
+
52
+ def create_pipeline(self, model_id, controlnet):
53
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
54
+ model_id,
55
+ torch_dtype=torch_dtype,
56
+ controlnet=controlnet,
57
+ cache_dir="./models_cache"
58
+ ).to(device)
59
+
60
+ if os.path.exists('./lora_logos'):
61
+ pipe = self.load_lora_adapters(pipe)
62
+
63
+ return pipe
64
 
65
+ def load_lora_adapters(self, pipe):
66
+ unet_dir = os.path.join('./lora_logos', "unet")
67
+ text_encoder_dir = os.path.join('./lora_logos', "text_encoder")
68
+
69
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_dir, adapter_name="default")
70
+ if os.path.exists(text_encoder_dir):
71
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_dir)
72
+
73
+ return pipe.to(device)
74
+
75
+ @contextmanager
76
+ def torch_inference_mode():
77
+ with torch.inference_mode(), torch.autocast(device.type):
78
+ yield
79
+
80
+ def process_embeddings(prompt, negative_prompt, tokenizer, text_encoder):
81
+ def process_text(text):
82
+ tokens = tokenizer(text, return_tensors="pt", truncation=False).input_ids
83
+ chunks = [tokens[:, i:i+77].to(device) for i in range(0, tokens.size(1), 77)]
84
+ return torch.cat([text_encoder(chunk)[0] for chunk in chunks], dim=1)
85
 
86
+ prompt_emb = process_text(prompt)
87
+ negative_emb = process_text(negative_prompt)
88
+ max_len = max(prompt_emb.size(1), negative_emb.size(1))
89
 
90
+ return (
91
+ torch.nn.functional.pad(prompt_emb, (0, 0, 0, max_len - prompt_emb.size(1))),
92
+ torch.nn.functional.pad(negative_emb, (0, 0, 0, max_len - negative_emb.size(1)))
93
+ )
 
 
 
 
 
 
 
 
 
94
 
95
+ def process_control_image(image_path: str, processor: str, hed_detector) -> Image:
 
 
 
 
96
  image = load_image(image_path).convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ if processor == 'edge_detection':
99
+ edges = cv.Canny(np.array(image), 80, 160)
100
+ return Image.fromarray(np.repeat(edges[:, :, None], 3, axis=2))
101
 
102
+ if processor == 'scribble':
103
+ scribble = hed_detector(image)
104
+ processed = cv.medianBlur(np.array(scribble), 3)
105
+ return Image.fromarray(cv.convertScaleAbs(processed, alpha=1.5))
106
+
107
+ pipeline_mgr = PipelineManager()
108
+ controlnet_models = {
109
+ "edge_detection": "lllyasviel/sd-controlnet-canny",
110
+ "scribble": "lllyasviel/sd-controlnet-scribble"
111
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ def infer(**kwargs):
114
+ generator = torch.Generator(device).manual_seed(kwargs['seed'])
 
115
 
116
+ with torch_inference_mode():
117
+ pipe = pipeline_mgr.initialize_pipeline(
118
+ kwargs['model_id'],
119
+ controlnet_models.get(kwargs['cn_mode'], controlnet_models['edge_detection'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  )
121
+
122
+ if kwargs['cn_enable'] and not kwargs['cn_image']:
123
+ raise gr.Error("ControlNet enabled but no image provided!")
124
+
125
+ prompt_emb, negative_emb = process_embeddings(
126
+ kwargs['prompt'],
127
+ kwargs['negative_prompt'],
128
+ pipe.tokenizer,
129
+ pipe.text_encoder
130
+ )
131
+
132
+ params = {
133
+ 'prompt_embeds': prompt_emb,
134
+ 'negative_prompt_embeds': negative_emb,
135
+ 'guidance_scale': kwargs['guidance_scale'],
136
+ 'num_inference_steps': kwargs['num_inference_steps'],
137
+ 'width': kwargs['width'],
138
+ 'height': kwargs['height'],
139
+ 'generator': generator
140
+ }
141
+
142
+ if kwargs['cn_enable']:
143
+ params['image'] = process_control_image(
144
+ kwargs['cn_image'],
145
+ kwargs['cn_mode'],
146
+ pipeline_mgr.get_hed_detector()
147
+ )
148
+ params['controlnet_conditioning_scale'] = kwargs['cn_strength']
149
+
150
+ if kwargs.get('ip_enable', False):
151
+ pipe.load_ip_adapter(IP_ADAPTER, subfolder="models", weight_name=IP_ADAPTER_WEIGHT_NAME)
152
+ params['ip_adapter_image'] = load_image(kwargs['ip_image']).convert('RGB')
153
+ pipe.set_ip_adapter_scale(kwargs.get('ip_scale', 0.5))
154
+
155
+ pipe.fuse_lora(lora_scale=kwargs.get('lora_scale', 0.5))
156
+
157
+ return pipe(**params).images[0]
158
 
159
  css = """
160
  #col-container {