huntrezz commited on
Commit
40334e7
·
verified ·
1 Parent(s): 99bbe3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -3
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  from transformers import DPTForDepthEstimation, DPTImageProcessor
5
  import gradio as gr
6
  import torch.nn.utils.prune as prune
 
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
@@ -39,6 +40,23 @@ def preprocess_image(image):
39
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
40
  return image / 255.0
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @torch.inference_mode()
43
  def process_frame(image):
44
  if image is None:
@@ -46,10 +64,28 @@ def process_frame(image):
46
  preprocessed = preprocess_image(image)
47
  predicted_depth = model(preprocessed).predicted_depth
48
  depth_map = predicted_depth.squeeze().cpu().numpy()
 
 
49
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
50
- depth_map = (depth_map * 255).astype(np.uint8)
51
- depth_map_colored = cv2.applyColorMap(depth_map, cv2.COLORMAP_INFERNO)
52
- return cv2.cvtColor(depth_map_colored, cv2.COLOR_BGR2RGB)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  interface = gr.Interface(
55
  fn=process_frame,
 
4
  from transformers import DPTForDepthEstimation, DPTImageProcessor
5
  import gradio as gr
6
  import torch.nn.utils.prune as prune
7
+ import open3d as o3d
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
 
40
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
41
  return image / 255.0
42
 
43
+ def create_point_cloud(depth_map, color_image):
44
+ rows, cols = depth_map.shape
45
+ c, r = np.meshgrid(np.arange(cols), np.arange(rows), sparse=True)
46
+ valid = (depth_map > 0) & (depth_map < 1000)
47
+ z = np.where(valid, depth_map, 0)
48
+ x = np.where(valid, z * (c - cols / 2) / cols, 0)
49
+ y = np.where(valid, z * (r - rows / 2) / rows, 0)
50
+
51
+ points = np.dstack((x, y, z)).reshape(-1, 3)
52
+ colors = color_image.reshape(-1, 3)
53
+
54
+ pcd = o3d.geometry.PointCloud()
55
+ pcd.points = o3d.utility.Vector3dVector(points)
56
+ pcd.colors = o3d.utility.Vector3dVector(colors / 255.0)
57
+
58
+ return pcd
59
+
60
  @torch.inference_mode()
61
  def process_frame(image):
62
  if image is None:
 
64
  preprocessed = preprocess_image(image)
65
  predicted_depth = model(preprocessed).predicted_depth
66
  depth_map = predicted_depth.squeeze().cpu().numpy()
67
+
68
+ # Normalize depth map
69
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
70
+
71
+ # Create point cloud
72
+ pcd = create_point_cloud(depth_map, image)
73
+
74
+ # Visualize point cloud
75
+ vis = o3d.visualization.Visualizer()
76
+ vis.create_window()
77
+ vis.add_geometry(pcd)
78
+ vis.poll_events()
79
+ vis.update_renderer()
80
+
81
+ # Capture the visualization as an image
82
+ image = vis.capture_screen_float_buffer(False)
83
+ vis.destroy_window()
84
+
85
+ # Convert the image to numpy array
86
+ point_cloud_image = (np.asarray(image) * 255).astype(np.uint8)
87
+
88
+ return point_cloud_image
89
 
90
  interface = gr.Interface(
91
  fn=process_frame,