svjack commited on
Commit
2960559
·
verified ·
1 Parent(s): d5d4264

Update gradio_app_with_frames.py

Browse files
Files changed (1) hide show
  1. gradio_app_with_frames.py +80 -83
gradio_app_with_frames.py CHANGED
@@ -1,20 +1,46 @@
1
  import os
 
2
  import shutil
3
  import uuid
4
  import subprocess
5
  import gradio as gr
6
- import cv2
7
- import sys
8
  from glob import glob
9
- from pathlib import Path
10
 
11
- # 获取当前Python解释器路径
12
- #PYTHON_EXECUTABLE = sys.executable
13
- PYTHON_EXECUTABLE = "python"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def normalize_path(path: str) -> str:
 
16
  """标准化路径,将Windows路径转换为正斜杠形式"""
17
- return str(Path(path).resolve()).replace('\\', '/')
18
 
19
  def check_video_frames(video_path: str) -> int:
20
  """检查视频帧数"""
@@ -33,14 +59,14 @@ def preprocess_video(video_path: str) -> str:
33
  output_dir = os.path.join(temp_dir, f"processed_{unique_id}")
34
  output_dir = normalize_path(output_dir)
35
  os.makedirs(output_dir, exist_ok=True)
36
-
37
  print(f"Processing video: {video_path}")
38
  print(f"Output directory: {output_dir}")
39
-
40
- # 调用process_video_to_14frames.py处理视频
41
  result = subprocess.run(
42
  [
43
- PYTHON_EXECUTABLE, "process_video_to_14frames.py",
44
  "--input", video_path,
45
  "--output", output_dir
46
  ],
@@ -48,12 +74,12 @@ def preprocess_video(video_path: str) -> str:
48
  capture_output=True,
49
  text=True
50
  )
51
-
52
  if result.stdout:
53
  print(f"Preprocessing stdout: {result.stdout}")
54
  if result.stderr:
55
  print(f"Preprocessing stderr: {result.stderr}")
56
-
57
  # 获取处理后的视频路径
58
  processed_videos = glob(os.path.join(output_dir, "*.mp4"))
59
  if not processed_videos:
@@ -66,87 +92,52 @@ def preprocess_video(video_path: str) -> str:
66
  raise gr.Error(f"Error during video preprocessing: {str(e)}")
67
 
68
  def generate(control_sequence, ref_image):
 
 
 
 
 
69
  try:
70
- # 验证输入文件是否存在
71
- control_sequence = normalize_path(control_sequence)
72
- ref_image = normalize_path(ref_image)
73
-
74
- if not os.path.exists(control_sequence):
75
- raise gr.Error(f"Control sequence file not found: {control_sequence}")
76
- if not os.path.exists(ref_image):
77
- raise gr.Error(f"Reference image file not found: {ref_image}")
78
-
79
- # 创建输出目录
80
- output_dir = "outputs"
81
- os.makedirs(output_dir, exist_ok=True)
82
- unique_id = str(uuid.uuid4())
83
- result_dir = os.path.join(output_dir, f"results_{unique_id}")
84
- result_dir = normalize_path(result_dir)
85
- os.makedirs(result_dir, exist_ok=True)
86
-
87
- print(f"Input control sequence: {control_sequence}")
88
- print(f"Input reference image: {ref_image}")
89
- print(f"Output directory: {result_dir}")
90
-
91
  # 检查视频帧数
92
- frame_count = check_video_frames(control_sequence)
93
  if frame_count != 14:
94
  print(f"Video has {frame_count} frames, preprocessing to 14 frames...")
95
- control_sequence = preprocess_video(control_sequence)
96
- print(f"Preprocessed video saved to: {control_sequence}")
97
-
98
  # 运行推理命令
99
- print(f"Running inference...")
100
- result = subprocess.run(
101
  [
102
- PYTHON_EXECUTABLE, "scripts_infer/anidoc_inference.py",
103
  "--all_sketch",
104
  "--matching",
105
  "--tracking",
106
- "--control_image", control_sequence,
107
- "--ref_image", ref_image,
108
- "--output_dir", result_dir,
109
  "--max_point", "10",
110
  ],
111
- check=True,
112
- capture_output=True,
113
- text=True
114
  )
