FQiao's picture
Upload 70 files
3324de2 verified
import os
import torch
import einops
import argparse
import numpy as np
from PIL import Image
from PIL.Image import Resampling
from depthfm import DepthFM
import matplotlib.pyplot as plt
def get_dtype_from_str(dtype_str):
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str]
def resize_max_res(
img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR
) -> Image.Image:
"""
Resize image to limit maximum edge length while keeping aspect ratio.
Args:
img (`Image.Image`):
Image to be resized.
max_edge_resolution (`int`):
Maximum edge length (pixel).
resample_method (`PIL.Image.Resampling`):
Resampling method used to resize images.
Returns:
`Image.Image`: Resized image.
"""
original_width, original_height = img.size
downscale_factor = min( max_edge_resolution / original_width, max_edge_resolution / original_height)
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)
new_width = round(new_width / 64) * 64
new_height = round(new_height / 64) * 64
print(f"Resizing image from {original_width}x{original_height} to {new_width}x{new_height}")
resized_img = img.resize((new_width, new_height), resample=resample_method)
return resized_img, (original_width, original_height)
def load_im(fp, processing_res=-1):
assert os.path.exists(fp), f"File not found: {fp}"
im = Image.open(fp).convert('RGB')
if processing_res < 0:
processing_res = max(im.size)
im, orig_res = resize_max_res(im, processing_res)
x = np.array(im)
x = einops.rearrange(x, 'h w c -> c h w')
x = x / 127.5 - 1
x = torch.tensor(x, dtype=torch.float32)[None]
return x, orig_res
def main(args):
print(f"{'Input':<10}: {args.img}")
print(f"{'Steps':<10}: {args.num_steps}")
print(f"{'Ensemble':<10}: {args.ensemble_size}")
# Load the model
model = DepthFM(args.ckpt)
model.cuda(args.device).eval()
# Load an image
im, orig_res = load_im(args.img, args.processing_res)
im = im.cuda(args.device)
# Generate depth
dtype = get_dtype_from_str(args.dtype)
model.model.dtype = dtype
with torch.autocast(device_type="cuda", dtype=dtype):
depth = model.predict_depth(im, num_steps=args.num_steps, ensemble_size=args.ensemble_size)
depth = depth.squeeze(0).squeeze(0).cpu().numpy() # (h, w) in [0, 1]
# Convert depth to [0, 255] range
if args.no_color:
depth = (depth * 255).astype(np.uint8)
else:
depth = plt.get_cmap('magma')(depth, bytes=True)[..., :3]
# Save the depth map
depth_fp = args.img + '_depth.png'
depth_img = Image.fromarray(depth)
if depth_img.size != orig_res:
depth_img = depth_img.resize(orig_res, Resampling.BILINEAR)
depth_img.save(depth_fp)
print(f"==> Saved depth map to {depth_fp}")
if __name__ == "__main__":
parser = argparse.ArgumentParser("DepthFM Inference")
parser.add_argument("--img", type=str, default="assets/dog.png",
help="Path to the input image")
parser.add_argument("--ckpt", type=str, default="checkpoints/depthfm-v1.ckpt",
help="Path to the model checkpoint")
parser.add_argument("--num_steps", type=int, default=2,
help="Number of steps for ODE solver")
parser.add_argument("--ensemble_size", type=int, default=4,
help="Number of ensemble members")
parser.add_argument("--no_color", action="store_true",
help="If set, the depth map will be grayscale")
parser.add_argument("--device", type=int, default=0,
help="GPU to use")
parser.add_argument("--processing_res", type=int, default=-1,
help="Longer edge of the image will be resized to this resolution. -1 to disable resizing.")
parser.add_argument("--dtype", type=str, choices=["fp32", "bf16", "fp16"], default="fp16",
help="Run with specific precision. Speeds up inference with subtle loss")
args = parser.parse_args()
main(args)