samarth-ht commited on
Commit
21c3587
·
1 Parent(s): 229302a

mask dropdown added

Browse files
app.py CHANGED
@@ -19,6 +19,7 @@ def process_video(
19
  inference_steps,
20
  seed,
21
  checkpoint_file,
 
22
  ):
23
  # Create the temp directory if it doesn't exist
24
  output_dir = Path("./temp")
@@ -26,6 +27,9 @@ def process_video(
26
 
27
  # Use selected checkpoint or fall back to default
28
  checkpoint_path = Path("checkpoints/unetFiles") / checkpoint_file if checkpoint_file else CHECKPOINT_PATH
 
 
 
29
 
30
  # Convert paths to absolute Path objects and normalize them
31
  video_file_path = Path(video_path)
@@ -48,7 +52,7 @@ def process_video(
48
  )
49
 
50
  # Parse the arguments
51
- args = create_args(video_path, audio_path, output_path, guidance_scale, seed, checkpoint_path)
52
 
53
  try:
54
  result = main(
@@ -63,7 +67,8 @@ def process_video(
63
 
64
 
65
  def create_args(
66
- video_path: str, audio_path: str, output_path: str, guidance_scale: float, seed: int, checkpoint_path: Path
 
67
  ) -> argparse.Namespace:
68
  parser = argparse.ArgumentParser()
69
  parser.add_argument("--inference_ckpt_path", type=str, required=True)
@@ -72,6 +77,7 @@ def create_args(
72
  parser.add_argument("--video_out_path", type=str, required=True)
73
  parser.add_argument("--guidance_scale", type=float, default=1.0)
74
  parser.add_argument("--seed", type=int, default=1247)
 
75
 
76
  return parser.parse_args(
77
  [
@@ -87,6 +93,8 @@ def create_args(
87
  str(guidance_scale),
88
  "--seed",
89
  str(seed),
 
 
90
  ]
91
  )
92
 
@@ -97,6 +105,13 @@ def get_checkpoint_files():
97
  return []
98
  return [f.name for f in unet_files_dir.glob("*.pt")]
99
 
 
 
 
 
 
 
 
100
  # Create Gradio interface
101
  with gr.Blocks(title="SoundImage") as demo:
102
  gr.Markdown(
@@ -109,12 +124,17 @@ with gr.Blocks(title="SoundImage") as demo:
109
 
110
  with gr.Row():
111
  with gr.Column():
112
- # Add checkpoint selector dropdown
113
  checkpoint_dropdown = gr.Dropdown(
114
  choices=get_checkpoint_files(),
115
  label="Select Checkpoint",
116
  value=get_checkpoint_files()[0] if get_checkpoint_files() else None
117
  )
 
 
 
 
 
118
  video_input = gr.Video(label="Input Video")
119
  audio_input = gr.Audio(label="Input Audio", type="filepath")
120
 
@@ -156,6 +176,7 @@ with gr.Blocks(title="SoundImage") as demo:
156
  inference_steps,
157
  seed,
158
  checkpoint_dropdown,
 
159
  ],
160
  outputs=video_output,
161
  )
 
19
  inference_steps,
20
  seed,
21
  checkpoint_file,
22
+ mask_file,
23
  ):
24
  # Create the temp directory if it doesn't exist
25
  output_dir = Path("./temp")
 
27
 
28
  # Use selected checkpoint or fall back to default
29
  checkpoint_path = Path("checkpoints/unetFiles") / checkpoint_file if checkpoint_file else CHECKPOINT_PATH
30
+
31
+ # Get mask path
32
+ mask_path = Path("masks") / mask_file if mask_file else None
33
 
34
  # Convert paths to absolute Path objects and normalize them
35
  video_file_path = Path(video_path)
 
52
  )
53
 
54
  # Parse the arguments
55
+ args = create_args(video_path, audio_path, output_path, guidance_scale, seed, checkpoint_path, mask_path)
56
 
57
  try:
58
  result = main(
 
67
 
68
 
69
  def create_args(
70
+ video_path: str, audio_path: str, output_path: str, guidance_scale: float, seed: int,
71
+ checkpoint_path: Path, mask_path: Path
72
  ) -> argparse.Namespace:
73
  parser = argparse.ArgumentParser()
74
  parser.add_argument("--inference_ckpt_path", type=str, required=True)
 
77
  parser.add_argument("--video_out_path", type=str, required=True)
78
  parser.add_argument("--guidance_scale", type=float, default=1.0)
79
  parser.add_argument("--seed", type=int, default=1247)
80
+ parser.add_argument("--mask_path", type=str, required=False)
81
 
82
  return parser.parse_args(
83
  [
 
93
  str(guidance_scale),
94
  "--seed",
95
  str(seed),
96
+ "--mask_path",
97
+ mask_path.absolute().as_posix() if mask_path else "",
98
  ]
99
  )
100
 
 
105
  return []
106
  return [f.name for f in unet_files_dir.glob("*.pt")]
107
 
108
+ # Add this function to get mask files
109
+ def get_mask_files():
110
+ masks_dir = Path("masks")
111
+ if not masks_dir.exists():
112
+ return []
113
+ return [f.name for f in masks_dir.glob("*.png")] # Assuming masks are PNG files
114
+
115
  # Create Gradio interface
116
  with gr.Blocks(title="SoundImage") as demo:
117
  gr.Markdown(
 
124
 
125
  with gr.Row():
126
  with gr.Column():
127
+ # Add checkpoint and mask selectors
128
  checkpoint_dropdown = gr.Dropdown(
129
  choices=get_checkpoint_files(),
130
  label="Select Checkpoint",
131
  value=get_checkpoint_files()[0] if get_checkpoint_files() else None
132
  )
133
+ mask_dropdown = gr.Dropdown( # New dropdown for masks
134
+ choices=get_mask_files(),
135
+ label="Select Mask",
136
+ value=get_mask_files()[0] if get_mask_files() else None
137
+ )
138
  video_input = gr.Video(label="Input Video")
139
  audio_input = gr.Audio(label="Input Audio", type="filepath")
140
 
 
176
  inference_steps,
177
  seed,
178
  checkpoint_dropdown,
179
+ mask_dropdown, # Add mask_dropdown to inputs
180
  ],
181
  outputs=video_output,
182
  )
scripts/inference.py CHANGED
@@ -84,6 +84,7 @@ def main(config, args):
84
  weight_dtype=dtype,
85
  width=config.data.resolution,
86
  height=config.data.resolution,
 
87
  )
88
 
89
 
 
84
  weight_dtype=dtype,
85
  width=config.data.resolution,
86
  height=config.data.resolution,
87
+ mask_path=args.mask_path,
88
  )