115
-
116
- if result.stdout:
117
- print(f"Inference stdout: {result.stdout}")
118
- if result.stderr:
119
- print(f"Inference stderr: {result.stderr}")
120
 
121
  # 搜索输出视频
122
- output_video = glob(os.path.join(result_dir, "*.mp4"))
123
- print(f"Found output videos: {output_video}")
124
-
125
  if output_video:
126
- output_video_path = normalize_path(output_video[0])
127
- print(f"Returning output video: {output_video_path}")
128
  else:
129
- raise gr.Error("No output video generated")
130
-
131
- # 清理临时文件
132
- temp_dirs = glob("outputs/processed_*")
133
- for temp_dir in temp_dirs:
134
- if os.path.isdir(temp_dir):
135
- try:
136
- shutil.rmtree(temp_dir)
137
- print(f"Cleaned up temp directory: {temp_dir}")
138
- except Exception as e:
139
- print(f"Warning: Failed to clean up temp directory {temp_dir}: {str(e)}")
140
-
141
  return output_video_path
142
-
143
  except subprocess.CalledProcessError as e:
144
- print(f"Inference stderr: {e.stderr}")
145
- raise gr.Error(f"Error during inference: {e.stderr}")
146
  except Exception as e:
147
  raise gr.Error(f"Error: {str(e)}")
148
 
149
- css="""
150
  div#col-container{
151
  margin: 0 auto;
152
  max-width: 982px;
@@ -168,6 +159,12 @@ with gr.Blocks(css=css) as demo:
168
  <a href="https://arxiv.org/pdf/2412.14173">
169
  <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
170
  </a>
 
 
 
 
 
 
171
  </div>
172
  """)
173
  with gr.Row():
@@ -179,20 +176,20 @@ with gr.Blocks(css=css) as demo:
179
  video_result = gr.Video(label="Result")
180
 
181
  gr.Examples(
182
- examples = [
183
  ["data_test/sample5.mp4", "data_test/sample5.png"],
184
  ["data_test/sample1.mp4", "data_test/sample1.png"],
185
  ["data_test/sample2.mp4", "data_test/sample2.png"],
186
  ["data_test/sample3.mp4", "data_test/sample3.png"],
187
  ["data_test/sample4.mp4", "data_test/sample4.png"]
188
  ],
189
- inputs = [control_sequence, ref_image]
190
- )
191
-
192
- submit_btn.click(
193
- fn = generate,
194
- inputs = [control_sequence, ref_image],
195
- outputs = [video_result]
196
- )
197
 
198
- demo.queue().launch(inbrowser=True,show_api=False, show_error=True, share = True)
 
1
  import os
2
+ import sys
3
  import shutil
4
  import uuid
5
  import subprocess
6
  import gradio as gr
7
+ import cv2 # 用于检查视频帧数
 
8
  from glob import glob
 
9
 
10
+ from huggingface_hub import snapshot_download, hf_hub_download
11
+
12
+ # Download models
13
+ os.makedirs("pretrained_weights", exist_ok=True)
14
+
15
+ # List of subdirectories to create inside "checkpoints"
16
+ subfolders = [
17
+ "stable-video-diffusion-img2vid-xt"
18
+ ]
19
+
20
+ # Create each subdirectory
21
+ for subfolder in subfolders:
22
+ os.makedirs(os.path.join("pretrained_weights", subfolder), exist_ok=True)
23
+
24
+ snapshot_download(
25
+ repo_id="stabilityai/stable-video-diffusion-img2vid",
26
+ local_dir="./pretrained_weights/stable-video-diffusion-img2vid-xt"
27
+ )
28
+
29
+ snapshot_download(
30
+ repo_id="Yhmeng1106/anidoc",
31
+ local_dir="./pretrained_weights"
32
+ )
33
+
34
+ hf_hub_download(
35
+ repo_id="facebook/cotracker",
36
+ filename="cotracker2.pth",
37
+ local_dir="./pretrained_weights"
38
+ )
39
 
40
  def normalize_path(path: str) -> str:
41
+ return path
42
  """标准化路径,将Windows路径转换为正斜杠形式"""
43
+ return os.path.abspath(path).replace('\\', '/')
44
 
