HReynaud commited on
Commit
153ce0d
·
1 Parent(s): 308c0d9
Files changed (1) hide show
  1. demo.py +6 -8
demo.py CHANGED
@@ -306,13 +306,12 @@ def generate_latent_image(mask, class_selection, sampling_steps=50):
306
 
307
 
308
  @spaces.GPU
309
- def decode_images(latents, vae):
310
  """Decode latent representations to pixel space using a VAE.
311
 
312
  Args:
313
  latents: A numpy array of shape [B, C, H, W] for single image
314
  or [B, C, T, H, W] for sequences/animations
315
- vae: The VAE model for decoding
316
 
317
  Returns:
318
  numpy array of decoded images in [B, H, W, 3] format for single image
@@ -321,6 +320,9 @@ def decode_images(latents, vae):
321
  if latents is None:
322
  return None
323
 
 
 
 
324
  # Convert to torch tensor if needed
325
  if not isinstance(latents, torch.Tensor):
326
  latents = torch.from_numpy(latents).to(device, dtype=dtype)
@@ -365,7 +367,6 @@ def decode_images(latents, vae):
365
 
366
  def decode_latent_to_pixel(latent_image):
367
  """Decode a single latent image to pixel space"""
368
- global vae
369
  if latent_image is None:
370
  return None
371
 
@@ -373,7 +374,7 @@ def decode_latent_to_pixel(latent_image):
373
  if len(latent_image.shape) == 3:
374
  latent_image = latent_image[None, ...]
375
 
376
- decoded_image = decode_images(latent_image, vae)
377
  decoded_image = cv2.resize(
378
  decoded_image, (400, 400), interpolation=cv2.INTER_NEAREST
379
  )
@@ -493,7 +494,6 @@ def generate_animation(
493
 
494
  def decode_animation(latent_animation):
495
  """Decode a latent animation to pixel space"""
496
- global vae
497
  if latent_animation is None:
498
  return None
499
 
@@ -506,9 +506,7 @@ def decode_animation(latent_animation):
506
  latent_animation = latent_animation[None, ...] # Add batch dimension
507
 
508
  # Decode using VAE
509
- decoded = decode_images(
510
- latent_animation, vae
511
- ) # Returns B x C x T x H x W numpy array
512
 
513
  # Remove batch dimension and transpose to T x H x W x C
514
  decoded = np.transpose(decoded[0], (1, 2, 3, 0)) # [T, H, W, C]
 
306
 
307
 
308
  @spaces.GPU
309
+ def decode_images(latents):
310
  """Decode latent representations to pixel space using a VAE.
311
 
312
  Args:
313
  latents: A numpy array of shape [B, C, H, W] for single image
314
  or [B, C, T, H, W] for sequences/animations
 
315
 
316
  Returns:
317
  numpy array of decoded images in [B, H, W, 3] format for single image
 
320
  if latents is None:
321
  return None
322
 
323
+ vae = vae.to(device, dtype=dtype)
324
+ vae.eval()
325
+
326
  # Convert to torch tensor if needed
327
  if not isinstance(latents, torch.Tensor):
328
  latents = torch.from_numpy(latents).to(device, dtype=dtype)
 
367
 
368
  def decode_latent_to_pixel(latent_image):
369
  """Decode a single latent image to pixel space"""
 
370
  if latent_image is None:
371
  return None
372
 
 
374
  if len(latent_image.shape) == 3:
375
  latent_image = latent_image[None, ...]
376
 
377
+ decoded_image = decode_images(latent_image)
378
  decoded_image = cv2.resize(
379
  decoded_image, (400, 400), interpolation=cv2.INTER_NEAREST
380
  )
 
494
 
495
  def decode_animation(latent_animation):
496
  """Decode a latent animation to pixel space"""
 
497
  if latent_animation is None:
498
  return None
499
 
 
506
  latent_animation = latent_animation[None, ...] # Add batch dimension
507
 
508
  # Decode using VAE
509
+ decoded = decode_images(latent_animation) # Returns B x C x T x H x W numpy array
 
 
510
 
511
  # Remove batch dimension and transpose to T x H x W x C
512
  decoded = np.transpose(decoded[0], (1, 2, 3, 0)) # [T, H, W, C]