dezzman commited on
Commit
b8002e7
·
verified ·
1 Parent(s): 7bcd902

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -19
app.py CHANGED
@@ -110,21 +110,43 @@ controlnet_models = {
110
  "scribble": "lllyasviel/sd-controlnet-scribble"
111
  }
112
 
113
- def infer(**kwargs):
114
- generator = torch.Generator(device).manual_seed(kwargs['seed'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  with torch_inference_mode():
117
  pipe = pipeline_mgr.initialize_pipeline(
118
- kwargs['model_id'],
119
- controlnet_models.get(kwargs['cn_mode'], controlnet_models['edge_detection'])
120
  )
121
 
122
- if kwargs['cn_enable'] and not kwargs['cn_image']:
123
  raise gr.Error("ControlNet enabled but no image provided!")
 
 
 
124
 
125
  prompt_emb, negative_emb = process_embeddings(
126
- kwargs['prompt'],
127
- kwargs['negative_prompt'],
128
  pipe.tokenizer,
129
  pipe.text_encoder
130
  )
@@ -132,27 +154,27 @@ def infer(**kwargs):
132
  params = {
133
  'prompt_embeds': prompt_emb,
134
  'negative_prompt_embeds': negative_emb,
135
- 'guidance_scale': kwargs['guidance_scale'],
136
- 'num_inference_steps': kwargs['num_inference_steps'],
137
- 'width': kwargs['width'],
138
- 'height': kwargs['height'],
139
  'generator': generator
140
  }
141
 
142
- if kwargs['cn_enable']:
143
  params['image'] = process_control_image(
144
- kwargs['cn_image'],
145
- kwargs['cn_mode'],
146
  pipeline_mgr.get_hed_detector()
147
  )
148
- params['controlnet_conditioning_scale'] = kwargs['cn_strength']
149
 
150
- if kwargs.get('ip_enable', False):
151
  pipe.load_ip_adapter(IP_ADAPTER, subfolder="models", weight_name=IP_ADAPTER_WEIGHT_NAME)
152
- params['ip_adapter_image'] = load_image(kwargs['ip_image']).convert('RGB')
153
- pipe.set_ip_adapter_scale(kwargs.get('ip_scale', 0.5))
154
 
155
- pipe.fuse_lora(lora_scale=kwargs.get('lora_scale', 0.5))
156
 
157
  return pipe(**params).images[0]
158
 
 
110
  "scribble": "lllyasviel/sd-controlnet-scribble"
111
  }
112
 
113
+ def infer(
114
+ prompt,
115
+ negative_prompt,
116
+ width=512,
117
+ height=512,
118
+ num_inference_steps=20,
119
+ model_id='CompVis/stable-diffusion-v1-4',
120
+ seed=42,
121
+ guidance_scale=7.0,
122
+ lora_scale=0.5,
123
+ cn_enable=False,
124
+ cn_strength=0.0,
125
+ cn_mode='edge_detection',
126
+ cn_image=None,
127
+ ip_enable=False,
128
+ ip_scale=0.5,
129
+ ip_image=None,
130
+ progress=gr.Progress(track_tqdm=True)
131
+ ):
132
+
133
+ generator = torch.Generator(device).manual_seed(seed)
134
 
135
  with torch_inference_mode():
136
  pipe = pipeline_mgr.initialize_pipeline(
137
+ model_id,
138
+ controlnet_models.get(cn_mode, controlnet_models['edge_detection'])
139
  )
140
 
141
+ if cn_enable and not cn_image:
142
  raise gr.Error("ControlNet enabled but no image provided!")
143
+
144
+ if ip_enable and not ip_image:
145
+ raise gr.Error("IP-Adapter enabled but no image provided!")
146
 
147
  prompt_emb, negative_emb = process_embeddings(
148
+ prompt,
149
+ negative_prompt,
150
  pipe.tokenizer,
151
  pipe.text_encoder
152
  )
 
154
  params = {
155
  'prompt_embeds': prompt_emb,
156
  'negative_prompt_embeds': negative_emb,
157
+ 'guidance_scale': guidance_scale,
158
+ 'num_inference_steps': num_inference_steps,
159
+ 'width': width,
160
+ 'height': height,
161
  'generator': generator
162
  }
163
 
164
+ if cn_enable:
165
  params['image'] = process_control_image(
166
+ cn_image,
167
+ cn_mode,
168
  pipeline_mgr.get_hed_detector()
169
  )
170
+ params['controlnet_conditioning_scale'] = cn_strength
171
 
172
+ if ip_enable:
173
  pipe.load_ip_adapter(IP_ADAPTER, subfolder="models", weight_name=IP_ADAPTER_WEIGHT_NAME)
174
+ params['ip_adapter_image'] = load_image(ip_image).convert('RGB')
175
+ pipe.set_ip_adapter_scale(ip_scale)
176
 
177
+ pipe.fuse_lora(lora_scale=lora_scale)
178
 
179
  return pipe(**params).images[0]
180