Vevo / app.py
积极的屁孩
first commit
2b30c39
raw
history blame
35.2 kB
import os
import sys
import importlib.util
import site
import json
import torch
import gradio as gr
import torchaudio
import numpy as np
from huggingface_hub import snapshot_download, hf_hub_download
import subprocess
import re
def install_espeak():
"""检测并安装espeak-ng依赖"""
try:
# 检查espeak-ng是否已安装
result = subprocess.run(["which", "espeak-ng"], capture_output=True, text=True)
if result.returncode != 0:
print("检测到系统中未安装espeak-ng,正在尝试安装...")
# 尝试使用apt-get安装espeak-ng及其数据
subprocess.run(["apt-get", "update"], check=True)
# 安装 espeak-ng 和对应的语言数据包
subprocess.run(["apt-get", "install", "-y", "espeak-ng", "espeak-ng-data"], check=True)
print("espeak-ng及其数据包安装成功!")
else:
print("espeak-ng已安装在系统中。")
# 即使已安装,也尝试更新数据确保完整性 (可选,但有时有帮助)
# print("尝试更新 espeak-ng 数据...")
# subprocess.run(["apt-get", "update"], check=True)
# subprocess.run(["apt-get", "install", "--only-upgrade", "-y", "espeak-ng-data"], check=True)
# 验证中文支持 (可选)
try:
voices_result = subprocess.run(["espeak-ng", "--voices=cmn"], capture_output=True, text=True, check=True)
if "cmn" in voices_result.stdout:
print("espeak-ng 支持 'cmn' 语言。")
else:
print("警告:espeak-ng 安装了,但 'cmn' 语言似乎仍不可用。")
except Exception as e:
print(f"验证 espeak-ng 中文支持时出错(可能不影响功能): {e}")
except Exception as e:
print(f"安装espeak-ng时出错: {e}")
print("请尝试手动运行: apt-get update && apt-get install -y espeak-ng espeak-ng-data")
# 在所有其他操作之前安装espeak
install_espeak()
def patch_langsegment_init():
try:
# 尝试找到 LangSegment 包的位置
spec = importlib.util.find_spec("LangSegment")
if spec is None or spec.origin is None:
print("无法定位 LangSegment 包。")
return
# 构建 __init__.py 的路径
init_path = os.path.join(os.path.dirname(spec.origin), '__init__.py')
if not os.path.exists(init_path):
print(f"未找到 LangSegment 的 __init__.py 文件于: {init_path}")
# 尝试在 site-packages 中查找,适用于某些环境
for site_pkg_path in site.getsitepackages():
potential_path = os.path.join(site_pkg_path, 'LangSegment', '__init__.py')
if os.path.exists(potential_path):
init_path = potential_path
print(f"在 site-packages 中找到 __init__.py: {init_path}")
break
else: # 如果循环正常结束(没有 break)
print(f"在 site-packages 中也未找到 __init__.py")
return
print(f"尝试读取 LangSegment __init__.py: {init_path}")
with open(init_path, 'r') as f:
lines = f.readlines()
modified = False
new_lines = []
target_line_prefix = "from .LangSegment import"
for line in lines:
stripped_line = line.strip()
if stripped_line.startswith(target_line_prefix):
if 'setLangfilters' in stripped_line or 'getLangfilters' in stripped_line:
print(f"发现需要修改的行: {stripped_line}")
# 移除 setLangfilters 和 getLangfilters
modified_line = stripped_line.replace(',setLangfilters', '')
modified_line = modified_line.replace(',getLangfilters', '')
# 确保逗号处理正确 (例如,如果它们是末尾的项)
modified_line = modified_line.replace('setLangfilters,', '')
modified_line = modified_line.replace('getLangfilters,', '')
# 如果它们是唯一的额外导入,移除可能多余的逗号
modified_line = modified_line.rstrip(',')
new_lines.append(modified_line + '\n')
modified = True
print(f"修改后的行: {modified_line.strip()}")
else:
new_lines.append(line) # 行没问题,保留原样
else:
new_lines.append(line) # 非目标行,保留原样
if modified:
print(f"尝试写回已修改的 LangSegment __init__.py 到: {init_path}")
try:
with open(init_path, 'w') as f:
f.writelines(new_lines)
print("LangSegment __init__.py 修改成功。")
# 尝试重新加载模块以使更改生效(可能无效,取决于导入链)
try:
import LangSegment
importlib.reload(LangSegment)
print("LangSegment 模块已尝试重新加载。")
except Exception as reload_e:
print(f"重新加载 LangSegment 时出错(可能无影响): {reload_e}")
except PermissionError:
print(f"错误:权限不足,无法修改 {init_path}。请考虑修改 requirements.txt。")
except Exception as write_e:
print(f"写入 LangSegment __init__.py 时发生其他错误: {write_e}")
else:
print("LangSegment __init__.py 无需修改。")
except ImportError:
print("未找到 LangSegment 包,无法进行修复。")
except Exception as e:
print(f"修复 LangSegment 包时发生意外错误: {e}")
# 在所有其他导入(尤其是可能触发 LangSegment 导入的 Amphion)之前执行修复
patch_langsegment_init()
# 克隆Amphion仓库
if not os.path.exists("Amphion"):
subprocess.run(["git", "clone", "https://github.com/open-mmlab/Amphion.git"])
os.chdir("Amphion")
else:
if not os.getcwd().endswith("Amphion"):
os.chdir("Amphion")
# 将Amphion加入到路径中
if os.path.dirname(os.path.abspath("Amphion")) not in sys.path:
sys.path.append(os.path.dirname(os.path.abspath("Amphion")))
# 确保需要的目录存在
os.makedirs("wav", exist_ok=True)
os.makedirs("ckpts/Vevo", exist_ok=True)
from models.vc.vevo.vevo_utils import VevoInferencePipeline, save_audio, load_wav
# 下载和设置配置文件
def setup_configs():
config_path = "models/vc/vevo/config"
os.makedirs(config_path, exist_ok=True)
config_files = [
"PhoneToVq8192.json",
"Vocoder.json",
"Vq32ToVq8192.json",
"Vq8192ToMels.json",
"hubert_large_l18_c32.yaml",
]
for file in config_files:
file_path = f"{config_path}/{file}"
if not os.path.exists(file_path):
try:
file_data = hf_hub_download(
repo_id="amphion/Vevo",
filename=f"config/{file}",
repo_type="model",
)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# 拷贝文件到目标位置
subprocess.run(["cp", file_data, file_path])
except Exception as e:
print(f"下载配置文件 {file} 时出错: {e}")
setup_configs()
# 设备配置
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"使用设备: {device}")
# 初始化管道字典
inference_pipelines = {}
def get_pipeline(pipeline_type):
if pipeline_type in inference_pipelines:
return inference_pipelines[pipeline_type]
# 根据需要的管道类型初始化
if pipeline_type == "style" or pipeline_type == "voice":
# 下载Content Tokenizer
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["tokenizer/vq32/*"],
)
content_tokenizer_ckpt_path = os.path.join(
local_dir, "tokenizer/vq32/hubert_large_l18_c32.pkl"
)
# 下载Content-Style Tokenizer
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["tokenizer/vq8192/*"],
)
content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
# 下载Autoregressive Transformer
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["contentstyle_modeling/Vq32ToVq8192/*"],
)
ar_cfg_path = "./models/vc/vevo/config/Vq32ToVq8192.json"
ar_ckpt_path = os.path.join(local_dir, "contentstyle_modeling/Vq32ToVq8192")
# 下载Flow Matching Transformer
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["acoustic_modeling/Vq8192ToMels/*"],
)
fmt_cfg_path = "./models/vc/vevo/config/Vq8192ToMels.json"
fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels")
# 下载Vocoder
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["acoustic_modeling/Vocoder/*"],
)
vocoder_cfg_path = "./models/vc/vevo/config/Vocoder.json"
vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
# 初始化管道
inference_pipeline = VevoInferencePipeline(
content_tokenizer_ckpt_path=content_tokenizer_ckpt_path,
content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
ar_cfg_path=ar_cfg_path,
ar_ckpt_path=ar_ckpt_path,
fmt_cfg_path=fmt_cfg_path,
fmt_ckpt_path=fmt_ckpt_path,
vocoder_cfg_path=vocoder_cfg_path,
vocoder_ckpt_path=vocoder_ckpt_path,
device=device,
)
elif pipeline_type == "timbre":
# 下载Content-Style Tokenizer (仅timbre需要)
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["tokenizer/vq8192/*"],
)
content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
# 下载Flow Matching Transformer
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["acoustic_modeling/Vq8192ToMels/*"],
)
fmt_cfg_path = "./models/vc/vevo/config/Vq8192ToMels.json"
fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels")
# 下载Vocoder
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["acoustic_modeling/Vocoder/*"],
)
vocoder_cfg_path = "./models/vc/vevo/config/Vocoder.json"
vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
# 初始化管道
inference_pipeline = VevoInferencePipeline(
content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
fmt_cfg_path=fmt_cfg_path,
fmt_ckpt_path=fmt_ckpt_path,
vocoder_cfg_path=vocoder_cfg_path,
vocoder_ckpt_path=vocoder_ckpt_path,
device=device,
)
elif pipeline_type == "tts":
# 下载Content-Style Tokenizer
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["tokenizer/vq8192/*"],
)
content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
# 下载Autoregressive Transformer (TTS特有)
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["contentstyle_modeling/PhoneToVq8192/*"],
)
ar_cfg_path = "./models/vc/vevo/config/PhoneToVq8192.json"
ar_ckpt_path = os.path.join(local_dir, "contentstyle_modeling/PhoneToVq8192")
# 下载Flow Matching Transformer
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["acoustic_modeling/Vq8192ToMels/*"],
)
fmt_cfg_path = "./models/vc/vevo/config/Vq8192ToMels.json"
fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels")
# 下载Vocoder
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir="./ckpts/Vevo",
allow_patterns=["acoustic_modeling/Vocoder/*"],
)
vocoder_cfg_path = "./models/vc/vevo/config/Vocoder.json"
vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
# 初始化管道
inference_pipeline = VevoInferencePipeline(
content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
ar_cfg_path=ar_cfg_path,
ar_ckpt_path=ar_ckpt_path,
fmt_cfg_path=fmt_cfg_path,
fmt_ckpt_path=fmt_ckpt_path,
vocoder_cfg_path=vocoder_cfg_path,
vocoder_ckpt_path=vocoder_ckpt_path,
device=device,
)
# 缓存管道实例
inference_pipelines[pipeline_type] = inference_pipeline
return inference_pipeline
# 实现VEVO功能函数
def vevo_style(content_wav, style_wav):
temp_content_path = "wav/temp_content.wav"
temp_style_path = "wav/temp_style.wav"
output_path = "wav/output_vevostyle.wav"
# 检查并处理音频数据
if content_wav is None or style_wav is None:
raise ValueError("Please upload audio files")
# 处理音频格式
if isinstance(content_wav, tuple) and len(content_wav) == 2:
if isinstance(content_wav[0], np.ndarray):
content_data, content_sr = content_wav
else:
content_sr, content_data = content_wav
# 确保是单声道
if len(content_data.shape) > 1 and content_data.shape[1] > 1:
content_data = np.mean(content_data, axis=1)
# 重采样到24kHz
if content_sr != 24000:
content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
content_tensor = torchaudio.functional.resample(content_tensor, content_sr, 24000)
content_sr = 24000
else:
content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
# 归一化音量
content_tensor = content_tensor / (torch.max(torch.abs(content_tensor)) + 1e-6) * 0.95
else:
raise ValueError("Invalid content audio format")
if isinstance(style_wav, tuple) and len(style_wav) == 2:
# 确保正确的顺序 (data, sample_rate)
if isinstance(style_wav[0], np.ndarray):
style_data, style_sr = style_wav
else:
style_sr, style_data = style_wav
style_tensor = torch.FloatTensor(style_data)
if style_tensor.ndim == 1:
style_tensor = style_tensor.unsqueeze(0) # 添加通道维度
else:
raise ValueError("Invalid style audio format")
# 打印debug信息
print(f"Content audio shape: {content_tensor.shape}, sample rate: {content_sr}")
print(f"Style audio shape: {style_tensor.shape}, sample rate: {style_sr}")
# 保存音频
torchaudio.save(temp_content_path, content_tensor, content_sr)
torchaudio.save(temp_style_path, style_tensor, style_sr)
try:
# 获取管道
pipeline = get_pipeline("style")
# 推理
gen_audio = pipeline.inference_ar_and_fm(
src_wav_path=temp_content_path,
src_text=None,
style_ref_wav_path=temp_style_path,
timbre_ref_wav_path=temp_content_path,
)
# 检查生成音频是否为数值异常
if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
print("Warning: Generated audio contains NaN or Inf values")
gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
# 保存生成的音频
save_audio(gen_audio, output_path=output_path)
return output_path
except Exception as e:
print(f"Error during processing: {e}")
import traceback
traceback.print_exc()
raise e
def vevo_timbre(content_wav, reference_wav):
temp_content_path = "wav/temp_content.wav"
temp_reference_path = "wav/temp_reference.wav"
output_path = "wav/output_vevotimbre.wav"
# 检查并处理音频数据
if content_wav is None or reference_wav is None:
raise ValueError("Please upload audio files")
# 处理内容音频格式
if isinstance(content_wav, tuple) and len(content_wav) == 2:
if isinstance(content_wav[0], np.ndarray):
content_data, content_sr = content_wav
else:
content_sr, content_data = content_wav
# 确保是单声道
if len(content_data.shape) > 1 and content_data.shape[1] > 1:
content_data = np.mean(content_data, axis=1)
# 重采样到24kHz
if content_sr != 24000:
content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
content_tensor = torchaudio.functional.resample(content_tensor, content_sr, 24000)
content_sr = 24000
else:
content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
# 归一化音量
content_tensor = content_tensor / (torch.max(torch.abs(content_tensor)) + 1e-6) * 0.95
else:
raise ValueError("Invalid content audio format")
# 处理参考音频格式
if isinstance(reference_wav, tuple) and len(reference_wav) == 2:
if isinstance(reference_wav[0], np.ndarray):
reference_data, reference_sr = reference_wav
else:
reference_sr, reference_data = reference_wav
# 确保是单声道
if len(reference_data.shape) > 1 and reference_data.shape[1] > 1:
reference_data = np.mean(reference_data, axis=1)
# 重采样到24kHz
if reference_sr != 24000:
reference_tensor = torch.FloatTensor(reference_data).unsqueeze(0)
reference_tensor = torchaudio.functional.resample(reference_tensor, reference_sr, 24000)
reference_sr = 24000
else:
reference_tensor = torch.FloatTensor(reference_data).unsqueeze(0)
# 归一化音量
reference_tensor = reference_tensor / (torch.max(torch.abs(reference_tensor)) + 1e-6) * 0.95
else:
raise ValueError("Invalid reference audio format")
# 打印debug信息
print(f"Content audio shape: {content_tensor.shape}, sample rate: {content_sr}")
print(f"Reference audio shape: {reference_tensor.shape}, sample rate: {reference_sr}")
# 保存上传的音频
torchaudio.save(temp_content_path, content_tensor, content_sr)
torchaudio.save(temp_reference_path, reference_tensor, reference_sr)
try:
# 获取管道
pipeline = get_pipeline("timbre")
# 推理
gen_audio = pipeline.inference_fm(
src_wav_path=temp_content_path,
timbre_ref_wav_path=temp_reference_path,
flow_matching_steps=32,
)
# 检查生成音频是否为数值异常
if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
print("Warning: Generated audio contains NaN or Inf values")
gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
# 保存生成的音频
save_audio(gen_audio, output_path=output_path)
return output_path
except Exception as e:
print(f"Error during processing: {e}")
import traceback
traceback.print_exc()
raise e
def vevo_voice(content_wav, style_reference_wav, timbre_reference_wav):
temp_content_path = "wav/temp_content.wav"
temp_style_path = "wav/temp_style.wav"
temp_timbre_path = "wav/temp_timbre.wav"
output_path = "wav/output_vevovoice.wav"
# 检查并处理音频数据
if content_wav is None or style_reference_wav is None or timbre_reference_wav is None:
raise ValueError("Please upload all required audio files")
# 处理内容音频格式
if isinstance(content_wav, tuple) and len(content_wav) == 2:
if isinstance(content_wav[0], np.ndarray):
content_data, content_sr = content_wav
else:
content_sr, content_data = content_wav
# 确保是单声道
if len(content_data.shape) > 1 and content_data.shape[1] > 1:
content_data = np.mean(content_data, axis=1)
# 重采样到24kHz
if content_sr != 24000:
content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
content_tensor = torchaudio.functional.resample(content_tensor, content_sr, 24000)
content_sr = 24000
else:
content_tensor = torch.FloatTensor(content_data).unsqueeze(0)
# 归一化音量
content_tensor = content_tensor / (torch.max(torch.abs(content_tensor)) + 1e-6) * 0.95
else:
raise ValueError("Invalid content audio format")
# 处理风格参考音频格式
if isinstance(style_reference_wav, tuple) and len(style_reference_wav) == 2:
if isinstance(style_reference_wav[0], np.ndarray):
style_data, style_sr = style_reference_wav
else:
style_sr, style_data = style_reference_wav
# 确保是单声道
if len(style_data.shape) > 1 and style_data.shape[1] > 1:
style_data = np.mean(style_data, axis=1)
# 重采样到24kHz
if style_sr != 24000:
style_tensor = torch.FloatTensor(style_data).unsqueeze(0)
style_tensor = torchaudio.functional.resample(style_tensor, style_sr, 24000)
style_sr = 24000
else:
style_tensor = torch.FloatTensor(style_data).unsqueeze(0)
# 归一化音量
style_tensor = style_tensor / (torch.max(torch.abs(style_tensor)) + 1e-6) * 0.95
else:
raise ValueError("Invalid style reference audio format")
# 处理音色参考音频格式
if isinstance(timbre_reference_wav, tuple) and len(timbre_reference_wav) == 2:
if isinstance(timbre_reference_wav[0], np.ndarray):
timbre_data, timbre_sr = timbre_reference_wav
else:
timbre_sr, timbre_data = timbre_reference_wav
# 确保是单声道
if len(timbre_data.shape) > 1 and timbre_data.shape[1] > 1:
timbre_data = np.mean(timbre_data, axis=1)
# 重采样到24kHz
if timbre_sr != 24000:
timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
timbre_tensor = torchaudio.functional.resample(timbre_tensor, timbre_sr, 24000)
timbre_sr = 24000
else:
timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
# 归一化音量
timbre_tensor = timbre_tensor / (torch.max(torch.abs(timbre_tensor)) + 1e-6) * 0.95
else:
raise ValueError("Invalid timbre reference audio format")
# 打印debug信息
print(f"Content audio shape: {content_tensor.shape}, sample rate: {content_sr}")
print(f"Style reference audio shape: {style_tensor.shape}, sample rate: {style_sr}")
print(f"Timbre reference audio shape: {timbre_tensor.shape}, sample rate: {timbre_sr}")
# 保存上传的音频
torchaudio.save(temp_content_path, content_tensor, content_sr)
torchaudio.save(temp_style_path, style_tensor, style_sr)
torchaudio.save(temp_timbre_path, timbre_tensor, timbre_sr)
try:
# 获取管道
pipeline = get_pipeline("voice")
# 推理
gen_audio = pipeline.inference_ar_and_fm(
src_wav_path=temp_content_path,
src_text=None,
style_ref_wav_path=temp_style_path,
timbre_ref_wav_path=temp_timbre_path,
)
# 检查生成音频是否为数值异常
if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
print("Warning: Generated audio contains NaN or Inf values")
gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
# 保存生成的音频
save_audio(gen_audio, output_path=output_path)
return output_path
except Exception as e:
print(f"Error during processing: {e}")
import traceback
traceback.print_exc()
raise e
def vevo_tts(text, ref_wav, timbre_ref_wav=None, style_ref_text=None, src_language="en", ref_language="en", style_ref_text_language="en"):
temp_ref_path = "wav/temp_ref.wav"
temp_timbre_path = "wav/temp_timbre.wav"
output_path = "wav/output_vevotts.wav"
# 检查并处理音频数据
if ref_wav is None:
raise ValueError("Please upload a reference audio file")
# 处理参考音频格式
if isinstance(ref_wav, tuple) and len(ref_wav) == 2:
if isinstance(ref_wav[0], np.ndarray):
ref_data, ref_sr = ref_wav
else:
ref_sr, ref_data = ref_wav
# 确保是单声道
if len(ref_data.shape) > 1 and ref_data.shape[1] > 1:
ref_data = np.mean(ref_data, axis=1)
# 重采样到24kHz
if ref_sr != 24000:
ref_tensor = torch.FloatTensor(ref_data).unsqueeze(0)
ref_tensor = torchaudio.functional.resample(ref_tensor, ref_sr, 24000)
ref_sr = 24000
else:
ref_tensor = torch.FloatTensor(ref_data).unsqueeze(0)
# 归一化音量
ref_tensor = ref_tensor / (torch.max(torch.abs(ref_tensor)) + 1e-6) * 0.95
else:
raise ValueError("Invalid reference audio format")
# 打印debug信息
print(f"Reference audio shape: {ref_tensor.shape}, sample rate: {ref_sr}")
if style_ref_text:
print(f"Style reference text: {style_ref_text}, language: {style_ref_text_language}")
# 保存上传的音频
torchaudio.save(temp_ref_path, ref_tensor, ref_sr)
if timbre_ref_wav is not None:
if isinstance(timbre_ref_wav, tuple) and len(timbre_ref_wav) == 2:
if isinstance(timbre_ref_wav[0], np.ndarray):
timbre_data, timbre_sr = timbre_ref_wav
else:
timbre_sr, timbre_data = timbre_ref_wav
# 确保是单声道
if len(timbre_data.shape) > 1 and timbre_data.shape[1] > 1:
timbre_data = np.mean(timbre_data, axis=1)
# 重采样到24kHz
if timbre_sr != 24000:
timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
timbre_tensor = torchaudio.functional.resample(timbre_tensor, timbre_sr, 24000)
timbre_sr = 24000
else:
timbre_tensor = torch.FloatTensor(timbre_data).unsqueeze(0)
# 归一化音量
timbre_tensor = timbre_tensor / (torch.max(torch.abs(timbre_tensor)) + 1e-6) * 0.95
print(f"Timbre reference audio shape: {timbre_tensor.shape}, sample rate: {timbre_sr}")
torchaudio.save(temp_timbre_path, timbre_tensor, timbre_sr)
else:
raise ValueError("Invalid timbre reference audio format")
else:
temp_timbre_path = temp_ref_path
try:
# 获取管道
pipeline = get_pipeline("tts")
# 推理
gen_audio = pipeline.inference_ar_and_fm(
src_wav_path=None,
src_text=text,
style_ref_wav_path=temp_ref_path,
timbre_ref_wav_path=temp_timbre_path,
style_ref_wav_text=style_ref_text,
src_text_language=src_language,
style_ref_wav_text_language=style_ref_text_language,
)
# 检查生成音频是否为数值异常
if torch.isnan(gen_audio).any() or torch.isinf(gen_audio).any():
print("Warning: Generated audio contains NaN or Inf values")
gen_audio = torch.nan_to_num(gen_audio, nan=0.0, posinf=0.95, neginf=-0.95)
print(f"Generated audio shape: {gen_audio.shape}, max: {torch.max(gen_audio)}, min: {torch.min(gen_audio)}")
# 保存生成的音频
save_audio(gen_audio, output_path=output_path)
return output_path
except Exception as e:
print(f"Error during processing: {e}")
import traceback
traceback.print_exc()
raise e
# 创建Gradio界面
with gr.Blocks(title="Vevo DEMO") as demo:
gr.Markdown("# Vevo DEMO")
# 添加链接标签行
with gr.Row(elem_id="links_row"):
gr.HTML("""
<div style="display: flex; justify-content: flex-start; gap: 8px; margin: 0 0; padding-left: 0px;">
<a href="https://arxiv.org/abs/2502.07243" target="_blank" style="text-decoration: none;">
<img alt="arXiv Paper" src="https://img.shields.io/badge/arXiv-Paper-red">
</a>
<a href="https://openreview.net/pdf?id=anQDiQZhDP" target="_blank" style="text-decoration: none;">
<img alt="ICLR Paper" src="https://img.shields.io/badge/ICLR-Paper-64b63a">
</a>
<a href="https://huggingface.co./amphion/Vevo" target="_blank" style="text-decoration: none;">
<img alt="HuggingFace Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Model-yellow">
</a>
<a href="https://github.com/open-mmlab/Amphion/tree/main/models/vc/vevo" target="_blank" style="text-decoration: none;">
<img alt="GitHub Repo" src="https://img.shields.io/badge/GitHub-Repo-blue">
</a>
</div>
""")
with gr.Tab("Vevo-Timbre"):
gr.Markdown("### Vevo-Timbre: Maintain style but transfer timbre")
with gr.Row():
with gr.Column():
timbre_content = gr.Audio(label="Source Audio", type="numpy")
timbre_reference = gr.Audio(label="Timbre Reference", type="numpy")
timbre_button = gr.Button("Generate")
with gr.Column():
timbre_output = gr.Audio(label="Result")
timbre_button.click(vevo_timbre, inputs=[timbre_content, timbre_reference], outputs=timbre_output)
with gr.Tab("Vevo-Style"):
gr.Markdown("### Vevo-Style: Maintain timbre but transfer style (accent, emotion, etc.)")
with gr.Row():
with gr.Column():
style_content = gr.Audio(label="Source Audio", type="numpy")
style_reference = gr.Audio(label="Style Reference", type="numpy")
style_button = gr.Button("Generate")
with gr.Column():
style_output = gr.Audio(label="Result")
style_button.click(vevo_style, inputs=[style_content, style_reference], outputs=style_output)
with gr.Tab("Vevo-Voice"):
gr.Markdown("### Vevo-Voice: Transfers both style and timbre with separate references")
with gr.Row():
with gr.Column():
voice_content = gr.Audio(label="Source Audio", type="numpy")
voice_style_reference = gr.Audio(label="Style Reference", type="numpy")
voice_timbre_reference = gr.Audio(label="Timbre Reference", type="numpy")
voice_button = gr.Button("Generate")
with gr.Column():
voice_output = gr.Audio(label="Result")
voice_button.click(vevo_voice, inputs=[voice_content, voice_style_reference, voice_timbre_reference], outputs=voice_output)
with gr.Tab("Vevo-TTS"):
gr.Markdown("### Vevo-TTS: Text-to-speech with separate style and timbre references")
with gr.Row():
with gr.Column():
tts_text = gr.Textbox(label="Target Text", placeholder="Enter text to synthesize...", lines=3)
tts_src_language = gr.Dropdown(["en", "zh", "de", "fr", "ja", "ko"], label="Text Language", value="en")
tts_reference = gr.Audio(label="Style Reference", type="numpy")
tts_style_ref_text = gr.Textbox(label="Style Reference Text", placeholder="Enter style reference text...", lines=3)
tts_style_ref_text_language = gr.Dropdown(["en", "zh", "de", "fr", "ja", "ko"], label="Style Reference Text Language", value="en")
tts_timbre_reference = gr.Audio(label="Timbre Reference", type="numpy")
tts_button = gr.Button("Generate")
with gr.Column():
tts_output = gr.Audio(label="Result")
tts_button.click(
vevo_tts,
inputs=[tts_text, tts_reference, tts_timbre_reference, tts_style_ref_text, tts_src_language, tts_style_ref_text_language],
outputs=tts_output
)
gr.Markdown("""
## About VEVO
VEVO is a versatile voice synthesis and conversion model that offers four main functionalities:
1. **Vevo-Style**: Maintains timbre but transfers style (accent, emotion, etc.)
2. **Vevo-Timbre**: Maintains style but transfers timbre
3. **Vevo-Voice**: Transfers both style and timbre with separate references
4. **Vevo-TTS**: Text-to-speech with separate style and timbre references
For more information, visit the [Amphion project](https://github.com/open-mmlab/Amphion)
""")
# 启动应用
demo.launch()