Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
b812142
1
Parent(s):
36de41f
force the device
Browse files
app.py
CHANGED
@@ -119,6 +119,9 @@ class ImageGenerator:
|
|
119 |
max_length=max_length,
|
120 |
dtype=dtype,
|
121 |
)
|
|
|
|
|
|
|
122 |
|
123 |
def prepare(self, prompt, img, ref_image, ref_image_raw):
|
124 |
bs, _, h, w = img.shape
|
@@ -314,6 +317,7 @@ class ImageGenerator:
|
|
314 |
|
315 |
ref_images_raw = self.load_image(ref_images_raw)
|
316 |
ref_images_raw = ref_images_raw.to(self.device)
|
|
|
317 |
ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
|
318 |
|
319 |
seed = int(seed)
|
@@ -398,7 +402,7 @@ def prepare_infer_func():
|
|
398 |
|
399 |
return image_edit.generate_image
|
400 |
|
401 |
-
@spaces.GPU
|
402 |
def inference(prompt, ref_images, seed, size_level, infer_func=None):
|
403 |
start_time = time.time()
|
404 |
|
|
|
119 |
max_length=max_length,
|
120 |
dtype=dtype,
|
121 |
)
|
122 |
+
self.ae = self.ae.to(device=self.device, dtype=torch.float32)
|
123 |
+
self.dit = self.dit.to(device=self.device, dtype=dtype)
|
124 |
+
self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
|
125 |
|
126 |
def prepare(self, prompt, img, ref_image, ref_image_raw):
|
127 |
bs, _, h, w = img.shape
|
|
|
317 |
|
318 |
ref_images_raw = self.load_image(ref_images_raw)
|
319 |
ref_images_raw = ref_images_raw.to(self.device)
|
320 |
+
print(f'self.ae, self.dit device: {self.ae.device}, {self.dit.device}')
|
321 |
ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
|
322 |
|
323 |
seed = int(seed)
|
|
|
402 |
|
403 |
return image_edit.generate_image
|
404 |
|
405 |
+
@spaces.GPU(duration=240)
|
406 |
def inference(prompt, ref_images, seed, size_level, infer_func=None):
|
407 |
start_time = time.time()
|
408 |
|