import spaces import os import sys os.environ["PYOPENGL_PLATFORM"] = "egl" os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1" import gradio as gr import cv2 import numpy as np import torch from ultralytics import YOLO from pathlib import Path import argparse import json import trimesh from torchvision import transforms from typing import Dict, Optional from PIL import Image, ImageDraw from huggingface_hub import hf_hub_download from lang_sam import LangSAM from wilor.models import load_wilor from wilor.utils import recursive_to from wilor.datasets.vitdet_dataset import ViTDetDataset from hort.models import load_hort from hort.utils.renderer import Renderer, cam_crop_to_new from hort.utils.img_utils import process_bbox, generate_patch_image, PerspectiveCamera from ultralytics import YOLO LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353) STEEL_BLUE=(0.2745098, 0.5098039, 0.7058824) def install_cuda_toolkit(): CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run" CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) if not os.path.exists(CUDA_TOOLKIT_FILE): os.system("pip install gradio==5.0.2") print("start to download cuda toolkit") os.system(f"wget -q {CUDA_TOOLKIT_URL} -O {CUDA_TOOLKIT_FILE}") os.system(f"chmod +x {CUDA_TOOLKIT_FILE}") print("start to install cuda toolkit") os.system(f"{CUDA_TOOLKIT_FILE} --silent --toolkit") os.environ["CUDA_HOME"] = "/usr/local/cuda" # install_cuda_toolkit() # print("start to install pointnet++") # os.system("cd /home/user/app/hort/models/tgs/models/snowflake/pointnet2_ops_lib && python setup.py install && cd /home/user/app") wilor_checkpoint_path = hf_hub_download(repo_id="zerchen/hort_models", filename="wilor_final.ckpt") hort_checkpoint_path = hf_hub_download(repo_id="zerchen/hort_models", filename="hort_final.pth.tar") # Download and load checkpoints wilor_model, wilor_model_cfg = load_wilor(checkpoint_path = wilor_checkpoint_path, cfg_path= './pretrained_models/model_config.yaml') hand_detector = YOLO('./pretrained_models/detector.pt') # Setup the renderer renderer = Renderer(wilor_model_cfg, faces=wilor_model.mano.faces) # Setup the SAM model sam_model = LangSAM(sam_type="sam2.1_hiera_large") # Setup the HORT model hort_model = load_hort(hort_checkpoint_path) device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') wilor_model = wilor_model.to(device) hand_detector = hand_detector.to(device) hort_model = hort_model.to(device) wilor_model.eval() hort_model.eval() image_transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) def calculate_iou(box1, box2): x1_inter = max(box1[0], box2[0]) y1_inter = max(box1[1], box2[1]) x2_inter = min(box1[2], box2[2]) y2_inter = min(box1[3], box2[3]) # Compute intersection area inter_width = max(0, x2_inter - x1_inter) inter_height = max(0, y2_inter - y1_inter) intersection = inter_width * inter_height # Compute areas of each box area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) # Compute union union = area_box1 + area_box2 - intersection # Compute IoU return intersection / union if union > 0 else 0.0 @spaces.GPU() def run_model(image, conf, IoU_threshold=0.5): img_cv2 = image[..., ::-1] img_pil = Image.fromarray(image) pred_obj = sam_model.predict([img_pil], ["manipulated object"]) bbox_obj = pred_obj[0]["boxes"][0].reshape((-1, 2)) detections = hand_detector(img_cv2, conf=conf, verbose=False, iou=IoU_threshold)[0] bboxes = [] is_right = [] for det in detections: Bbox = det.boxes.data.cpu().detach().squeeze().numpy() is_right.append(det.boxes.cls.cpu().detach().squeeze().item()) bboxes.append(Bbox[:4].tolist()) if len(bboxes) == 0: print("no hands in this image") elif len(bboxes) == 1: bbox_hand = np.array(bboxes[0]).reshape((-1, 2)) elif len(bboxes) > 1: hand_idx = None max_iou = -10. for cur_idx, cur_bbox in enumerate(bboxes): cur_iou = calculate_iou(cur_bbox, bbox_obj.reshape(-1).tolist()) if cur_iou >= max_iou: hand_idx = cur_idx max_iou = cur_iou bbox_hand = np.array(bboxes[hand_idx]).reshape((-1, 2)) bboxes = [bboxes[hand_idx]] is_right = [is_right[hand_idx]] tl = np.min(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0) br = np.max(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0) box_size = br - tl bbox = np.concatenate([tl - 10, box_size + 20], axis=0) ho_bbox = process_bbox(bbox) boxes = np.stack(bboxes) right = np.stack(is_right) if not right: new_x1 = img_cv2.shape[1] - boxes[0][2] new_x2 = img_cv2.shape[1] - boxes[0][0] boxes[0][0] = new_x1 boxes[0][2] = new_x2 ho_bbox[0] = img_cv2.shape[1] - (ho_bbox[0] + ho_bbox[2]) img_cv2 = cv2.flip(img_cv2, 1) right[0] = 1. crop_img_cv2, _ = generate_patch_image(img_cv2, ho_bbox, (224, 224), 0, 1.0, 0) dataset = ViTDetDataset(wilor_model_cfg, img_cv2, boxes, right, rescale_factor=2.0) dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) for batch in dataloader: batch = recursive_to(batch, device) with torch.no_grad(): out = wilor_model(batch) pred_cam = out['pred_cam'] box_center = batch["box_center"].float() box_size = batch["box_size"].float() img_size = batch["img_size"].float() scaled_focal_length = wilor_model_cfg.EXTRA.FOCAL_LENGTH / wilor_model_cfg.MODEL.IMAGE_SIZE * 224 pred_cam_t_full = cam_crop_to_new(pred_cam, box_center, box_size, img_size, torch.from_numpy(np.array(ho_bbox, dtype=np.float32))[None, :].to(img_size.device), scaled_focal_length).detach().cpu().numpy() batch_size = batch['img'].shape[0] for n in range(batch_size): verts = out['pred_vertices'][n].detach().cpu().numpy() joints = out['pred_keypoints_3d'][n].detach().cpu().numpy() is_right = batch['right'][n].cpu().numpy() palm = (verts[95] + verts[22]) / 2 cam_t = pred_cam_t_full[n] img_input = image_transform(crop_img_cv2[:, :, ::-1]).unsqueeze(0).cuda() camera = PerspectiveCamera(5000 / 256 * 224, 5000 / 256 * 224, 112, 112) cam_intr = camera.intrinsics metas = dict() metas["right_hand_verts_3d"] = torch.from_numpy((verts + cam_t)[None]).cuda() metas["right_hand_joints_3d"] = torch.from_numpy((joints + cam_t)[None]).cuda() metas["right_hand_palm"] = torch.from_numpy((palm + cam_t)[None]).cuda() metas["cam_intr"] = torch.from_numpy(cam_intr[None]).cuda() with torch.amp.autocast(device_type='cuda', dtype=torch.float16): pc_results = hort_model(img_input, metas) objtrans = pc_results["objtrans"][0].detach().cpu().numpy() pointclouds_up = pc_results["pointclouds_up"][0].detach().cpu().numpy() * 0.3 reconstructions = {'verts': verts, 'palm': palm, 'objtrans': objtrans, 'objpcs': pointclouds_up, 'cam_t': cam_t, 'right': is_right, 'img_size': 224, 'focal': scaled_focal_length} camera_translation = cam_t.copy() hand_mesh = renderer.mesh(verts, camera_translation, LIGHT_PURPLE, is_right=is_right) obj_pcd = trimesh.PointCloud(reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'] + camera_translation, colors=[70, 130, 180, 255]) scene = trimesh.Scene([hand_mesh, obj_pcd]) scene_path = "/tmp/test.glb" scene.export(scene_path) return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), reconstructions, scene_path def render_reconstruction(image, conf, IoU_threshold=0.3): input_img, num_dets, reconstructions, scene_path = run_model(image, conf, IoU_threshold=0.5) # Render front view misc_args = dict(mesh_base_color=LIGHT_PURPLE, point_base_color=STEEL_BLUE, scene_bg_color=(1, 1, 1), focal_length=reconstructions['focal']) cam_view = renderer.render_rgba(reconstructions['verts'], reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'], cam_t=reconstructions['cam_t'], render_res=(224, 224), is_right=True, **misc_args) # Overlay image input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:] return input_img_overlay, f'{num_dets} hands detected', scene_path header = ('''

