CoreloneH commited on
Commit
308e4c0
·
1 Parent(s): 3e5e167

fix memory bug

Browse files
Files changed (1) hide show
  1. inference/pipeline.py +10 -5
inference/pipeline.py CHANGED
@@ -83,11 +83,16 @@ class RealCustomInferencePipeline:
83
  vision_model_config = unet_config.pop("vision_model_config", None)
84
  self.vision_model_config = vision_model_config.pop("vision_model_config", None)
85
 
86
- self.unet_model = UNet2DConditionModelDiffusers(**unet_config)
87
-
88
- self.unet_model.eval().to(self.device).to(self.torch_dtype)
89
- self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False)
90
- self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False)
 
 
 
 
 
91
  print("loading unet model finished.")
92
 
93
  def _reload_unet_checkpoint(self, unet_checkpoint, realcustom_checkpoint):
 
83
  vision_model_config = unet_config.pop("vision_model_config", None)
84
  self.vision_model_config = vision_model_config.pop("vision_model_config", None)
85
 
86
+ # self.unet_model = UNet2DConditionModelDiffusers(**unet_config)
87
+
88
+ # self.unet_model.eval().to(self.device).to(self.torch_dtype)
89
+ # self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False)
90
+ # self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False)
91
+ with torch.device("meta"):
92
+ self.unet_model = UNet2DConditionModelDiffusers(**unet_config)
93
+ self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False, assign=True)
94
+ self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False, assign=True)
95
+ self.unet_model.eval()
96
  print("loading unet model finished.")
97
 
98
  def _reload_unet_checkpoint(self, unet_checkpoint, realcustom_checkpoint):