Spaces:
Runtime error
Runtime error
fix memory bug
Browse files- inference/pipeline.py +10 -5
inference/pipeline.py
CHANGED
@@ -83,11 +83,16 @@ class RealCustomInferencePipeline:
|
|
83 |
vision_model_config = unet_config.pop("vision_model_config", None)
|
84 |
self.vision_model_config = vision_model_config.pop("vision_model_config", None)
|
85 |
|
86 |
-
self.unet_model = UNet2DConditionModelDiffusers(**unet_config)
|
87 |
-
|
88 |
-
self.unet_model.eval().to(self.device).to(self.torch_dtype)
|
89 |
-
self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False)
|
90 |
-
self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False)
|
|
|
|
|
|
|
|
|
|
|
91 |
print("loading unet model finished.")
|
92 |
|
93 |
def _reload_unet_checkpoint(self, unet_checkpoint, realcustom_checkpoint):
|
|
|
83 |
vision_model_config = unet_config.pop("vision_model_config", None)
|
84 |
self.vision_model_config = vision_model_config.pop("vision_model_config", None)
|
85 |
|
86 |
+
# self.unet_model = UNet2DConditionModelDiffusers(**unet_config)
|
87 |
+
|
88 |
+
# self.unet_model.eval().to(self.device).to(self.torch_dtype)
|
89 |
+
# self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False)
|
90 |
+
# self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False)
|
91 |
+
with torch.device("meta"):
|
92 |
+
self.unet_model = UNet2DConditionModelDiffusers(**unet_config)
|
93 |
+
self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False, assign=True)
|
94 |
+
self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False, assign=True)
|
95 |
+
self.unet_model.eval()
|
96 |
print("loading unet model finished.")
|
97 |
|
98 |
def _reload_unet_checkpoint(self, unet_checkpoint, realcustom_checkpoint):
|