89
 
90
 
soundimage/pipelines/lipsync_pipeline.py CHANGED
@@ -296,6 +296,7 @@ class LipsyncPipeline(DiffusionPipeline):
296
  audio_path: str,
297
  video_out_path: str,
298
  video_mask_path: str = None,
 
299
  num_frames: int = 16,
300
  video_fps: int = 25,
301
  audio_sample_rate: int = 16000,
@@ -317,7 +318,7 @@ class LipsyncPipeline(DiffusionPipeline):
317
  # 0. Define call parameters
318
  batch_size = 1
319
  device = self._execution_device
320
- self.image_processor = ImageProcessor(height, mask=mask, device="cuda")
321
  self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
322
 
323
  video_frames, original_video_frames, boxes, affine_matrices = self.affine_transform_video(video_path)
 
296
  audio_path: str,
297
  video_out_path: str,
298
  video_mask_path: str = None,
299
+ mask_path: str = None,
300
  num_frames: int = 16,
301
  video_fps: int = 25,
302
  audio_sample_rate: int = 16000,
 
318
  # 0. Define call parameters
319
  batch_size = 1
320
  device = self._execution_device
321
+ self.image_processor = ImageProcessor(height, mask=mask, device="cuda", mask_image=mask_path)
322
  self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
323
 
324
  video_frames, original_video_frames, boxes, affine_matrices = self.affine_transform_video(video_path)
soundimage/utils/image_processor.py CHANGED
@@ -28,8 +28,8 @@ https://stackoverflow.com/questions/23853632/which-kind-of-interpolation-best-fo
28
  """
29
 
30
 
31
- def load_fixed_mask(resolution: int) -> torch.Tensor:
32
- mask_image = cv2.imread("soundimage/utils/mask.png")
33
  mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
34
  mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0
35
  mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
@@ -37,7 +37,7 @@ def load_fixed_mask(resolution: int) -> torch.Tensor:
37
 
38
 
39
  class ImageProcessor:
40
- def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None):
41
  self.resolution = resolution
42
  self.resize = transforms.Resize(
43
  (resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True
@@ -53,7 +53,7 @@ class ImageProcessor:
53
  self.restorer = AlignRestore()
54
 
55
  if mask_image is None:
56
- self.mask_image = load_fixed_mask(resolution)
57
  else:
58
  self.mask_image = mask_image
59
 
 
28
  """
29
 
30
 
31
+ def load_fixed_mask(resolution: int, mask_path: str) -> torch.Tensor:
32
+ mask_image = cv2.imread(mask_path)
33
  mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
34
  mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0
35
  mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
 
37
 
38
 
39
  class ImageProcessor:
40
+ def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None, mask_path=None):
41
  self.resolution = resolution
42
  self.resize = transforms.Resize(
43
  (resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True
 
53
  self.restorer = AlignRestore()
54
 
55
  if mask_image is None:
56
+ self.mask_image = load_fixed_mask(resolution, mask_path)
57
  else:
58
  self.mask_image = mask_image
59
 
soundimage/utils/mask.png DELETED

Git LFS Details

  • SHA256: aa233251b9ff5691a1565a4108f0910ab1e5e7ad79a7bb2b741ab4d92c81053c
  • Pointer size: 129 Bytes
  • Size of remote file: 1.87 kB