积极的屁孩 commited on
Commit
3b944a1
·
1 Parent(s): defde46
Files changed (1) hide show
  1. app.py +135 -135
app.py CHANGED
@@ -13,67 +13,67 @@ import re
13
  import spaces
14
 
15
  def install_espeak():
16
- """检测并安装espeak-ng依赖"""
17
  try:
18
- # 检查espeak-ng是否已安装
19
  result = subprocess.run(["which", "espeak-ng"], capture_output=True, text=True)
20
  if result.returncode != 0:
21
- print("检测到系统中未安装espeak-ng,正在尝试安装...")
22
- # 尝试使用apt-get安装espeak-ng及其数据
23
  subprocess.run(["apt-get", "update"], check=True)
24
- # 安装 espeak-ng 和对应的语言数据包
25
  subprocess.run(["apt-get", "install", "-y", "espeak-ng", "espeak-ng-data"], check=True)
26
- print("espeak-ng及其数据包安装成功!")
27
  else:
28
- print("espeak-ng已安装在系统中。")
29
- # 即使已安装,也尝试更新数据确保完整性 (可选,但有时有帮助)
30
- # print("尝试更新 espeak-ng 数据...")
31
  # subprocess.run(["apt-get", "update"], check=True)
32
  # subprocess.run(["apt-get", "install", "--only-upgrade", "-y", "espeak-ng-data"], check=True)
33
 
34
- # 验证中文支持 (可选)
35
  try:
36
  voices_result = subprocess.run(["espeak-ng", "--voices=cmn"], capture_output=True, text=True, check=True)
37
  if "cmn" in voices_result.stdout:
38
- print("espeak-ng 支持 'cmn' 语言。")
39
  else:
40
- print("警告:espeak-ng 安装了,但 'cmn' 语言似乎仍不可用。")
41
  except Exception as e:
42
- print(f"验证 espeak-ng 中文支持时出错(可能不影响功能): {e}")
43
 
44
  except Exception as e:
45
- print(f"安装espeak-ng时出错: {e}")
46
- print("请尝试手动运行: apt-get update && apt-get install -y espeak-ng espeak-ng-data")
47
 
48
- # 在所有其他操作之前安装espeak
49
  install_espeak()
50
 
51
  def patch_langsegment_init():
52
  try:
53
- # 尝试找到 LangSegment 包的位置
54
  spec = importlib.util.find_spec("LangSegment")
55
  if spec is None or spec.origin is None:
56
- print("无法定位 LangSegment 包。")
57
  return
58
 
59
- # 构建 __init__.py 的路径
60
  init_path = os.path.join(os.path.dirname(spec.origin), '__init__.py')
61
 
62
  if not os.path.exists(init_path):
63
- print(f"未找到 LangSegment __init__.py 文件于: {init_path}")
64
- # 尝试在 site-packages 中查找,适用于某些环境
65
  for site_pkg_path in site.getsitepackages():
66
  potential_path = os.path.join(site_pkg_path, 'LangSegment', '__init__.py')
67
  if os.path.exists(potential_path):
68
  init_path = potential_path
69
- print(f" site-packages 中找到 __init__.py: {init_path}")
70
  break
71
- else: # 如果循环正常结束(没有 break
72
- print(f" site-packages 中也未找到 __init__.py")
73
  return
74
 
75
 
76
- print(f"尝试读取 LangSegment __init__.py: {init_path}")
77
  with open(init_path, 'r') as f:
78
  lines = f.readlines()
79
 
@@ -85,52 +85,52 @@ def patch_langsegment_init():
85
  stripped_line = line.strip()
86
  if stripped_line.startswith(target_line_prefix):
87
  if 'setLangfilters' in stripped_line or 'getLangfilters' in stripped_line:
88
- print(f"发现需要修改的行: {stripped_line}")
89
- # 移除 setLangfilters getLangfilters
90
  modified_line = stripped_line.replace(',setLangfilters', '')
91
  modified_line = modified_line.replace(',getLangfilters', '')
92
- # 确保逗号处理正确 (例如,如果它们是末尾的项)
93
  modified_line = modified_line.replace('setLangfilters,', '')
94
  modified_line = modified_line.replace('getLangfilters,', '')
95
- # 如果它们是唯一的额外导入,移除可能多余的逗号
96
  modified_line = modified_line.rstrip(',')
97
  new_lines.append(modified_line + '\n')
98
  modified = True
99
- print(f"修改后的行: {modified_line.strip()}")
100
  else:
101
- new_lines.append(line) # 行没问题,保留原样
102
  else:
103
- new_lines.append(line) # 非目标行,保留原样
104
 
105
  if modified:
106
- print(f"尝试写回已修改的 LangSegment __init__.py 到: {init_path}")
107
  try:
108
  with open(init_path, 'w') as f:
109
  f.writelines(new_lines)
110
- print("LangSegment __init__.py 修改成功。")
111
- # 尝试重新加载模块以使更改生效(可能无效,取决于导入链)
112
  try:
113
  import LangSegment
114
  importlib.reload(LangSegment)
115
- print("LangSegment 模块已尝试重新加载。")
116
  except Exception as reload_e:
117
- print(f"重新加载 LangSegment 时出错(可能无影响): {reload_e}")
118
  except PermissionError:
119
- print(f"错误:权限不足,无法修改 {init_path}。请考虑修改 requirements.txt")
120
  except Exception as write_e:
121
- print(f"写入 LangSegment __init__.py 时发生其他错误: {write_e}")
122
  else:
123
- print("LangSegment __init__.py 无需修改。")
124
 
125
  except ImportError:
126
- print("未找到 LangSegment 包,无法进行修复。")
127
  except Exception as e:
128
- print(f"修复 LangSegment 包时发生意外错误: {e}")
129
 
130
- # 在所有其他导入(尤其是可能触发 LangSegment 导入的 Amphion)之前执行修复
131
  patch_langsegment_init()
132
 
133
- # 克隆Amphion仓库
134
  if not os.path.exists("Amphion"):
135
  subprocess.run(["git", "clone", "https://github.com/open-mmlab/Amphion.git"])
136
  os.chdir("Amphion")
@@ -138,17 +138,17 @@ else:
138
  if not os.getcwd().endswith("Amphion"):
139
  os.chdir("Amphion")
140
 
141
- # Amphion加入到路径中
142
  if os.path.dirname(os.path.abspath("Amphion")) not in sys.path:
143
  sys.path.append(os.path.dirname(os.path.abspath("Amphion")))
144
 
145
- # 确保需要的目录存在
146
  os.makedirs("wav", exist_ok=True)
147
  os.makedirs("ckpts/Vevo", exist_ok=True)
148
 
149
  from models.vc.vevo.vevo_utils import VevoInferencePipeline, save_audio, load_wav
150
 
151
- # 下载和设置配置文件
152
  def setup_configs():
153
  config_path = "models/vc/vevo/config"
154
  os.makedirs(config_path, exist_ok=True)
@@ -171,27 +171,27 @@ def setup_configs():
171
  repo_type="model",
172
  )
173
  os.makedirs(os.path.dirname(file_path), exist_ok=True)