45
  def check_video_frames(video_path: str) -> int:
46
  """检查视频帧数"""
 
59
  output_dir = os.path.join(temp_dir, f"processed_{unique_id}")
60
  output_dir = normalize_path(output_dir)
61
  os.makedirs(output_dir, exist_ok=True)
62
+
63
  print(f"Processing video: {video_path}")
64
  print(f"Output directory: {output_dir}")
65
+
66
+ # 调用外部脚本处理视频
67
  result = subprocess.run(
68
  [
69
+ "python", "process_video_to_14frames.py",
70
  "--input", video_path,
71
  "--output", output_dir
72
  ],
 
74
  capture_output=True,
75
  text=True
76
  )
77
+
78
  if result.stdout:
79
  print(f"Preprocessing stdout: {result.stdout}")
80
  if result.stderr:
81
  print(f"Preprocessing stderr: {result.stderr}")
82
+
83
  # 获取处理后的视频路径
84
  processed_videos = glob(os.path.join(output_dir, "*.mp4"))
85
  if not processed_videos:
 
92
  raise gr.Error(f"Error during video preprocessing: {str(e)}")
93
 
94
  def generate(control_sequence, ref_image):
95
+ control_image = control_sequence # "data_test/sample4.mp4"
96
+ ref_image = ref_image # "data_test/sample4.png"
97
+ unique_id = str(uuid.uuid4())
98
+ output_dir = f"results_{unique_id}"
99
+
100
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # 检查视频帧数
102
+ frame_count = check_video_frames(control_image)
103
  if frame_count != 14:
104
  print(f"Video has {frame_count} frames, preprocessing to 14 frames...")
105
+ control_image = preprocess_video(control_image)
106
+ print(f"Preprocessed video saved to: {control_image}")
107
+
108
  # 运行推理命令
109
+ subprocess.run(
 
110
  [
111
+ "python", "scripts_infer/anidoc_inference.py",
112
  "--all_sketch",
113
  "--matching",
114
  "--tracking",
115
+ "--control_image", f"{control_image}",
116
+ "--ref_image", f"{ref_image}",
117
+ "--output_dir", f"{output_dir}",
118
  "--max_point", "10",
119
  ],
120
+ check=True
 
 
121
  )
 
 
 
 
 
122
 
123
  # 搜索输出视频
124
+ output_video = glob(os.path.join(output_dir, "*.mp4"))
125
+ print(output_video)
126
+
127
  if output_video:
128
+ output_video_path = output_video[0] # 获取第一个匹配
 
129
  else:
130
+ output_video_path = None
131
+
132
+ print(output_video_path)
 
 
 
 
 
 
 
 
 
133
  return output_video_path
134
+
135
  except subprocess.CalledProcessError as e:
136
+ raise gr.Error(f"Error during inference: {str(e)}")
 
137
  except Exception as e:
138
  raise gr.Error(f"Error: {str(e)}")
139
 
140
+ css = """
141
  div#col-container{
142
  margin: 0 auto;
143
  max-width: 982px;
 
159
  <a href="https://arxiv.org/pdf/2412.14173">
160
  <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
161
  </a>
162
+ <a href="https://huggingface.co/spaces/fffiloni/AniDoc?duplicate=true">
163
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
164
+ </a>
165
+ <a href="https://huggingface.co/fffiloni">
166
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
167
+ </a>
168
  </div>
169
  """)
170
  with gr.Row():
 
176
  video_result = gr.Video(label="Result")
177
 
178
  gr.Examples(
179
+ examples=[
180
  ["data_test/sample5.mp4", "data_test/sample5.png"],
181
  ["data_test/sample1.mp4", "data_test/sample1.png"],
182
  ["data_test/sample2.mp4", "data_test/sample2.png"],
183
  ["data_test/sample3.mp4", "data_test/sample3.png"],
184
  ["data_test/sample4.mp4", "data_test/sample4.png"]
185
  ],
186
+ inputs=[control_sequence, ref_image]
187
+ )
188
+
189
+ submit_btn.click(
190
+ fn=generate,
191
+ inputs=[control_sequence, ref_image],
192
+ outputs=[video_result]
193
+ )
194
 
195
+ demo.queue().launch(show_api=False, show_error=True, share=True)