HORT: Monocular Hand-held Objects Reconstruction with Transformers

Zerui Chen1, Rolandos Alexandros Potamias2,
Shizhe Chen1, Cordelia Schmid1

1Inria, Ecole normale supérieure, CNRS, PSL Research University; 2Imperial College London

''') theme = gr.themes.Ocean() theme.set( checkbox_label_background_fill_selected="*button_primary_background_fill", checkbox_label_text_color_selected="*button_primary_text_color", ) with gr.Blocks(theme=theme, title="HORT: Monocular Hand-held Objects Reconstruction with Transformers", css=".gradio-container") as demo: gr.Markdown(header) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input image", type="numpy") submit = gr.Button("Submit", variant="primary") example_images = gr.Examples([ ['/home/user/app/assets/test1.png'], ['/home/user/app/assets/test2.png'], ['/home/user/app/assets/test3.jpg'], ['/home/user/app/assets/test4.jpg'], ['/home/user/app/assets/test5.jpeg'], ['/home/user/app/assets/test6.jpg'], ['/home/user/app/assets/test7.jpg'], ['/home/user/app/assets/test8.jpeg'] ], inputs=input_image) with gr.Column(): reconstruction = gr.Image(label="Reconstructions", type="numpy") output_meshes = gr.Model3D(label="3D Models", height=300, zoom_speed=0.5, pan_speed=0.5) hands_detected = gr.Textbox(label="Hands Detected") submit.click(fn=render_reconstruction, inputs=[input_image], outputs=[reconstruction, hands_detected, output_meshes]) demo.launch(share=True)