listen2you003 commited on
Commit
b812142
·
1 Parent(s): 36de41f

force the device

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -119,6 +119,9 @@ class ImageGenerator:
119
  max_length=max_length,
120
  dtype=dtype,
121
  )
 
 
 
122
 
123
  def prepare(self, prompt, img, ref_image, ref_image_raw):
124
  bs, _, h, w = img.shape
@@ -314,6 +317,7 @@ class ImageGenerator:
314
 
315
  ref_images_raw = self.load_image(ref_images_raw)
316
  ref_images_raw = ref_images_raw.to(self.device)
 
317
  ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
318
 
319
  seed = int(seed)
@@ -398,7 +402,7 @@ def prepare_infer_func():
398
 
399
  return image_edit.generate_image
400
 
401
- @spaces.GPU
402
  def inference(prompt, ref_images, seed, size_level, infer_func=None):
403
  start_time = time.time()
404
 
 
119
  max_length=max_length,
120
  dtype=dtype,
121
  )
122
+ self.ae = self.ae.to(device=self.device, dtype=torch.float32)
123
+ self.dit = self.dit.to(device=self.device, dtype=dtype)
124
+ self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
125
 
126
  def prepare(self, prompt, img, ref_image, ref_image_raw):
127
  bs, _, h, w = img.shape
 
317
 
318
  ref_images_raw = self.load_image(ref_images_raw)
319
  ref_images_raw = ref_images_raw.to(self.device)
320
+ print(f'self.ae, self.dit device: {self.ae.device}, {self.dit.device}')
321
  ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
322
 
323
  seed = int(seed)
 
402
 
403
  return image_edit.generate_image
404
 
405
+ @spaces.GPU(duration=240)
406
  def inference(prompt, ref_images, seed, size_level, infer_func=None):
407
  start_time = time.time()
408