174
- # 拷贝文件到目标位置
175
  subprocess.run(["cp", file_data, file_path])
176
  except Exception as e:
177
- print(f"下载配置文件 {file} 时出错: {e}")
178
 
179
  setup_configs()
180
 
181
- # 设备配置
182
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
183
- print(f"使用设备: {device}")
184
 
185
- # 初始化管道字典
186
  inference_pipelines = {}
187
 
188
  def get_pipeline(pipeline_type):
189
  if pipeline_type in inference_pipelines:
190
  return inference_pipelines[pipeline_type]
191
 
192
- # 根据需要的管道类型初始化
193
  if pipeline_type == "style" or pipeline_type == "voice":
194
- # 下载Content Tokenizer
195
  local_dir = snapshot_download(
196
  repo_id="amphion/Vevo",
197
  repo_type="model",
@@ -202,7 +202,7 @@ def get_pipeline(pipeline_type):
202
  local_dir, "tokenizer/vq32/hubert_large_l18_c32.pkl"
203
  )
204
 
205
- # 下载Content-Style Tokenizer
206
  local_dir = snapshot_download(
207
  repo_id="amphion/Vevo",
208
  repo_type="model",
@@ -211,7 +211,7 @@ def get_pipeline(pipeline_type):
211
  )
212
  content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
213
 
