import torch from accelerate.test_utils.testing import get_backend from PIL import Image import os import sys from config import LOGS_DIR, DEPTH_FM_CHECKPOINT, DEPTH_FM_DIR sys.path.append(DEPTH_FM_DIR + '/depthfm') from dfm import DepthFM from unet import UNetModel import einops import numpy as np from torchvision import transforms class DepthEstimator: def __init__(self, image_dir = LOGS_DIR): self.device,_,_ = get_backend() self.image_dir = image_dir self.model = None def _load_model(self): if self.model is None: self.model = DepthFM(DEPTH_FM_CHECKPOINT).to(self.device).eval() else: self.model = self.model.to(self.device).eval() def _unload_model(self): if self.model is not None: self.model = self.model.to("cpu") torch.cuda.empty_cache() def estimate_depth(self, image_path : str) -> list: print("Estimating depth...") predictions_list = [] self._load_model() for img in os.listdir(image_path): if img.endswith(".jpg") or img.endswith(".jpeg") or img.endswith(".png"): image = Image.open(os.path.join(image_path, img)) x = np.array(image) x = einops.rearrange(x, 'h w c -> c h w') x = x / 127.5 - 1 x = torch.tensor(x, dtype=torch.float32)[None] with torch.no_grad(): depth = self.model.predict_depth(x.to(self.device), num_steps=2, ensemble_size=4) # returns a tensor depth.cpu() to_pil = transforms.ToPILImage() PIL_image = to_pil(depth.squeeze()) predictions_list.append({"depth": PIL_image}) del x, depth torch.cuda.empty_cache() self._unload_model() print("Depth estimation complete.") return predictions_list def visualize(self, predictions_list : list) -> None: for (i, prediction) in enumerate(predictions_list): prediction["depth"].save(f"depth_{i}.png") # Estimator = DepthEstimator() # predictions = Estimator.estimate_depth(Estimator.image_dir) # Estimator.visualize(predictions)