Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -110,21 +110,43 @@ controlnet_models = {
|
|
110 |
"scribble": "lllyasviel/sd-controlnet-scribble"
|
111 |
}
|
112 |
|
113 |
-
def infer(
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
with torch_inference_mode():
|
117 |
pipe = pipeline_mgr.initialize_pipeline(
|
118 |
-
|
119 |
-
controlnet_models.get(
|
120 |
)
|
121 |
|
122 |
-
if
|
123 |
raise gr.Error("ControlNet enabled but no image provided!")
|
|
|
|
|
|
|
124 |
|
125 |
prompt_emb, negative_emb = process_embeddings(
|
126 |
-
|
127 |
-
|
128 |
pipe.tokenizer,
|
129 |
pipe.text_encoder
|
130 |
)
|
@@ -132,27 +154,27 @@ def infer(**kwargs):
|
|
132 |
params = {
|
133 |
'prompt_embeds': prompt_emb,
|
134 |
'negative_prompt_embeds': negative_emb,
|
135 |
-
'guidance_scale':
|
136 |
-
'num_inference_steps':
|
137 |
-
'width':
|
138 |
-
'height':
|
139 |
'generator': generator
|
140 |
}
|
141 |
|
142 |
-
if
|
143 |
params['image'] = process_control_image(
|
144 |
-
|
145 |
-
|
146 |
pipeline_mgr.get_hed_detector()
|
147 |
)
|
148 |
-
params['controlnet_conditioning_scale'] =
|
149 |
|
150 |
-
if
|
151 |
pipe.load_ip_adapter(IP_ADAPTER, subfolder="models", weight_name=IP_ADAPTER_WEIGHT_NAME)
|
152 |
-
params['ip_adapter_image'] = load_image(
|
153 |
-
pipe.set_ip_adapter_scale(
|
154 |
|
155 |
-
pipe.fuse_lora(lora_scale=
|
156 |
|
157 |
return pipe(**params).images[0]
|
158 |
|
|
|
110 |
"scribble": "lllyasviel/sd-controlnet-scribble"
|
111 |
}
|
112 |
|
113 |
+
def infer(
|
114 |
+
prompt,
|
115 |
+
negative_prompt,
|
116 |
+
width=512,
|
117 |
+
height=512,
|
118 |
+
num_inference_steps=20,
|
119 |
+
model_id='CompVis/stable-diffusion-v1-4',
|
120 |
+
seed=42,
|
121 |
+
guidance_scale=7.0,
|
122 |
+
lora_scale=0.5,
|
123 |
+
cn_enable=False,
|
124 |
+
cn_strength=0.0,
|
125 |
+
cn_mode='edge_detection',
|
126 |
+
cn_image=None,
|
127 |
+
ip_enable=False,
|
128 |
+
ip_scale=0.5,
|
129 |
+
ip_image=None,
|
130 |
+
progress=gr.Progress(track_tqdm=True)
|
131 |
+
):
|
132 |
+
|
133 |
+
generator = torch.Generator(device).manual_seed(seed)
|
134 |
|
135 |
with torch_inference_mode():
|
136 |
pipe = pipeline_mgr.initialize_pipeline(
|
137 |
+
model_id,
|
138 |
+
controlnet_models.get(cn_mode, controlnet_models['edge_detection'])
|
139 |
)
|
140 |
|
141 |
+
if cn_enable and not cn_image:
|
142 |
raise gr.Error("ControlNet enabled but no image provided!")
|
143 |
+
|
144 |
+
if ip_enable and not ip_image:
|
145 |
+
raise gr.Error("IP-Adapter enabled but no image provided!")
|
146 |
|
147 |
prompt_emb, negative_emb = process_embeddings(
|
148 |
+
prompt,
|
149 |
+
negative_prompt,
|
150 |
pipe.tokenizer,
|
151 |
pipe.text_encoder
|
152 |
)
|
|
|
154 |
params = {
|
155 |
'prompt_embeds': prompt_emb,
|
156 |
'negative_prompt_embeds': negative_emb,
|
157 |
+
'guidance_scale': guidance_scale,
|
158 |
+
'num_inference_steps': num_inference_steps,
|
159 |
+
'width': width,
|
160 |
+
'height': height,
|
161 |
'generator': generator
|
162 |
}
|
163 |
|
164 |
+
if cn_enable:
|
165 |
params['image'] = process_control_image(
|
166 |
+
cn_image,
|
167 |
+
cn_mode,
|
168 |
pipeline_mgr.get_hed_detector()
|
169 |
)
|
170 |
+
params['controlnet_conditioning_scale'] = cn_strength
|
171 |
|
172 |
+
if ip_enable:
|
173 |
pipe.load_ip_adapter(IP_ADAPTER, subfolder="models", weight_name=IP_ADAPTER_WEIGHT_NAME)
|
174 |
+
params['ip_adapter_image'] = load_image(ip_image).convert('RGB')
|
175 |
+
pipe.set_ip_adapter_scale(ip_scale)
|
176 |
|
177 |
+
pipe.fuse_lora(lora_scale=lora_scale)
|
178 |
|
179 |
return pipe(**params).images[0]
|
180 |
|