214
- # 下载Autoregressive Transformer
215
  local_dir = snapshot_download(
216
  repo_id="amphion/Vevo",
217
  repo_type="model",
@@ -221,7 +221,7 @@ def get_pipeline(pipeline_type):
221
  ar_cfg_path = "./models/vc/vevo/config/Vq32ToVq8192.json"
222
  ar_ckpt_path = os.path.join(local_dir, "contentstyle_modeling/Vq32ToVq8192")
223
 
224
- # 下载Flow Matching Transformer
225
  local_dir = snapshot_download(
226
  repo_id="amphion/Vevo",
227
  repo_type="model",
@@ -231,7 +231,7 @@ def get_pipeline(pipeline_type):
231
  fmt_cfg_path = "./models/vc/vevo/config/Vq8192ToMels.json"
232
  fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels")
233
 
234
- # 下载Vocoder
235
  local_dir = snapshot_download(
236
  repo_id="amphion/Vevo",
237
  repo_type="model",
@@ -241,7 +241,7 @@ def get_pipeline(pipeline_type):
241
  vocoder_cfg_path = "./models/vc/vevo/config/Vocoder.json"
242
  vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
243
 
244
- # 初始化管道
245
  inference_pipeline = VevoInferencePipeline(
246
  content_tokenizer_ckpt_path=content_tokenizer_ckpt_path,
247
  content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
@@ -255,7 +255,7 @@ def get_pipeline(pipeline_type):
255
  )
256
 
257
  elif pipeline_type == "timbre":
258
- # 下载Content-Style Tokenizer (timbre需要)
259
  local_dir = snapshot_download(
260
  repo_id="amphion/Vevo",
261
  repo_type="model",
@@ -264,7 +264,7 @@ def get_pipeline(pipeline_type):
264
  )
265
  content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
266
 
267
- # 下载Flow Matching Transformer
268
  local_dir = snapshot_download(
269
  repo_id="amphion/Vevo",
270
  repo_type="model",
@@ -274,7 +274,7 @@ def get_pipeline(pipeline_type):
274
  fmt_cfg_path = "./models/vc/vevo/config/Vq8192ToMels.json"
275
  fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels")
276
 
277
- # 下载Vocoder
278
  local_dir = snapshot_download(
279
  repo_id="amphion/Vevo",
280
  repo_type="model",
@@ -284,7 +284,7 @@ def get_pipeline(pipeline_type):
284
  vocoder_cfg_path = "./models/vc/vevo/config/Vocoder.json"
285
  vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
286
 
287
- # 初始化管道
288
  inference_pipeline = VevoInferencePipeline(
289
  content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
290
  fmt_cfg_path=fmt_cfg_path,
@@ -295,7 +295,7 @@ def get_pipeline(pipeline_type):
295
  )
296
 
297
  elif pipeline_type == "tts":
298
- # 下载Content-Style Tokenizer
299
  local_dir = snapshot_download(
300
  repo_id="amphion/Vevo",
301
  repo_type="model",
@@ -304,7 +304,7 @@ def get_pipeline(pipeline_type):
304
  )
305
  content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
306
 
307
- # 下载Autoregressive Transformer (TTS特有)
308
  local_dir = snapshot_download(
309
  repo_id="amphion/Vevo",
310
  repo_type="model",
@@ -314,7 +314,7 @@ def get_pipeline(pipeline_type):
314
  ar_cfg_path = "./models/vc/vevo/config/PhoneToVq8192.json"
315
  ar_ckpt_path = os.path.join(local_dir, "contentstyle_modeling/PhoneToVq8192")
316
 
317
- # 下载Flow Matching Transformer
318
  local_dir = snapshot_download(
319
  repo_id="amphion/Vevo",
320
  repo_type="model",
@@ -324,7 +324,7 @@ def get_pipeline(pipeline_type):
324
  fmt_cfg_path = "./models/vc/vevo/config/Vq8192ToMels.json"
325
  fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels")
326
 
327
- # 下载Vocoder
328
  local_dir = snapshot_download(
329
  repo_id="amphion/Vevo",
330
  repo_type="model",
@@ -334,7 +334,7 @@ def get_pipeline(pipeline_type):
334
  vocoder_cfg_path = "./models/vc/vevo/config/Vocoder.json"
335
  vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
336
 
337
- # 初始化管道
338
  inference_pipeline = VevoInferencePipeline(
339
  content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
340
  ar_cfg_path=ar_cfg_path,
@@ -346,33 +346,33 @@ def get_pipeline(pipeline_type):
346
  device=device,
347
  )
348
 
349
- # 缓存管道实例
350
  inference_pipelines[pipeline_type] = inference_pipeline
351
  return inference_pipeline
352
 
353
- # 实现VEVO功能函数
354
  @spaces.GPU()
355
  def vevo_style(content_wav, style_wav):
356
  temp_content_path = "wav/temp_content.wav"
357
  temp_style_path = "wav/temp_style.wav"
358
  output_path = "wav/output_vevostyle.wav"
359
 
360
- # 检查并处理音频数据
361
  if content_wav is None or style_wav is None:
362
  raise ValueError("Please upload audio files")
363
 
364
- # 处理音频格式
365
  if isinstance(content_wav, tuple) and len(content_wav) == 2:
366
  if isinstance(content_wav[0], np.ndarray):
367
  content_data, content_sr = content_wav
368
  else:
369
  content_sr, content_data = content_wav
370
 
371
- # 确保是单声道
372
  if len(content_data.shape) > 1 and content_data.shape[1] > 1:
373
  content_data = np.mean(content_data, axis=1)
374
 
375
- # 重采样到24kHz
376
  if content_sr != 24000:
377
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
378
  content_tensor = torchaudio.functional.resample(content_tensor, content_sr, 24000)
@@ -380,7 +380,7 @@ def vevo_style(content_wav, style_wav):
380
  else:
381
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
382
 
383
- # 归一化音量
384
  content_tensor = content_tensor / (torch.max(torch.abs(content_tensor)) + 1e-6) * 0.95
385
  else:
386
  raise ValueError("Invalid content audio format")
@@ -390,11 +390,11 @@ def vevo_style(content_wav, style_wav):
390
  else:
391
  style_sr, style_data = style_wav
392
 
393
- # 确保是单声道
394
  if len(style_data.shape) > 1 and style_data.shape[1] > 1:
395
  style_data = np.mean(style_data, axis=1)
396
 
397
- # 重采样到24kHz
398
  if style_sr != 24000:
399
  style_tensor = torch.FloatTensor(style_data).unsqueeze(0)
400
  style_tensor = torchaudio.functional.resample(style_tensor, style_sr, 24000)
@@ -402,22 +402,22 @@ def vevo_style(content_wav, style_wav):
402
  else:
403
  style_tensor = torch.FloatTensor(style_data).unsqueeze(0)
404
 
405
- # 归一化音量
406
  style_tensor = style_tensor / (torch.max(torch.abs(style_tensor)) + 1e-6) * 0.95
407
 
408
- # 打印debug信息
409
  print(f"Content audio shape: {content_tensor.shape}, sample rate: {content_sr}")
410
  print(f"Style audio shape: {style_tensor.shape}, sample rate: {style_sr}")
411
 
412
- # 保存音频
413
  torchaudio.save(temp_content_path, content_tensor, content_sr)
414
  torchaudio.save(temp_style_path, style_tensor, style_sr)
415
 
416
  try:
417
- # 获取管道
418
  pipeline = get_pipeline("style")
419
 
420
- # 推理
421
  gen_audio = pipeline.inference_ar_and_fm(
422
  src_wav_path=temp_content_path,
423
  src_text=None,
@@ -425,14 +425,14 @@ def vevo_style(content_wav, style_wav):
425
  timbre_ref_wav_path=temp_content_path,
426
  )
427
 
428
- # 检查生成音频是否为数值异常
429
  if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
430
  print("Warning: Generated audio contains NaN or Inf values")
431
  gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
432
 
433
  print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
434
 
435
- # 保存生成的音频
436
  save_audio(gen_audio, output_path=output_path)
437
 
438
  return output_path
@@ -448,22 +448,22 @@ def vevo_timbre(content_wav, reference_wav):
448
  temp_reference_path = "wav/temp_reference.wav"
449
  output_path = "wav/output_vevotimbre.wav"
450
 
451
- # 检查并处理音频数据
452
  if content_wav is None or reference_wav is None:
453
  raise ValueError("Please upload audio files")
454
 
455
- # 处理内容音频格式
456
  if isinstance(content_wav, tuple) and len(content_wav) == 2:
457
  if isinstance(content_wav[0], np.ndarray):
458
  content_data, content_sr = content_wav
459
  else:
460
  content_sr, content_data = content_wav
461
 
462
- # 确保是单声道
463
  if len(content_data.shape) > 1 and content_data.shape[1] > 1:
464
  content_data = np.mean(content_data, axis=1)
465
 
466
- # 重采样到24kHz
467
  if content_sr != 24000:
468
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
469
  content_tensor = torchaudio.functional.resample(content_tensor, content_sr, 24000)
@@ -471,23 +471,23 @@ def vevo_timbre(content_wav, reference_wav):
471
  else:
472
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
473
 
474
- # 归一化音量
475
  content_tensor = content_tensor / (torch.max(torch.abs(content_tensor)) + 1e-6) * 0.95
476
  else:
477
  raise ValueError("Invalid content audio format")
478
 
479
- # 处理参考音频格式
480
  if isinstance(reference_wav, tuple) and len(reference_wav) == 2:
481
  if isinstance(reference_wav[0], np.ndarray):
482
  reference_data, reference_sr = reference_wav
483
  else:
484
  reference_sr, reference_data = reference_wav
485
 
486
- # 确保是单声道
487
  if len(reference_data.shape) > 1 and reference_data.shape[1] > 1:
488
  reference_data = np.mean(reference_data, axis=1)
489
 
490
- # 重采样到24kHz
491
  if reference_sr != 24000:
492
  reference_tensor = torch.FloatTensor(reference_data).unsqueeze(0)
493
  reference_tensor = torchaudio.functional.resample(reference_tensor, reference_sr, 24000)
@@ -495,38 +495,38 @@ def vevo_timbre(content_wav, reference_wav):
495
  else:
496
  reference_tensor = torch.FloatTensor(reference_data).unsqueeze(0)
497
 
498
- # 归一化音量
499
  reference_tensor = reference_tensor / (torch.max(torch.abs(reference_tensor)) + 1e-6) * 0.95
500
  else:
501
  raise ValueError("Invalid reference audio format")
502
 
503
- # 打印debug信息
504
  print(f"Content audio shape: {content_tensor.shape}, sample rate: {content_sr}")
505
  print(f"Reference audio shape: {reference_tensor.shape}, sample rate: {reference_sr}")
506
 
507
- # 保存上传的音频
508
  torchaudio.save(temp_content_path, content_tensor, content_sr)
509
  torchaudio.save(temp_reference_path, reference_tensor, reference_sr)
510
 
511
  try:
512
- # 获取管道
513
  pipeline = get_pipeline("timbre")
514
 
515
- # 推理
516
  gen_audio = pipeline.inference_fm(
517
  src_wav_path=temp_content_path,
518
  timbre_ref_wav_path=temp_reference_path,
519
  flow_matching_steps=32,
520
  )
521
 
522
- # 检查生成音频是否为数值异常
523
  if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
524
  print("Warning: Generated audio contains NaN or Inf values")
525
  gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
526
 
527
  print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
528
 
529
- # 保存生成的音频
530
  save_audio(gen_audio, output_path=output_path)
531
 
532
  return output_path
@@ -543,22 +543,22 @@ def vevo_voice(content_wav, style_reference_wav, timbre_reference_wav):
543
  temp_timbre_path = "wav/temp_timbre.wav"
544
  output_path = "wav/output_vevovoice.wav"
545
 
546
- # 检查并处理音频数据
547
  if content_wav is None or style_reference_wav is None or timbre_reference_wav is None:
548
  raise ValueError("Please upload all required audio files")
549
 
550
- # 处理内容音频格式
551
  if isinstance(content_wav, tuple) and len(content_wav) == 2:
552
  if isinstance(content_wav[0], np.ndarray):
553
  content_data, content_sr = content_wav
554
  else:
555
  content_sr, content_data = content_wav
556
 
557
- # 确保是单声道
558
  if len(content_data.shape) > 1 and content_data.shape[1] > 1:
559
  content_data = np.mean(content_data, axis=1)
560
 
561
- # 重采样到24kHz
562
  if content_sr != 24000:
563
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
564
  content_tensor = torchaudio.functional.resample(content_tensor, content_sr, 24000)
@@ -566,23 +566,23 @@ def vevo_voice(content_wav, style_reference_wav, timbre_reference_wav):
566
  else:
567
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
568
 
569
- # 归一化音量
570
  content_tensor = content_tensor / (torch.max(torch.abs(content_tensor)) + 1e-6) * 0.95
571
  else:
572
  raise ValueError("Invalid content audio format")
573
 
574
- # 处理风格参考音频格式
575
  if isinstance(style_reference_wav, tuple) and len(style_reference_wav) == 2:
576
  if isinstance(style_reference_wav[0], np.ndarray):
577
  style_data, style_sr = style_reference_wav
578
  else:
579
  style_sr, style_data = style_reference_wav
580
 
581
- # 确保是单声道
582
  if len(style_data.shape) > 1 and style_data.shape[1] > 1:
583
  style_data = np.mean(style_data, axis=1)
584
 
585
- # 重采样到24kHz
586
  if style_sr != 24000:
587
  style_tensor = torch.FloatTensor(style_data).unsqueeze(0)
588
  style_tensor = torchaudio.functional.resample(style_tensor, style_sr, 24000)
@@ -590,23 +590,23 @@ def vevo_voice(content_wav, style_reference_wav, timbre_reference_wav):
590
  else:
591
  style_tensor = torch.FloatTensor(style_data).unsqueeze(0)
592
 
593
- # 归一化音量
594
  style_tensor = style_tensor / (torch.max(torch.abs(style_tensor)) + 1e-6) * 0.95
595
  else:
596
  raise ValueError("Invalid style reference audio format")
597
 
598
- # 处理音色参考音频格式
599
  if isinstance(timbre_reference_wav, tuple) and len(timbre_reference_wav) == 2:
600
  if isinstance(timbre_reference_wav[0], np.ndarray):
601
  timbre_data, timbre_sr = timbre_reference_wav
602
  else:
603
  timbre_sr, timbre_data = timbre_reference_wav
604
 
605
- # 确保是单声道
606
  if len(timbre_data.shape) > 1 and timbre_data.shape[1] > 1:
607
  timbre_data = np.mean(timbre_data, axis=1)
608
 
609
- # 重采样到24kHz
610
  if timbre_sr != 24000:
611
  timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
612
  timbre_tensor = torchaudio.functional.resample(timbre_tensor, timbre_sr, 24000)
@@ -614,26 +614,26 @@ def vevo_voice(content_wav, style_reference_wav, timbre_reference_wav):
614
  else:
615
  timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
616
 
617
- # 归一化音量
618
  timbre_tensor = timbre_tensor / (torch.max(torch.abs(timbre_tensor)) + 1e-6) * 0.95
619
  else:
620
  raise ValueError("Invalid timbre reference audio format")
621
 
622
- # 打印debug信息
623
  print(f"Content audio shape: {content_tensor.shape}, sample rate: {content_sr}")
624
  print(f"Style reference audio shape: {style_tensor.shape}, sample rate: {style_sr}")
625
  print(f"Timbre reference audio shape: {timbre_tensor.shape}, sample rate: {timbre_sr}")
626
 
627
- # 保存上传的音频
628
  torchaudio.save(temp_content_path, content_tensor, content_sr)
629
  torchaudio.save(temp_style_path, style_tensor, style_sr)
630
  torchaudio.save(temp_timbre_path, timbre_tensor, timbre_sr)
631
 
632
  try:
633
- # 获取管道
634
  pipeline = get_pipeline("voice")
635
 
636
- # 推理
637
  gen_audio = pipeline.inference_ar_and_fm(
638
  src_wav_path=temp_content_path,
639
  src_text=None,
@@ -641,14 +641,14 @@ def vevo_voice(content_wav, style_reference_wav, timbre_reference_wav):
641
  timbre_ref_wav_path=temp_timbre_path,
642
  )
643
 
644
- # 检查生成音频是否为数值异常
645
  if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
646
  print("Warning: Generated audio contains NaN or Inf values")
647
  gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
648
 
649
  print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
650
 
651
- # 保存生成��音频
652
  save_audio(gen_audio, output_path=output_path)
653
 
654
  return output_path
@@ -664,22 +664,22 @@ def vevo_tts(text, ref_wav, timbre_ref_wav=None, style_ref_text=None, src_langua
664
  temp_timbre_path = "wav/temp_timbre.wav"
665
  output_path = "wav/output_vevotts.wav"
666
 
667
- # 检查并处理音频数据
668
  if ref_wav is None:
669
  raise ValueError("Please upload a reference audio file")
670
 
671
- # 处理参考音频格式
672
  if isinstance(ref_wav, tuple) and len(ref_wav) == 2:
673
  if isinstance(ref_wav[0], np.ndarray):
674
  ref_data, ref_sr = ref_wav
675
  else:
676
  ref_sr, ref_data = ref_wav
677
 
678
- # 确保是单声道
679
  if len(ref_data.shape) > 1 and ref_data.shape[1] > 1:
680
  ref_data = np.mean(ref_data, axis=1)
681
 
682
- # 重采样到24kHz
683
  if ref_sr != 24000:
684
  ref_tensor = torch.FloatTensor(ref_data).unsqueeze(0)
685
  ref_tensor = torchaudio.functional.resample(ref_tensor, ref_sr, 24000)
@@ -687,17 +687,17 @@ def vevo_tts(text, ref_wav, timbre_ref_wav=None, style_ref_text=None, src_langua
687
  else:
688
  ref_tensor = torch.FloatTensor(ref_data).unsqueeze(0)
689
 
690
- # 归一化音量
691
  ref_tensor = ref_tensor / (torch.max(torch.abs(ref_tensor)) + 1e-6) * 0.95
692
  else:
693
  raise ValueError("Invalid reference audio format")
694
 
695
- # 打印debug信息
696
  print(f"Reference audio shape: {ref_tensor.shape}, sample rate: {ref_sr}")
697
  if style_ref_text:
698
  print(f"Style reference text: {style_ref_text}, language: {style_ref_text_language}")
699
 
700
- # 保存上传的音频
701
  torchaudio.save(temp_ref_path, ref_tensor, ref_sr)
702
 
703
  if timbre_ref_wav is not None:
@@ -707,11 +707,11 @@ def vevo_tts(text, ref_wav, timbre_ref_wav=None, style_ref_text=None, src_langua
707
  else:
708
  timbre_sr, timbre_data = timbre_ref_wav
709
 
710
- # 确保是单声道
711
  if len(timbre_data.shape) > 1 and timbre_data.shape[1] > 1:
712
  timbre_data = np.mean(timbre_data, axis=1)
713
 
714
- # 重采样到24kHz
715
  if timbre_sr != 24000:
716
  timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
717
  timbre_tensor = torchaudio.functional.resample(timbre_tensor, timbre_sr, 24000)
@@ -719,7 +719,7 @@ def vevo_tts(text, ref_wav, timbre_ref_wav=None, style_ref_text=None, src_langua
719
  else:
720
  timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
721
 
722
- # 归一化音量
723
  timbre_tensor = timbre_tensor / (torch.max(torch.abs(timbre_tensor)) + 1e-6) * 0.95
724
 
725
  print(f"Timbre reference audio shape: {timbre_tensor.shape}, sample rate: {timbre_sr}")
@@ -730,10 +730,10 @@ def vevo_tts(text, ref_wav, timbre_ref_wav=None, style_ref_text=None, src_langua
730
  temp_timbre_path = temp_ref_path
731
 
732
  try:
733
- # 获取管道
734
  pipeline = get_pipeline("tts")
735
 
736
- # 推理
737
  gen_audio = pipeline.inference_ar_and_fm(
738
  src_wav_path=None,
739
  src_text=text,
@@ -744,14 +744,14 @@ def vevo_tts(text, ref_wav, timbre_ref_wav=None, style_ref_text=None, src_langua
744
  style_ref_wav_text_language=style_ref_text_language,
745
  )
746
 
747
- # 检查生成音频是否为数值异常
748
  if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
749
  print("Warning: Generated audio contains NaN or Inf values")
750
  gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
751
 
752
  print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
753
 
754
- # 保存生成的音频
755
  save_audio(gen_audio, output_path=output_path)
756
 
757
  return output_path
@@ -761,10 +761,10 @@ def vevo_tts(text, ref_wav, timbre_ref_wav=None, style_ref_text=None, src_langua
761
  traceback.print_exc()
762
  raise e
763
 
764
- # 创建Gradio界面
765
  with gr.Blocks(title="Vevo: Controllable Zero-Shot Voice Imitation with Self-Supervised Disentanglement") as demo:
766
  gr.Markdown("# Vevo: Controllable Zero-Shot Voice Imitation with Self-Supervised Disentanglement")
767
- # 添加链接标签行
768
  with gr.Row(elem_id="links_row"):
769
  gr.HTML("""
770
  <div style="display: flex; justify-content: flex-start; gap: 8px; margin: 0 0; padding-left: 0px;">
@@ -850,5 +850,5 @@ with gr.Blocks(title="Vevo: Controllable Zero-Shot Voice Imitation with Self-Sup
850
  For more information, visit the [Amphion project](https://github.com/open-mmlab/Amphion)
851
  """)
852
 
853
- # 启动应用
854
  demo.launch()
 
13
  import spaces
14
 
15
  def install_espeak():
16
+ """Detect and install espeak-ng dependency"""
17
  try:
18
+ # Check if espeak-ng is already installed
19
  result = subprocess.run(["which", "espeak-ng"], capture_output=True, text=True)
20
  if result.returncode != 0:
21
+ print("Detected espeak-ng not installed in the system, attempting to install...")
22
+ # Try to install espeak-ng and its data using apt-get
23
  subprocess.run(["apt-get", "update"], check=True)
24
+ # Install espeak-ng and the corresponding language data package
25
  subprocess.run(["apt-get", "install", "-y", "espeak-ng", "espeak-ng-data"], check=True)
26
+ print("espeak-ng and its data packages installed successfully!")
27
  else:
28
+ print("espeak-ng is already installed in the system.")
29
+ # Even if already installed, try to update data to ensure integrity (optional but sometimes helpful)
30
+ # print("Attempting to update espeak-ng data...")
31
  # subprocess.run(["apt-get", "update"], check=True)
32
  # subprocess.run(["apt-get", "install", "--only-upgrade", "-y", "espeak-ng-data"], check=True)
33
 
34
+ # Verify Chinese support (optional)
35
  try:
36
  voices_result = subprocess.run(["espeak-ng", "--voices=cmn"], capture_output=True, text=True, check=True)
37
  if "cmn" in voices_result.stdout:
38
+ print("espeak-ng supports 'cmn' language.")
39
  else:
40
+ print("Warning: espeak-ng is installed, but 'cmn' language still seems unavailable.")
41
  except Exception as e:
42
+ print(f"Error verifying espeak-ng Chinese support (may not affect functionality): {e}")
43
 
44
  except Exception as e:
45
+ print(f"Error installing espeak-ng: {e}")
46
+ print("Please try to run manually: apt-get update && apt-get install -y espeak-ng espeak-ng-data")
47
 
48
+ # Install espeak before all other operations
49
  install_espeak()
50
 
51
  def patch_langsegment_init():
52
  try:
53
+ # Try to find the location of the LangSegment package
54
  spec = importlib.util.find_spec("LangSegment")
55
  if spec is None or spec.origin is None:
56
+ print("Unable to locate LangSegment package.")
57
  return
58
 
59
+ # Build the path to __init__.py
60
  init_path = os.path.join(os.path.dirname(spec.origin), '__init__.py')
61
 
62
  if not os.path.exists(init_path):
63
+ print(f"LangSegment __init__.py file not found at: {init_path}")
64
+ # Try to find in site-packages, applicable in some environments
65
  for site_pkg_path in site.getsitepackages():
66
  potential_path = os.path.join(site_pkg_path, 'LangSegment', '__init__.py')
67
  if os.path.exists(potential_path):
68
  init_path = potential_path
69
+ print(f"Found __init__.py in site-packages: {init_path}")
70
  break
71
+ else: # If the loop ends normally (no break)
72
+ print(f"Also unable to find __init__.py in site-packages")
73
  return
74
 
75
 
76
+ print(f"Attempting to read LangSegment __init__.py: {init_path}")
77
  with open(init_path, 'r') as f:
78
  lines = f.readlines()
79
 
 
85
  stripped_line = line.strip()
86
  if stripped_line.startswith(target_line_prefix):
87
  if 'setLangfilters' in stripped_line or 'getLangfilters' in stripped_line:
88
+ print(f"Found line that needs modification: {stripped_line}")
89
+ # Remove setLangfilters and getLangfilters
90
  modified_line = stripped_line.replace(',setLangfilters', '')
91
  modified_line = modified_line.replace(',getLangfilters', '')
92
+ # Ensure comma handling is correct (e.g., if they are the last items)
93
  modified_line = modified_line.replace('setLangfilters,', '')
94
  modified_line = modified_line.replace('getLangfilters,', '')
95
+ # If they are the only extra imports, remove any redundant commas
96
  modified_line = modified_line.rstrip(',')
97
  new_lines.append(modified_line + '\n')
98
  modified = True
99
+ print(f"Modified line: {modified_line.strip()}")
100
  else:
101
+ new_lines.append(line) # Line is fine, keep as is
102
  else:
103
+ new_lines.append(line) # Non-target line, keep as is
104
 
105
  if modified:
106
+ print(f"Attempting to write back modified LangSegment __init__.py to: {init_path}")
107
  try:
108
  with open(init_path, 'w') as f:
109
  f.writelines(new_lines)
110
+ print("LangSegment __init__.py modified successfully.")
111
+ # Try to reload the module to make changes effective (may not work, depending on import chain)
112
  try:
113
  import LangSegment
114
  importlib.reload(LangSegment)
115
+ print("LangSegment module has been attempted to reload.")
116
  except Exception as reload_e:
117
+ print(f"Error reloading LangSegment (may have no impact): {reload_e}")
118
  except PermissionError:
119
+ print(f"Error: Insufficient permissions to modify {init_path}. Consider modifying requirements.txt.")
120
  except Exception as write_e:
121
+ print(f"Other error occurred when writing LangSegment __init__.py: {write_e}")
122
  else:
123
+ print("LangSegment __init__.py doesn't need modification.")
124
 
125
  except ImportError:
126
+ print("LangSegment package not found, unable to fix.")
127
  except Exception as e:
128
+ print(f"Unexpected error occurred when fixing LangSegment package: {e}")
129
 
130
+ # Execute the fix before all other imports (especially Amphion) that might trigger LangSegment
131
  patch_langsegment_init()
132
 
133
+ # Clone Amphion repository
134
  if not os.path.exists("Amphion"):
135
  subprocess.run(["git", "clone", "https://github.com/open-mmlab/Amphion.git"])
136
  os.chdir("Amphion")
 
138
  if not os.getcwd().endswith("Amphion"):
139
  os.chdir("Amphion")
140
 
141
+ # Add Amphion to the path
142
  if os.path.dirname(os.path.abspath("Amphion")) not in sys.path:
143
  sys.path.append(os.path.dirname(os.path.abspath("Amphion")))
144
 
145
+ # Ensure needed directories exist
146
  os.makedirs("wav", exist_ok=True)
147
  os.makedirs("ckpts/Vevo", exist_ok=True)
148
 
149
  from models.vc.vevo.vevo_utils import VevoInferencePipeline, save_audio, load_wav
150
 
151
+ # Download and setup config files
152
  def setup_configs():
153
  config_path = "models/vc/vevo/config"
154
  os.makedirs(config_path, exist_ok=True)
 
171
  repo_type="model",
172
  )
173
  os.makedirs(os.path.dirname(file_path), exist_ok=True)
174
+ # Copy file to target location
175
  subprocess.run(["cp", file_data, file_path])
176
  except Exception as e:
177
+ print(f"Error downloading config file {file}: {e}")
178
 
179
  setup_configs()
180
 
181
+ # Device configuration
182
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
183
+ print(f"Using device: {device}")
184
 
185
+ # Initialize pipeline dictionary
186
  inference_pipelines = {}
187
 
188
  def get_pipeline(pipeline_type):
189
  if pipeline_type in inference_pipelines:
190
  return inference_pipelines[pipeline_type]
191
 
192
+ # Initialize pipeline based on the required pipeline type
193
  if pipeline_type == "style" or pipeline_type == "voice":
194
+ # Download Content Tokenizer
195
  local_dir = snapshot_download(
196
  repo_id="amphion/Vevo",
197
  repo_type="model",
 
202
  local_dir, "tokenizer/vq32/hubert_large_l18_c32.pkl"
203
  )
204
 
205
+ # Download Content-Style Tokenizer
206
  local_dir = snapshot_download(
207
  repo_id="amphion/Vevo",
208
  repo_type="model",
 
211
  )
212
  content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
213
 
214
+ # Download Autoregressive Transformer
215
  local_dir = snapshot_download(
216
  repo_id="amphion/Vevo",
217
  repo_type="model",
 
221
  ar_cfg_path = "./models/vc/vevo/config/Vq32ToVq8192.json"
222
  ar_ckpt_path = os.path.join(local_dir, "contentstyle_modeling/Vq32ToVq8192")
223
 
224
+ # Download Flow Matching Transformer
225
  local_dir = snapshot_download(
226
  repo_id="amphion/Vevo",
227
  repo_type="model",
 
231
  fmt_cfg_path = "./models/vc/vevo/config/Vq8192ToMels.json"
232
  fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels")
233
 
234
+ # Download Vocoder
235
  local_dir = snapshot_download(
236
  repo_id="amphion/Vevo",
237
  repo_type="model",
 
241
  vocoder_cfg_path = "./models/vc/vevo/config/Vocoder.json"
242
  vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
243
 
244
+ # Initialize pipeline
245
  inference_pipeline = VevoInferencePipeline(
246
  content_tokenizer_ckpt_path=content_tokenizer_ckpt_path,
247
  content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
 
255
  )
256
 
257
  elif pipeline_type == "timbre":
258
+ # Download Content-Style Tokenizer (only needed for timbre)
259
  local_dir = snapshot_download(
260
  repo_id="amphion/Vevo",
261
  repo_type="model",
 
264
  )
265
  content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
266
 
267
+ # Download Flow Matching Transformer
268
  local_dir = snapshot_download(
269
  repo_id="amphion/Vevo",
270
  repo_type="model",
 
274
  fmt_cfg_path = "./models/vc/vevo/config/Vq8192ToMels.json"
275
  fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels")
276
 
277
+ # Download Vocoder
278
  local_dir = snapshot_download(
279
  repo_id="amphion/Vevo",
280
  repo_type="model",
 
284
  vocoder_cfg_path = "./models/vc/vevo/config/Vocoder.json"
285
  vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
286
 
287
+ # Initialize pipeline
288
  inference_pipeline = VevoInferencePipeline(
289
  content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
290
  fmt_cfg_path=fmt_cfg_path,
 
295
  )
296
 
297
  elif pipeline_type == "tts":
298
+ # Download Content-Style Tokenizer
299
  local_dir = snapshot_download(
300
  repo_id="amphion/Vevo",
301
  repo_type="model",
 
304
  )
305
  content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
306
 
307
+ # Download Autoregressive Transformer (TTS specific)
308
  local_dir = snapshot_download(
309
  repo_id="amphion/Vevo",
310
  repo_type="model",
 
314
  ar_cfg_path = "./models/vc/vevo/config/PhoneToVq8192.json"
315
  ar_ckpt_path = os.path.join(local_dir, "contentstyle_modeling/PhoneToVq8192")
316
 
317
+ # Download Flow Matching Transformer
318
  local_dir = snapshot_download(
319
  repo_id="amphion/Vevo",
320
  repo_type="model",
 
324
  fmt_cfg_path = "./models/vc/vevo/config/Vq8192ToMels.json"
325
  fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels")
326
 
327
+ # Download Vocoder
328
  local_dir = snapshot_download(
329
  repo_id="amphion/Vevo",
330
  repo_type="model",
 
334
  vocoder_cfg_path = "./models/vc/vevo/config/Vocoder.json"
335
  vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
336
 
337
+ # Initialize pipeline
338
  inference_pipeline = VevoInferencePipeline(
339
  content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
340
  ar_cfg_path=ar_cfg_path,
 
346
  device=device,
347
  )
348
 
349
+ # Cache pipeline instance
350
  inference_pipelines[pipeline_type] = inference_pipeline
351
  return inference_pipeline
352
 
353
+ # Implement VEVO functionality functions
354
  @spaces.GPU()
355
  def vevo_style(content_wav, style_wav):
356
  temp_content_path = "wav/temp_content.wav"
357
  temp_style_path = "wav/temp_style.wav"
358
  output_path = "wav/output_vevostyle.wav"
359
 
360
+ # Check and process audio data
361
  if content_wav is None or style_wav is None:
362
  raise ValueError("Please upload audio files")
363
 
364
+ # Process audio format
365
  if isinstance(content_wav, tuple) and len(content_wav) == 2:
366
  if isinstance(content_wav[0], np.ndarray):
367
  content_data, content_sr = content_wav
368
  else:
369
  content_sr, content_data = content_wav
370
 
371
+ # Ensure single channel
372
  if len(content_data.shape) > 1 and content_data.shape[1] > 1:
373
  content_data = np.mean(content_data, axis=1)
374
 
375
+ # Resample to 24kHz
376
  if content_sr != 24000:
377
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
378
  content_tensor = torchaudio.functional.resample(content_tensor, content_sr, 24000)
 
380
  else:
381
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
382
 
383
+ # Normalize volume
384
  content_tensor = content_tensor / (torch.max(torch.abs(content_tensor)) + 1e-6) * 0.95
385
  else:
386
  raise ValueError("Invalid content audio format")
 
390
  else:
391
  style_sr, style_data = style_wav
392
 
393
+ # Ensure single channel
394
  if len(style_data.shape) > 1 and style_data.shape[1] > 1:
395
  style_data = np.mean(style_data, axis=1)
396
 
397
+ # Resample to 24kHz
398
  if style_sr != 24000:
399
  style_tensor = torch.FloatTensor(style_data).unsqueeze(0)
400
  style_tensor = torchaudio.functional.resample(style_tensor, style_sr, 24000)
 
402
  else:
403
  style_tensor = torch.FloatTensor(style_data).unsqueeze(0)
404
 
405
+ # Normalize volume
406
  style_tensor = style_tensor / (torch.max(torch.abs(style_tensor)) + 1e-6) * 0.95
407
 
408
+ # Print debug information
409
  print(f"Content audio shape: {content_tensor.shape}, sample rate: {content_sr}")
410
  print(f"Style audio shape: {style_tensor.shape}, sample rate: {style_sr}")
411
 
412
+ # Save audio
413
  torchaudio.save(temp_content_path, content_tensor, content_sr)
414
  torchaudio.save(temp_style_path, style_tensor, style_sr)
415
 
416
  try:
417
+ # Get pipeline
418
  pipeline = get_pipeline("style")
419
 
420
+ # Inference
421
  gen_audio = pipeline.inference_ar_and_fm(
422
  src_wav_path=temp_content_path,
423
  src_text=None,
 
425
  timbre_ref_wav_path=temp_content_path,
426
  )
427
 
428
+ # Check if generated audio is numerical anomaly
429
  if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
430
  print("Warning: Generated audio contains NaN or Inf values")
431
  gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
432
 
433
  print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
434
 
435
+ # Save generated audio
436
  save_audio(gen_audio, output_path=output_path)
437
 
438
  return output_path
 
448
  temp_reference_path = "wav/temp_reference.wav"
449
  output_path = "wav/output_vevotimbre.wav"
450
 
451
+ # Check and process audio data
452
  if content_wav is None or reference_wav is None:
453
  raise ValueError("Please upload audio files")
454
 
455
+ # Process content audio format
456
  if isinstance(content_wav, tuple) and len(content_wav) == 2:
457
  if isinstance(content_wav[0], np.ndarray):
458
  content_data, content_sr = content_wav
459
  else:
460
  content_sr, content_data = content_wav
461
 
462
+ # Ensure single channel
463
  if len(content_data.shape) > 1 and content_data.shape[1] > 1:
464
  content_data = np.mean(content_data, axis=1)
465
 
466
+ # Resample to 24kHz
467
  if content_sr != 24000:
468
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
469
  content_tensor = torchaudio.functional.resample(content_tensor, content_sr, 24000)
 
471
  else:
472
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
473
 
474
+ # Normalize volume
475
  content_tensor = content_tensor / (torch.max(torch.abs(content_tensor)) + 1e-6) * 0.95
476
  else:
477
  raise ValueError("Invalid content audio format")
478
 
479
+ # Process reference audio format
480
  if isinstance(reference_wav, tuple) and len(reference_wav) == 2:
481
  if isinstance(reference_wav[0], np.ndarray):
482
  reference_data, reference_sr = reference_wav
483
  else:
484
  reference_sr, reference_data = reference_wav
485
 
486
+ # Ensure single channel
487
  if len(reference_data.shape) > 1 and reference_data.shape[1] > 1:
488
  reference_data = np.mean(reference_data, axis=1)
489
 
490
+ # Resample to 24kHz
491
  if reference_sr != 24000:
492
  reference_tensor = torch.FloatTensor(reference_data).unsqueeze(0)
493
  reference_tensor = torchaudio.functional.resample(reference_tensor, reference_sr, 24000)
 
495
  else:
496
  reference_tensor = torch.FloatTensor(reference_data).unsqueeze(0)
497
 
498
+ # Normalize volume
499
  reference_tensor = reference_tensor / (torch.max(torch.abs(reference_tensor)) + 1e-6) * 0.95
500
  else:
501
  raise ValueError("Invalid reference audio format")
502
 
503
+ # Print debug information
504
  print(f"Content audio shape: {content_tensor.shape}, sample rate: {content_sr}")
505
  print(f"Reference audio shape: {reference_tensor.shape}, sample rate: {reference_sr}")
506
 
507
+ # Save uploaded audio
508
  torchaudio.save(temp_content_path, content_tensor, content_sr)
509
  torchaudio.save(temp_reference_path, reference_tensor, reference_sr)
510
 
511
  try:
512
+ # Get pipeline
513
  pipeline = get_pipeline("timbre")
514
 
515
+ # Inference
516
  gen_audio = pipeline.inference_fm(
517
  src_wav_path=temp_content_path,
518
  timbre_ref_wav_path=temp_reference_path,
519
  flow_matching_steps=32,
520
  )
521
 
522
+ # Check if generated audio is numerical anomaly
523
  if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
524
  print("Warning: Generated audio contains NaN or Inf values")
525
  gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
526
 
527
  print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
528
 
529
+ # Save generated audio
530
  save_audio(gen_audio, output_path=output_path)
531
 
532
  return output_path
 
543
  temp_timbre_path = "wav/temp_timbre.wav"
544
  output_path = "wav/output_vevovoice.wav"
545
 
546
+ # Check and process audio data
547
  if content_wav is None or style_reference_wav is None or timbre_reference_wav is None:
548
  raise ValueError("Please upload all required audio files")
549
 
550
+ # Process content audio format
551
  if isinstance(content_wav, tuple) and len(content_wav) == 2:
552
  if isinstance(content_wav[0], np.ndarray):
553
  content_data, content_sr = content_wav
554
  else:
555
  content_sr, content_data = content_wav
556
 
557
+ # Ensure single channel
558
  if len(content_data.shape) > 1 and content_data.shape[1] > 1:
559
  content_data = np.mean(content_data, axis=1)
560
 
561
+ # Resample to 24kHz
562
  if content_sr != 24000:
563
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
564
  content_tensor = torchaudio.functional.resample(content_tensor, content_sr, 24000)
 
566
  else:
567
  content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
568
 
569
+ # Normalize volume
570
  content_tensor = content_tensor / (torch.max(torch.abs(content_tensor)) + 1e-6) * 0.95
571
  else:
572
  raise ValueError("Invalid content audio format")
573
 
574
+ # Process style reference audio format
575
  if isinstance(style_reference_wav, tuple) and len(style_reference_wav) == 2:
576
  if isinstance(style_reference_wav[0], np.ndarray):
577
  style_data, style_sr = style_reference_wav
578
  else:
579
  style_sr, style_data = style_reference_wav
580
 
581
+ # Ensure single channel
582
  if len(style_data.shape) > 1 and style_data.shape[1] > 1:
583
  style_data = np.mean(style_data, axis=1)
584
 
585
+ # Resample to 24kHz
586
  if style_sr != 24000:
587
  style_tensor = torch.FloatTensor(style_data).unsqueeze(0)
588
  style_tensor = torchaudio.functional.resample(style_tensor, style_sr, 24000)
 
590
  else:
591
  style_tensor = torch.FloatTensor(style_data).unsqueeze(0)
592
 
593
+ # Normalize volume
594
  style_tensor = style_tensor / (torch.max(torch.abs(style_tensor)) + 1e-6) * 0.95
595
  else:
596
  raise ValueError("Invalid style reference audio format")
597
 
598
+ # Process timbre reference audio format
599
  if isinstance(timbre_reference_wav, tuple) and len(timbre_reference_wav) == 2:
600
  if isinstance(timbre_reference_wav[0], np.ndarray):
601
  timbre_data, timbre_sr = timbre_reference_wav
602
  else:
603
  timbre_sr, timbre_data = timbre_reference_wav
604
 
605
+ # Ensure single channel
606
  if len(timbre_data.shape) > 1 and timbre_data.shape[1] > 1:
607
  timbre_data = np.mean(timbre_data, axis=1)
608
 
609
+ # Resample to 24kHz
610
  if timbre_sr != 24000:
611
  timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
612
  timbre_tensor = torchaudio.functional.resample(timbre_tensor, timbre_sr, 24000)
 
614
  else:
615
  timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
616
 
617
+ # Normalize volume
618
  timbre_tensor = timbre_tensor / (torch.max(torch.abs(timbre_tensor)) + 1e-6) * 0.95
619
  else:
620
  raise ValueError("Invalid timbre reference audio format")
621
 
622
+ # Print debug information
623
  print(f"Content audio shape: {content_tensor.shape}, sample rate: {content_sr}")
624
  print(f"Style reference audio shape: {style_tensor.shape}, sample rate: {style_sr}")
625
  print(f"Timbre reference audio shape: {timbre_tensor.shape}, sample rate: {timbre_sr}")
626
 
627
+ # Save uploaded audio
628
  torchaudio.save(temp_content_path, content_tensor, content_sr)
629
  torchaudio.save(temp_style_path, style_tensor, style_sr)
630
  torchaudio.save(temp_timbre_path, timbre_tensor, timbre_sr)
631
 
632
  try:
633
+ # Get pipeline
634
  pipeline = get_pipeline("voice")
635
 
636
+ # Inference
637
  gen_audio = pipeline.inference_ar_and_fm(
638
  src_wav_path=temp_content_path,
639
  src_text=None,
 
641
  timbre_ref_wav_path=temp_timbre_path,
642
  )
643
 
644
+ # Check if generated audio is numerical anomaly
645
  if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
646
  print("Warning: Generated audio contains NaN or Inf values")
647
  gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
648
 
649
  print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
650
 
651
+ # Save generated audio
652
  save_audio(gen_audio, output_path=output_path)
653
 
654
  return output_path
 
664
  temp_timbre_path = "wav/temp_timbre.wav"
665
  output_path = "wav/output_vevotts.wav"
666
 
667
+ # Check and process audio data
668
  if ref_wav is None:
669
  raise ValueError("Please upload a reference audio file")
670
 
671
+ # Process reference audio format
672
  if isinstance(ref_wav, tuple) and len(ref_wav) == 2:
673
  if isinstance(ref_wav[0], np.ndarray):
674
  ref_data, ref_sr = ref_wav
675
  else:
676
  ref_sr, ref_data = ref_wav
677
 
678
+ # Ensure single channel
679
  if len(ref_data.shape) > 1 and ref_data.shape[1] > 1:
680
  ref_data = np.mean(ref_data, axis=1)
681
 
682
+ # Resample to 24kHz
683
  if ref_sr != 24000:
684
  ref_tensor = torch.FloatTensor(ref_data).unsqueeze(0)
685
  ref_tensor = torchaudio.functional.resample(ref_tensor, ref_sr, 24000)
 
687
  else:
688
  ref_tensor = torch.FloatTensor(ref_data).unsqueeze(0)
689
 
690
+ # Normalize volume
691
  ref_tensor = ref_tensor / (torch.max(torch.abs(ref_tensor)) + 1e-6) * 0.95
692
  else:
693
  raise ValueError("Invalid reference audio format")
694
 
695
+ # Print debug information
696
  print(f"Reference audio shape: {ref_tensor.shape}, sample rate: {ref_sr}")
697
  if style_ref_text:
698
  print(f"Style reference text: {style_ref_text}, language: {style_ref_text_language}")
699
 
700
+ # Save uploaded audio
701
  torchaudio.save(temp_ref_path, ref_tensor, ref_sr)
702
 
703
  if timbre_ref_wav is not None:
 
707
  else:
708
  timbre_sr, timbre_data = timbre_ref_wav
709
 
710
+ # Ensure single channel
711
  if len(timbre_data.shape) > 1 and timbre_data.shape[1] > 1:
712
  timbre_data = np.mean(timbre_data, axis=1)
713
 
714
+ # Resample to 24kHz
715
  if timbre_sr != 24000:
716
  timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
717
  timbre_tensor = torchaudio.functional.resample(timbre_tensor, timbre_sr, 24000)
 
719
  else:
720
  timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
721
 
722
+ # Normalize volume
723
  timbre_tensor = timbre_tensor / (torch.max(torch.abs(timbre_tensor)) + 1e-6) * 0.95
724
 
725
  print(f"Timbre reference audio shape: {timbre_tensor.shape}, sample rate: {timbre_sr}")
 
730
  temp_timbre_path = temp_ref_path
731
 
732
  try:
733
+ # Get pipeline
734
  pipeline = get_pipeline("tts")
735
 
736
+ # Inference
737
  gen_audio = pipeline.inference_ar_and_fm(
738
  src_wav_path=None,
739
  src_text=text,
 
744
  style_ref_wav_text_language=style_ref_text_language,
745
  )
746
 
747
+ # Check if generated audio is numerical anomaly
748
  if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
749
  print("Warning: Generated audio contains NaN or Inf values")
750
  gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
751
 
752
  print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
753
 
754
+ # Save generated audio
755
  save_audio(gen_audio, output_path=output_path)
756
 
757
  return output_path
 
761
  traceback.print_exc()
762
  raise e
763
 
764
+ # Create Gradio interface
765
  with gr.Blocks(title="Vevo: Controllable Zero-Shot Voice Imitation with Self-Supervised Disentanglement") as demo:
766
  gr.Markdown("# Vevo: Controllable Zero-Shot Voice Imitation with Self-Supervised Disentanglement")
767
+ # Add link tag line
768
  with gr.Row(elem_id="links_row"):
769
  gr.HTML("""
770
  <div style="display: flex; justify-content: flex-start; gap: 8px; margin: 0 0; padding-left: 0px;">
 
850
  For more information, visit the [Amphion project](https://github.com/open-mmlab/Amphion)
851
  """)
852
 
853
+ # Launch application
854
  demo.launch()