Spaces:
Sleeping
Sleeping
Commit
·
21c3587
1
Parent(s):
229302a
mask dropdown added
Browse files- app.py +24 -3
- scripts/inference.py +1 -0
- soundimage/pipelines/lipsync_pipeline.py +2 -1
- soundimage/utils/image_processor.py +4 -4
- soundimage/utils/mask.png +0 -3
app.py
CHANGED
@@ -19,6 +19,7 @@ def process_video(
|
|
19 |
inference_steps,
|
20 |
seed,
|
21 |
checkpoint_file,
|
|
|
22 |
):
|
23 |
# Create the temp directory if it doesn't exist
|
24 |
output_dir = Path("./temp")
|
@@ -26,6 +27,9 @@ def process_video(
|
|
26 |
|
27 |
# Use selected checkpoint or fall back to default
|
28 |
checkpoint_path = Path("checkpoints/unetFiles") / checkpoint_file if checkpoint_file else CHECKPOINT_PATH
|
|
|
|
|
|
|
29 |
|
30 |
# Convert paths to absolute Path objects and normalize them
|
31 |
video_file_path = Path(video_path)
|
@@ -48,7 +52,7 @@ def process_video(
|
|
48 |
)
|
49 |
|
50 |
# Parse the arguments
|
51 |
-
args = create_args(video_path, audio_path, output_path, guidance_scale, seed, checkpoint_path)
|
52 |
|
53 |
try:
|
54 |
result = main(
|
@@ -63,7 +67,8 @@ def process_video(
|
|
63 |
|
64 |
|
65 |
def create_args(
|
66 |
-
video_path: str, audio_path: str, output_path: str, guidance_scale: float, seed: int,
|
|
|
67 |
) -> argparse.Namespace:
|
68 |
parser = argparse.ArgumentParser()
|
69 |
parser.add_argument("--inference_ckpt_path", type=str, required=True)
|
@@ -72,6 +77,7 @@ def create_args(
|
|
72 |
parser.add_argument("--video_out_path", type=str, required=True)
|
73 |
parser.add_argument("--guidance_scale", type=float, default=1.0)
|
74 |
parser.add_argument("--seed", type=int, default=1247)
|
|
|
75 |
|
76 |
return parser.parse_args(
|
77 |
[
|
@@ -87,6 +93,8 @@ def create_args(
|
|
87 |
str(guidance_scale),
|
88 |
"--seed",
|
89 |
str(seed),
|
|
|
|
|
90 |
]
|
91 |
)
|
92 |
|
@@ -97,6 +105,13 @@ def get_checkpoint_files():
|
|
97 |
return []
|
98 |
return [f.name for f in unet_files_dir.glob("*.pt")]
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
# Create Gradio interface
|
101 |
with gr.Blocks(title="SoundImage") as demo:
|
102 |
gr.Markdown(
|
@@ -109,12 +124,17 @@ with gr.Blocks(title="SoundImage") as demo:
|
|
109 |
|
110 |
with gr.Row():
|
111 |
with gr.Column():
|
112 |
-
# Add checkpoint
|
113 |
checkpoint_dropdown = gr.Dropdown(
|
114 |
choices=get_checkpoint_files(),
|
115 |
label="Select Checkpoint",
|
116 |
value=get_checkpoint_files()[0] if get_checkpoint_files() else None
|
117 |
)
|
|
|
|
|
|
|
|
|
|
|
118 |
video_input = gr.Video(label="Input Video")
|
119 |
audio_input = gr.Audio(label="Input Audio", type="filepath")
|
120 |
|
@@ -156,6 +176,7 @@ with gr.Blocks(title="SoundImage") as demo:
|
|
156 |
inference_steps,
|
157 |
seed,
|
158 |
checkpoint_dropdown,
|
|
|
159 |
],
|
160 |
outputs=video_output,
|
161 |
)
|
|
|
19 |
inference_steps,
|
20 |
seed,
|
21 |
checkpoint_file,
|
22 |
+
mask_file,
|
23 |
):
|
24 |
# Create the temp directory if it doesn't exist
|
25 |
output_dir = Path("./temp")
|
|
|
27 |
|
28 |
# Use selected checkpoint or fall back to default
|
29 |
checkpoint_path = Path("checkpoints/unetFiles") / checkpoint_file if checkpoint_file else CHECKPOINT_PATH
|
30 |
+
|
31 |
+
# Get mask path
|
32 |
+
mask_path = Path("masks") / mask_file if mask_file else None
|
33 |
|
34 |
# Convert paths to absolute Path objects and normalize them
|
35 |
video_file_path = Path(video_path)
|
|
|
52 |
)
|
53 |
|
54 |
# Parse the arguments
|
55 |
+
args = create_args(video_path, audio_path, output_path, guidance_scale, seed, checkpoint_path, mask_path)
|
56 |
|
57 |
try:
|
58 |
result = main(
|
|
|
67 |
|
68 |
|
69 |
def create_args(
|
70 |
+
video_path: str, audio_path: str, output_path: str, guidance_scale: float, seed: int,
|
71 |
+
checkpoint_path: Path, mask_path: Path
|
72 |
) -> argparse.Namespace:
|
73 |
parser = argparse.ArgumentParser()
|
74 |
parser.add_argument("--inference_ckpt_path", type=str, required=True)
|
|
|
77 |
parser.add_argument("--video_out_path", type=str, required=True)
|
78 |
parser.add_argument("--guidance_scale", type=float, default=1.0)
|
79 |
parser.add_argument("--seed", type=int, default=1247)
|
80 |
+
parser.add_argument("--mask_path", type=str, required=False)
|
81 |
|
82 |
return parser.parse_args(
|
83 |
[
|
|
|
93 |
str(guidance_scale),
|
94 |
"--seed",
|
95 |
str(seed),
|
96 |
+
"--mask_path",
|
97 |
+
mask_path.absolute().as_posix() if mask_path else "",
|
98 |
]
|
99 |
)
|
100 |
|
|
|
105 |
return []
|
106 |
return [f.name for f in unet_files_dir.glob("*.pt")]
|
107 |
|
108 |
+
# Add this function to get mask files
|
109 |
+
def get_mask_files():
|
110 |
+
masks_dir = Path("masks")
|
111 |
+
if not masks_dir.exists():
|
112 |
+
return []
|
113 |
+
return [f.name for f in masks_dir.glob("*.png")] # Assuming masks are PNG files
|
114 |
+
|
115 |
# Create Gradio interface
|
116 |
with gr.Blocks(title="SoundImage") as demo:
|
117 |
gr.Markdown(
|
|
|
124 |
|
125 |
with gr.Row():
|
126 |
with gr.Column():
|
127 |
+
# Add checkpoint and mask selectors
|
128 |
checkpoint_dropdown = gr.Dropdown(
|
129 |
choices=get_checkpoint_files(),
|
130 |
label="Select Checkpoint",
|
131 |
value=get_checkpoint_files()[0] if get_checkpoint_files() else None
|
132 |
)
|
133 |
+
mask_dropdown = gr.Dropdown( # New dropdown for masks
|
134 |
+
choices=get_mask_files(),
|
135 |
+
label="Select Mask",
|
136 |
+
value=get_mask_files()[0] if get_mask_files() else None
|
137 |
+
)
|
138 |
video_input = gr.Video(label="Input Video")
|
139 |
audio_input = gr.Audio(label="Input Audio", type="filepath")
|
140 |
|
|
|
176 |
inference_steps,
|
177 |
seed,
|
178 |
checkpoint_dropdown,
|
179 |
+
mask_dropdown, # Add mask_dropdown to inputs
|
180 |
],
|
181 |
outputs=video_output,
|
182 |
)
|
scripts/inference.py
CHANGED
@@ -84,6 +84,7 @@ def main(config, args):
|
|
84 |
weight_dtype=dtype,
|
85 |
width=config.data.resolution,
|
86 |
height=config.data.resolution,
|
|
|
87 |
)
|
88 |
|
89 |
|
|
|
84 |
weight_dtype=dtype,
|
85 |
width=config.data.resolution,
|
86 |
height=config.data.resolution,
|
87 |
+
mask_path=args.mask_path,
|
88 |
)
|
89 |
|
90 |
|
soundimage/pipelines/lipsync_pipeline.py
CHANGED
@@ -296,6 +296,7 @@ class LipsyncPipeline(DiffusionPipeline):
|
|
296 |
audio_path: str,
|
297 |
video_out_path: str,
|
298 |
video_mask_path: str = None,
|
|
|
299 |
num_frames: int = 16,
|
300 |
video_fps: int = 25,
|
301 |
audio_sample_rate: int = 16000,
|
@@ -317,7 +318,7 @@ class LipsyncPipeline(DiffusionPipeline):
|
|
317 |
# 0. Define call parameters
|
318 |
batch_size = 1
|
319 |
device = self._execution_device
|
320 |
-
self.image_processor = ImageProcessor(height, mask=mask, device="cuda")
|
321 |
self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
|
322 |
|
323 |
video_frames, original_video_frames, boxes, affine_matrices = self.affine_transform_video(video_path)
|
|
|
296 |
audio_path: str,
|
297 |
video_out_path: str,
|
298 |
video_mask_path: str = None,
|
299 |
+
mask_path: str = None,
|
300 |
num_frames: int = 16,
|
301 |
video_fps: int = 25,
|
302 |
audio_sample_rate: int = 16000,
|
|
|
318 |
# 0. Define call parameters
|
319 |
batch_size = 1
|
320 |
device = self._execution_device
|
321 |
+
self.image_processor = ImageProcessor(height, mask=mask, device="cuda", mask_image=mask_path)
|
322 |
self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
|
323 |
|
324 |
video_frames, original_video_frames, boxes, affine_matrices = self.affine_transform_video(video_path)
|
soundimage/utils/image_processor.py
CHANGED
@@ -28,8 +28,8 @@ https://stackoverflow.com/questions/23853632/which-kind-of-interpolation-best-fo
|
|
28 |
"""
|
29 |
|
30 |
|
31 |
-
def load_fixed_mask(resolution: int) -> torch.Tensor:
|
32 |
-
mask_image = cv2.imread(
|
33 |
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
|
34 |
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0
|
35 |
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
|
@@ -37,7 +37,7 @@ def load_fixed_mask(resolution: int) -> torch.Tensor:
|
|
37 |
|
38 |
|
39 |
class ImageProcessor:
|
40 |
-
def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None):
|
41 |
self.resolution = resolution
|
42 |
self.resize = transforms.Resize(
|
43 |
(resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True
|
@@ -53,7 +53,7 @@ class ImageProcessor:
|
|
53 |
self.restorer = AlignRestore()
|
54 |
|
55 |
if mask_image is None:
|
56 |
-
self.mask_image = load_fixed_mask(resolution)
|
57 |
else:
|
58 |
self.mask_image = mask_image
|
59 |
|
|
|
28 |
"""
|
29 |
|
30 |
|
31 |
+
def load_fixed_mask(resolution: int, mask_path: str) -> torch.Tensor:
|
32 |
+
mask_image = cv2.imread(mask_path)
|
33 |
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
|
34 |
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0
|
35 |
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
|
|
|
37 |
|
38 |
|
39 |
class ImageProcessor:
|
40 |
+
def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None, mask_path=None):
|
41 |
self.resolution = resolution
|
42 |
self.resize = transforms.Resize(
|
43 |
(resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True
|
|
|
53 |
self.restorer = AlignRestore()
|
54 |
|
55 |
if mask_image is None:
|
56 |
+
self.mask_image = load_fixed_mask(resolution, mask_path)
|
57 |
else:
|
58 |
self.mask_image = mask_image
|
59 |
|
soundimage/utils/mask.png
DELETED
Git LFS Details
|