Spaces:
Running
Running
import gradio as gr | |
import requests | |
import os | |
import time | |
import json | |
from datetime import datetime | |
import oss2 | |
import cv2 | |
import uuid | |
from pathlib import Path | |
import decord | |
from gradio.utils import get_cache_folder | |
cache_version = 20250325 | |
dashscope_api_key = os.getenv("API_KEY","") | |
class Examples(gr.helpers.Examples): | |
def __init__(self, *args, directory_name=None, **kwargs): | |
super().__init__(*args, **kwargs, _initiated_directly=False) | |
if directory_name is not None: | |
self.cached_folder = get_cache_folder() / directory_name | |
self.cached_file = Path(self.cached_folder) / "log.csv" | |
self.create() | |
def upload_to_oss(local_file_path, remote_file_path, expire_time=3600): | |
remote_url = "motionshop2/%s/%s" %(datetime.now().strftime("%Y%m%d"), remote_file_path) | |
for i in range(5): | |
try: | |
from oss2.credentials import EnvironmentVariableCredentialsProvider | |
auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider()) | |
bucket = oss2.Bucket(auth, 'oss-us-east-1.aliyuncs.com', 'huggingface-motionshop') | |
bucket.put_object_from_file(key=remote_url, filename=local_file_path) | |
break | |
except Exception as e: | |
if i < 4: # If this is not the last retry | |
time.sleep(2) # Wait for 2 second before next retry | |
continue | |
else: # If this is the last retry and it still fails | |
raise e | |
return bucket.sign_url('GET', remote_url, expire_time) | |
def get_url(filepath): | |
filename = os.path.basename(filepath) | |
remote_file_path = "%s_%s" %(uuid.uuid4(), filename) | |
return upload_to_oss(filepath, remote_file_path) | |
def online_detect(filepath): | |
url = "https://poc-dashscope.aliyuncs.com/api/v1/services/default/default/default" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": "Bearer {}".format(dashscope_api_key) | |
} | |
data = { | |
"model": "pre-motionshop-detect-gradio", | |
"input": { | |
"video_url": filepath | |
}, | |
"parameters": { | |
"threshold": 0.4, | |
"min_area_ratio": 0.001 | |
} | |
} | |
print("Call detect api, params: " + json.dumps(data)) | |
query_result_request = requests.post( | |
url, | |
json=data, | |
headers=headers | |
) | |
print("Detect api returned: " + query_result_request.text) | |
return json.loads(query_result_request.text) | |
def online_render(filepath, frame_id, bbox, replacement_ids, cache_url=None, model="pre-motionshop-render-gradio"): | |
url = "https://poc-dashscope.aliyuncs.com/api/v1/services/async-default/async-default/async-default" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": "Bearer {}".format(dashscope_api_key), | |
"X-DashScope-Async": "enable" | |
} | |
data = { | |
"model": model, | |
# "model": "pre-motionshop-render-gradio", | |
"input": { | |
"video_url": filepath, | |
"frame_index": frame_id, | |
"bbox": bbox, | |
"replacement_id": replacement_ids | |
}, | |
"parameters": { | |
} | |
} | |
if cache_url is not None: | |
data["input"]["cache_url"] = cache_url | |
print("Call render video api with params: " + json.dumps(data)) | |
query_result_request = requests.post( | |
url, | |
json=data, | |
headers=headers | |
) | |
print("Render video api returned: " + query_result_request.text) | |
return json.loads(query_result_request.text) | |
def get_async_result(task_id): | |
while True: | |
result = requests.post( | |
"https://poc-dashscope.aliyuncs.com/api/v1/tasks/%s" %task_id, | |
headers={ | |
"Authorization": "Bearer {}".format(dashscope_api_key), | |
} | |
) | |
result = json.loads(result.text) | |
if "output" in result and result["output"]["task_status"] in ["SUCCEEDED", "FAILED"]: | |
break | |
time.sleep(1) | |
return result | |
def save_video_cv2(vid, resize_video_input, resize_h, resize_w, fps): | |
fourcc = cv2.VideoWriter_fourcc(*'XVID') | |
out = cv2.VideoWriter(resize_video_input, fourcc, fps, (resize_w, resize_h)) | |
for idx in range(len(vid)): | |
frame = vid[idx].asnumpy()[:,:,::-1] | |
frame = cv2.resize(frame,(resize_w, resize_h)) | |
out.write(frame) | |
out.release() | |
def detect_human(video_input): | |
# print(video_input) | |
video_input_basename = os.path.basename(video_input) | |
resize_video_input = os.path.join(os.path.dirname(video_input), video_input_basename.split(".")[0]+"_resize."+video_input_basename.split(".")[-1]) | |
vid = decord.VideoReader(video_input) | |
fps = vid.get_avg_fps() | |
H, W, C = vid[0].shape | |
if H > 1280 or W > 1280: | |
if H > W: | |
resize_h, resize_w = 1280, int(W*1280/H) | |
else: | |
resize_h, resize_w = int(H*1280/W), 1280 | |
save_video_cv2(vid, resize_video_input, resize_h, resize_w, fps) | |
new_video_input = resize_video_input | |
else: | |
# resize_h, resize_w = H, W | |
new_video_input = video_input | |
video_url = get_url(new_video_input) | |
detect_result = online_detect(video_url) | |
check_result = "output" in detect_result | |
select_frame_index = detect_result["output"]["frame_index"] | |
boxes = detect_result["output"]["bbox"][:3] | |
print("Detected %d characters" %len(boxes)) | |
cap = cv2.VideoCapture(new_video_input) | |
cap.set(cv2.CAP_PROP_POS_FRAMES, select_frame_index) | |
_, box_image = cap.read() | |
box_image = cv2.cvtColor(box_image, cv2.COLOR_BGR2RGB) | |
width, height = box_image.shape[1], box_image.shape[0] | |
for i, box in enumerate(boxes): | |
box = [ | |
(box[0] - box[2] / 2) * width, (box[1] - box[3] / 2) * height, | |
(box[0] + box[2] / 2) * width, (box[1] + box[3] / 2) * height] | |
# box_image = cv2.rectangle(box_image, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) | |
if i == 0: | |
box_image = cv2.rectangle(box_image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255, 0, 0), 2) | |
if i == 1: | |
box_image = cv2.rectangle(box_image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2) | |
if i == 2: | |
box_image = cv2.rectangle(box_image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2) | |
# check_result, select_frame_index, box, box_image, _ = object_detector.getGroundingInfo(video_input) | |
video_state = { | |
"check_result": check_result, | |
"select_frame_index": select_frame_index, | |
"box": boxes, | |
"replace_ids": [], | |
"image_to_3d_tasks": {}, | |
"video_url": video_url, | |
"video_path": new_video_input | |
} | |
return video_state, box_image, gr.update(visible=True), gr.update(visible=False) | |
def predict(video_state, first_image, second_image, third_image): | |
if len(video_state["box"]) == 0: | |
return None, "No human detected, please use a video with clear human" | |
print("images:", first_image, second_image, third_image) | |
tasks = [] | |
boxes = [] | |
if first_image is not None and len(video_state["box"]) >= 1: | |
tasks.append(image_to_3d(first_image)) | |
boxes.append(video_state["box"][0]) | |
if second_image is not None and len(video_state["box"]) >= 2: | |
tasks.append(image_to_3d(second_image)) | |
boxes.append(video_state["box"][1]) | |
if third_image is not None and len(video_state["box"]) >= 3: | |
tasks.append(image_to_3d(third_image)) | |
boxes.append(video_state["box"][2]) | |
if len(tasks) == 0: | |
return None, "Please upload at least one character photo for replacement." | |
ids = [] | |
for t in tasks: | |
try: | |
image_to_3d_result = get_async_result(t) | |
print("image to 3d finished", image_to_3d_result) | |
ids.append(image_to_3d_result["output"]["ply_url"]) | |
except Exception as e: | |
print(e) | |
return None, "Error in 3d model generation, please check the uploaded image" | |
if (video_state["check_result"]): | |
try: | |
taskid = online_render(video_state["video_url"], video_state["select_frame_index"], boxes, ids, None)["output"]["task_id"] | |
task_output = get_async_result(taskid) | |
print("Video synthesis completed, api returned: " + json.dumps(task_output)) | |
video_url = task_output["output"]["synthesis_video_url"] | |
return video_url, "Processing Success" | |
except Exception as e: | |
print(e) | |
return None, "Error in video synthesis, please change the material and try again" | |
else: | |
return None, "Error in human detection, please use a video with clear human" | |
def online_img_to_3d(img_url): | |
url = "https://poc-dashscope.aliyuncs.com/api/v1/services/async-default/async-default/async-default" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": "Bearer {}".format(dashscope_api_key), | |
"X-DashScope-Async": "enable" | |
} | |
data = { | |
# "model": "pre-Human3DGS", | |
"model": "pre-image-to-3d-gradio", | |
# "model": "pre-motionshop-render-h20-test", | |
"input": { | |
"image_url": img_url, | |
}, | |
"parameters": { | |
} | |
} | |
query_result_request = requests.post( | |
url, | |
json=data, | |
headers=headers | |
) | |
print("Call image to 3d api, params: " + json.dumps(data)) | |
return json.loads(query_result_request.text) | |
def image_to_3d(image_path): | |
url = get_url(image_path) | |
task_send_result = online_img_to_3d(url) | |
image_to_3d_task_id = task_send_result["output"]["task_id"] | |
return image_to_3d_task_id | |
def gradio_demo(): | |
with gr.Blocks() as iface: | |
""" | |
state for | |
""" | |
video_state = gr.State( | |
{ | |
"check_result": False, | |
"select_frame_index": 0, | |
"box": [], | |
"replace_ids": [], | |
"image_to_3d_tasks": {}, | |
"video_url": "", | |
"video_path": "" | |
} | |
) | |
gr.HTML( | |
""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<div> | |
<h1 >Motionshop2</h1> | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;"> | |
<a class="flex-item" href="https://aigc3d.github.io/motionshop-2" target="_blank"> | |
<img src="https://img.shields.io/badge/Project_Page-Motionshop2-green.svg" alt="Project Page"> | |
</a> | |
<a class="flex-item" href="https://lingtengqiu.github.io/LHM/" target="_blank"> | |
<img src="https://img.shields.io/badge/Project_Page-LHM-green.svg" alt="Project Page"> | |
</a> | |
<a class="flex-item" href="https://lixiaowen-xw.github.io/DiffuEraser-page/" target="_blank"> | |
<img src="https://img.shields.io/badge/Project_Page-DiffuEraser-green.svg" alt="Project Page"> | |
</a> | |
</div> | |
</div> | |
</div> | |
""" | |
) | |
gr.Markdown("""<h4 style="color: green;"> 1. Choose or upload a video (duration<=15s, resolution<=720p)</h4>""") | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML(""" | |
<style> | |
#input_video video, #output_video video { | |
height: 480px !important; | |
object-fit: contain; | |
} | |
#template_frame img { | |
height: 480px !important; | |
object-fit: contain; | |
} | |
</style> | |
""") | |
video_input = gr.Video(elem_id="input_video") | |
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False) | |
Examples( | |
fn=detect_human, | |
examples=sorted([ | |
os.path.join("files", "example_videos", name) | |
for name in os.listdir(os.path.join("files", "example_videos")) | |
]), | |
run_on_click=True, | |
inputs=[video_input], | |
outputs=[video_state, template_frame, template_frame, video_input], | |
directory_name="examples_videos", | |
cache_examples=False, | |
) | |
gr.Markdown("""<h4 style="color: green;"> 2.Choose or upload images to replace </h4>""") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("Replace the character in the red box with...") | |
with gr.Row(): | |
first_image = gr.Image(type="filepath",interactive=True, elem_id="first_image", visible=True, height=480, width=270) | |
first_example = gr.Examples( | |
examples=sorted([os.path.join("files", "example_images", name) for name in os.listdir(os.path.join("files", "example_images"))]), | |
inputs=[first_image], | |
examples_per_page=6 | |
) | |
with gr.Column(): | |
gr.Markdown("Replace the character in the green box with...") | |
with gr.Row(): | |
second_image = gr.Image(type="filepath",interactive=True, elem_id="second_image", visible=True, height=480, width=270) | |
second_example = gr.Examples( | |
examples=sorted([os.path.join("files", "example_images", name) for name in os.listdir(os.path.join("files", "example_images"))]), | |
inputs=[second_image], | |
examples_per_page=6 | |
) | |
with gr.Column(): | |
gr.Markdown("Replace the character in the blue box with...") | |
with gr.Row(): | |
third_image = gr.Image(type="filepath",interactive=True, elem_id="third_image", visible=True, height=480, width=270) | |
third_example = gr.Examples( | |
examples=sorted([os.path.join("files", "example_images", name) for name in os.listdir(os.path.join("files", "example_images"))]), | |
inputs=[third_image], | |
examples_per_page=6 | |
) | |
gr.Markdown("""<h4 style="color: green;"> 3.Click Start (each generation may take 3 minutes due to the use of SOTA video inpainting and pose estimation methods)</h4>""") | |
with gr.Row(): | |
with gr.Column(): | |
motion_shop_predict_button = gr.Button(value="Start", variant="primary") | |
video_output = gr.Video(elem_id="output_video") | |
error_message = gr.Textbox(label="Processing Status", visible=True, interactive=False) | |
video_input.upload( | |
fn=detect_human, | |
inputs=[ | |
video_input | |
], | |
outputs=[video_state, template_frame, template_frame, video_input], | |
) | |
motion_shop_predict_button.click( | |
fn=predict, | |
inputs=[video_state, first_image, second_image, third_image], | |
outputs=[video_output, error_message] | |
) | |
# clear input | |
template_frame.clear( | |
lambda: ( | |
{ | |
"check_result": False, | |
"select_frame_index": 0, | |
"box": [], | |
"replace_ids": [], | |
"image_to_3d_tasks": {}, | |
"video_url": "", | |
"video_path": "" | |
}, | |
None, | |
None, | |
None, | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(value=None), | |
gr.update(value=None), | |
gr.update(value=None), | |
gr.update(value="") | |
), | |
[], | |
[ | |
video_state, | |
video_output, | |
template_frame, | |
video_input, | |
video_input, | |
template_frame, | |
first_image, | |
second_image, | |
third_image, | |
error_message | |
], | |
queue=False, | |
show_progress=False) | |
# print("username:", uuid_output_field) | |
# set example | |
# gr.Markdown("## Examples") | |
# gr.Examples( | |
# examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample8.mp4","test-sample4.mp4", \ | |
# "test-sample2.mp4","test-sample13.mp4"]], | |
# fn=run_example, | |
# inputs=[ | |
# e.s video_input | |
# ], | |
# outputs=[video_input], | |
# # cache_examples=True, | |
# ) | |
iface.queue(default_concurrency_limit=200) | |
iface.launch(debug=False, max_threads=10, server_name="0.0.0.0") | |
if __name__=="__main__": | |
gradio_demo() | |
# iface.launch(debug=True, enable_queue=True) |