|
|
|
import pathlib |
|
import gradio as gr |
|
import matplotlib as mpl |
|
import numpy as np |
|
import PIL.Image |
|
import spaces |
|
import torch |
|
from gradio_imageslider import ImageSlider |
|
from transformers import DepthProForDepthEstimation, DepthProImageProcessorFast |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf") |
|
model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device) |
|
|
|
|
|
cmap = mpl.colormaps.get_cmap("gray") |
|
|
|
@spaces.GPU(duration=20) |
|
@torch.inference_mode() |
|
def run(image: PIL.Image.Image) -> tuple[tuple[PIL.Image.Image, PIL.Image.Image], str, str, str, str]: |
|
inputs = image_processor(images=image, return_tensors="pt").to(device) |
|
outputs = model(**inputs) |
|
post_processed_output = image_processor.post_process_depth_estimation( |
|
outputs, target_sizes=[(image.height, image.width)], |
|
) |
|
|
|
depth_raw = post_processed_output[0]["predicted_depth"] |
|
depth_min = depth_raw.min().item() |
|
depth_max = depth_raw.max().item() |
|
|
|
inverse_depth = 1 / depth_raw |
|
normalized_inverse_depth = (inverse_depth - inverse_depth.min()) / (inverse_depth.max() - inverse_depth.min()) |
|
normalized_inverse_depth = normalized_inverse_depth * 255.0 |
|
normalized_inverse_depth = normalized_inverse_depth.detach().cpu().numpy() |
|
normalized_inverse_depth = PIL.Image.fromarray(normalized_inverse_depth.astype("uint8")) |
|
|
|
|
|
|
|
colored_inverse_depth = PIL.Image.fromarray( |
|
(cmap(np.array(normalized_inverse_depth))[:, :, :3] * 255).astype(np.uint8) |
|
) |
|
|
|
field_of_view = post_processed_output[0]["field_of_view"].item() |
|
focal_length = post_processed_output[0]["focal_length"].item() |
|
|
|
return ( |
|
(image, colored_inverse_depth), |
|
f"{field_of_view:.2f}", |
|
f"{focal_length:.2f}", |
|
f"{depth_min:.2f}", |
|
f"{depth_max:.2f}", |
|
) |
|
|
|
with gr.Blocks(css="style.css") as demo: |
|
gr.Markdown("# DepthPro") |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil") |
|
run_button = gr.Button() |
|
with gr.Column(): |
|
output_image = ImageSlider() |
|
with gr.Row(): |
|
output_field_of_view = gr.Textbox(label="Field of View") |
|
output_focal_length = gr.Textbox(label="Focal Length") |
|
output_depth_min = gr.Textbox(label="Depth Min") |
|
output_depth_max = gr.Textbox(label="Depth Max") |
|
|
|
gr.Examples( |
|
examples=sorted(pathlib.Path("images").glob("*.jpg")), |
|
inputs=input_image, |
|
fn=run, |
|
outputs=[ |
|
output_image, |
|
output_field_of_view, |
|
output_focal_length, |
|
output_depth_min, |
|
output_depth_max, |
|
], |
|
) |
|
|
|
run_button.click( |
|
fn=run, |
|
inputs=input_image, |
|
outputs=[ |
|
output_image, |
|
output_field_of_view, |
|
output_focal_length, |
|
output_depth_min, |
|
output_depth_max, |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch() |