jsu27 commited on
Commit
72b0442
·
1 Parent(s): 1760b4e
Files changed (1) hide show
  1. app.py +38 -56
app.py CHANGED
@@ -152,58 +152,37 @@ def combine_components_slice(model, gd, im1, im2, indices=None, sample_method='d
152
 
153
  def decompose_image_demo(im, model):
154
  sample_method = 'ddim'
155
- result = gen_image_and_components(MODELS[model], GD[sample_method], im, sample_method=sample_method, num_images=1)
156
  return result.permute(1, 2, 0).numpy()
157
 
158
 
159
  def combine_images_demo(im1, im2, model):
160
  sample_method = 'ddim'
161
- result = combine_components_slice(MODELS[model], GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1)
162
  return result.permute(1, 2, 0).numpy()
163
 
164
 
 
 
165
 
 
 
 
 
 
 
166
 
167
- ckpt_path = download_model('clevr') # 'clevr_model.pt'
 
168
 
169
- model_kwargs = unet_model_defaults()
170
- # model parameters
171
- model_kwargs.update(dict(
172
- emb_dim=64,
173
- enc_channels=128
174
- ))
175
- clevr_model = create_diffusion_model(**model_kwargs)
176
- clevr_model.eval()
177
 
178
- device = 'cuda' if th.cuda.is_available() else 'cpu'
179
- clevr_model.to(device)
180
-
181
- print(f'loading from {ckpt_path}')
182
- checkpoint = th.load(ckpt_path, map_location='cpu')
183
-
184
- clevr_model.load_state_dict(checkpoint)
185
-
186
-
187
-
188
- ckpt_path = download_model('celebahq') # 'celeb_model.pt'
189
-
190
- model_kwargs = unet_model_defaults()
191
- # model parameters
192
- model_kwargs.update(dict(
193
- enc_channels=128
194
- ))
195
- celeb_model = create_diffusion_model(**model_kwargs)
196
- celeb_model.eval()
197
 
198
  device = 'cuda' if th.cuda.is_available() else 'cpu'
199
- celeb_model.to(device)
200
-
201
- print(f'loading from {ckpt_path}')
202
- checkpoint = th.load(ckpt_path, map_location='cpu')
203
-
204
- celeb_model.load_state_dict(checkpoint)
205
-
206
 
 
 
207
 
208
  MODELS = {
209
  'CLEVR': clevr_model,
@@ -222,7 +201,7 @@ with gr.Blocks() as demo:
222
 
223
  gr.Markdown(
224
  """<h4>Decomposition and reconstruction of images</h4>""")
225
- with gr.Row().style(equal_height=True):
226
  with gr.Column():
227
  with gr.Row():
228
  decomp_input = gr.Image(type='numpy', label='Input')
@@ -230,19 +209,21 @@ with gr.Blocks() as demo:
230
  decomp_model = gr.Radio(
231
  ['CLEVR', 'CelebA-HQ'], type="value", label='Model',
232
  value='CLEVR')
 
 
 
 
 
 
 
 
 
 
233
 
234
  with gr.Column():
235
  decomp_output = gr.Image(type='numpy')
236
  decomp_button = gr.Button("Generate")
237
- with gr.Row():
238
-
239
- # image_examples = [os.path.join(os.path.dirname(__file__), 'sample_images/clevr_im_10.png'), 'CLEVR']
240
- decomp_examples = [['sample_images/clevr_im_10.png', 'CLEVR'],
241
- ['sample_images/celebahq_im_15.jpg', 'CelebA-HQ']]
242
- decomp_img_examples = gr.Examples(
243
- examples=decomp_examples,
244
- inputs=[decomp_input, decomp_model]
245
- )
246
 
247
 
248
  gr.Markdown(
@@ -260,20 +241,21 @@ with gr.Blocks() as demo:
260
  comb_model = gr.Radio(
261
  ['CLEVR', 'CelebA-HQ'], type="value", label='Model',
262
  value='CLEVR')
263
-
 
 
 
 
 
 
 
 
264
 
265
 
266
  with gr.Column(scale=1):
267
  comb_output = gr.Image(type='numpy')
268
  comb_button = gr.Button("Generate")
269
- with gr.Row():
270
-
271
- comb_examples = [['sample_images/clevr_im_10.png', 'sample_images/clevr_im_25.png', 'CLEVR'],
272
- ['sample_images/celebahq_im_15.jpg', 'sample_images/celebahq_im_21.jpg', 'CelebA-HQ']]
273
- comb_img_examples = gr.Examples(
274
- examples=comb_examples,
275
- inputs=[comb_input1, comb_input2, comb_model]
276
- )
277
 
278
  decomp_button.click(decompose_image_demo,
279
  inputs=[decomp_input, decomp_model],
 
152
 
153
  def decompose_image_demo(im, model):
154
  sample_method = 'ddim'
155
+ result = gen_image_and_components(MODELS[model], GD[sample_method], im, sample_method=sample_method, num_images=1, device=device)
156
  return result.permute(1, 2, 0).numpy()
157
 
158
 
159
  def combine_images_demo(im1, im2, model):
160
  sample_method = 'ddim'
161
+ result = combine_components_slice(MODELS[model], GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1, device=device)
162
  return result.permute(1, 2, 0).numpy()
163
 
164
 
165
+ def load_model(dataset, extra_kwargs={}, device='cuda'):
166
+ ckpt_path = download_model(dataset)
167
 
168
+ model_kwargs = unet_model_defaults()
169
+ # model parameters
170
+ model_kwargs.update(extra_kwargs)
171
+ model = create_diffusion_model(**model_kwargs)
172
+ model.eval()
173
+ model.to(device)
174
 
175
+ print(f'loading from {ckpt_path}')
176
+ checkpoint = th.load(ckpt_path, map_location='cpu')
177
 
178
+ model.load_state_dict(checkpoint)
179
+ return model
 
 
 
 
 
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  device = 'cuda' if th.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
183
 
184
+ clevr_model = load_model('clevr', extra_kwargs=dict(embed_dim=64, enc_channels=128), device=device)
185
+ celeb_model = load_model('celebahq', extra_kwargs=dict(enc_channels=128), device=device)
186
 
187
  MODELS = {
188
  'CLEVR': clevr_model,
 
201
 
202
  gr.Markdown(
203
  """<h4>Decomposition and reconstruction of images</h4>""")
204
+ with gr.Row():
205
  with gr.Column():
206
  with gr.Row():
207
  decomp_input = gr.Image(type='numpy', label='Input')
 
209
  decomp_model = gr.Radio(
210
  ['CLEVR', 'CelebA-HQ'], type="value", label='Model',
211
  value='CLEVR')
212
+
213
+ with gr.Row():
214
+
215
+ # image_examples = [os.path.join(os.path.dirname(__file__), 'sample_images/clevr_im_10.png'), 'CLEVR']
216
+ decomp_examples = [['sample_images/clevr_im_10.png', 'CLEVR'],
217
+ ['sample_images/celebahq_im_15.jpg', 'CelebA-HQ']]
218
+ decomp_img_examples = gr.Examples(
219
+ examples=decomp_examples,
220
+ inputs=[decomp_input, decomp_model]
221
+ )
222
 
223
  with gr.Column():
224
  decomp_output = gr.Image(type='numpy')
225
  decomp_button = gr.Button("Generate")
226
+
 
 
 
 
 
 
 
 
227
 
228
 
229
  gr.Markdown(
 
241
  comb_model = gr.Radio(
242
  ['CLEVR', 'CelebA-HQ'], type="value", label='Model',
243
  value='CLEVR')
244
+
245
+ with gr.Row():
246
+
247
+ comb_examples = [['sample_images/clevr_im_10.png', 'sample_images/clevr_im_25.png', 'CLEVR'],
248
+ ['sample_images/celebahq_im_15.jpg', 'sample_images/celebahq_im_21.jpg', 'CelebA-HQ']]
249
+ comb_img_examples = gr.Examples(
250
+ examples=comb_examples,
251
+ inputs=[comb_input1, comb_input2, comb_model]
252
+ )
253
 
254
 
255
  with gr.Column(scale=1):
256
  comb_output = gr.Image(type='numpy')
257
  comb_button = gr.Button("Generate")
258
+
 
 
 
 
 
 
 
259
 
260
  decomp_button.click(decompose_image_demo,
261
  inputs=[decomp_input, decomp_model],