naonauno commited on
Commit
2bc810c
·
verified ·
1 Parent(s): d4d1cf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -66
app.py CHANGED
@@ -10,13 +10,18 @@ import Amphion.models.vc.vevo.vevo_utils as vevo_utils
10
  from huggingface_hub import snapshot_download
11
 
12
  def load_model():
 
13
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
 
 
 
 
14
 
15
  # Content Tokenizer
16
  local_dir = snapshot_download(
17
  repo_id="amphion/Vevo",
18
  repo_type="model",
19
- cache_dir="./ckpts/Vevo",
20
  allow_patterns=["tokenizer/vq32/*"],
21
  )
22
  content_tokenizer_ckpt_path = os.path.join(
@@ -27,7 +32,7 @@ def load_model():
27
  local_dir = snapshot_download(
28
  repo_id="amphion/Vevo",
29
  repo_type="model",
30
- cache_dir="./ckpts/Vevo",
31
  allow_patterns=["tokenizer/vq8192/*"],
32
  )
33
  content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
@@ -36,7 +41,7 @@ def load_model():
36
  local_dir = snapshot_download(
37
  repo_id="amphion/Vevo",
38
  repo_type="model",
39
- cache_dir="./ckpts/Vevo",
40
  allow_patterns=["contentstyle_modeling/Vq32ToVq8192/*"],
41
  )
42
  ar_cfg_path = "./config/Vq32ToVq8192.json"
@@ -46,7 +51,7 @@ def load_model():
46
  local_dir = snapshot_download(
47
  repo_id="amphion/Vevo",
48
  repo_type="model",
49
- cache_dir="./ckpts/Vevo",
50
  allow_patterns=["acoustic_modeling/Vq8192ToMels/*"],
51
  )
52
  fmt_cfg_path = "./config/Vq8192ToMels.json"
@@ -56,12 +61,13 @@ def load_model():
56
  local_dir = snapshot_download(
57
  repo_id="amphion/Vevo",
58
  repo_type="model",
59
- cache_dir="./ckpts/Vevo",
60
  allow_patterns=["acoustic_modeling/Vocoder/*"],
61
  )
62
  vocoder_cfg_path = "./Amphion/models/vc/vevo/config/Vocoder.json"
63
  vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
64
 
 
65
  pipeline = vevo_utils.VevoInferencePipeline(
66
  content_tokenizer_ckpt_path=content_tokenizer_ckpt_path,
67
  content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
@@ -73,6 +79,7 @@ def load_model():
73
  vocoder_ckpt_path=vocoder_ckpt_path,
74
  device=device
75
  )
 
76
  return pipeline
77
 
78
  def convert_to_wav(audio_path):
@@ -94,6 +101,10 @@ def process_audio(mode, content_audio, ref_style_audio, ref_timbre_audio,
94
  src_text, ref_text, src_language, ref_language, steps,
95
  progress=gr.Progress()):
96
  try:
 
 
 
 
97
  # Convert uploaded audio files to WAV if needed
98
  if content_audio:
99
  content_path = convert_to_wav(content_audio)
@@ -110,10 +121,12 @@ def process_audio(mode, content_audio, ref_style_audio, ref_timbre_audio,
110
  else:
111
  ref_timbre_path = None
112
 
 
 
113
  # Run inference based on mode
114
  if mode == 'voice':
115
  if not all([content_path, ref_style_path, ref_timbre_path]):
116
- raise ValueError("Voice mode requires all audio inputs")
117
 
118
  gen_audio = inference_pipeline.inference_ar_and_fm(
119
  src_wav_path=content_path,
@@ -125,7 +138,7 @@ def process_audio(mode, content_audio, ref_style_audio, ref_timbre_audio,
125
 
126
  elif mode == 'timbre':
127
  if not all([content_path, ref_timbre_path]):
128
- raise ValueError("Timbre mode requires source and timbre reference audio")
129
 
130
  gen_audio = inference_pipeline.inference_fm(
131
  src_wav_path=content_path,
@@ -134,8 +147,8 @@ def process_audio(mode, content_audio, ref_style_audio, ref_timbre_audio,
134
  )
135
 
136
  elif mode == 'tts':
137
- if not all([ref_style_path, ref_timbre_path, src_text]):
138
- raise ValueError("TTS mode requires style audio, timbre audio, and source text")
139
 
140
  gen_audio = inference_pipeline.inference_ar_and_fm(
141
  src_wav_path=None,
@@ -147,18 +160,17 @@ def process_audio(mode, content_audio, ref_style_audio, ref_timbre_audio,
147
  style_ref_wav_text_language=ref_language
148
  )
149
 
 
 
150
  # Save and return the generated audio
151
- output_path = "output.wav"
152
  vevo_utils.save_audio(gen_audio, target_sample_rate=48000, output_path=output_path)
153
  return output_path
154
-
155
  except Exception as e:
156
  raise gr.Error(str(e))
157
 
158
  # Initialize the model
159
- print("Loading model...")
160
  inference_pipeline = load_model()
161
- print("Model loaded successfully!")
162
 
163
  # Create the Gradio interface
164
  with gr.Blocks(title="Vevo Voice Conversion") as demo:
@@ -168,52 +180,58 @@ with gr.Blocks(title="Vevo Voice Conversion") as demo:
168
  mode = gr.Radio(
169
  choices=["voice", "timbre", "tts"],
170
  value="timbre",
171
- label="Inference Mode"
 
172
  )
173
 
174
  with gr.Row():
175
  with gr.Column():
176
- content_audio = gr.Audio(
177
- label="Source Audio",
178
- type="filepath"
179
- )
180
-
181
- ref_style_audio = gr.Audio(
182
- label="Reference Style Audio",
183
- type="filepath"
184
- )
185
-
186
- ref_timbre_audio = gr.Audio(
187
- label="Reference Timbre Audio",
188
- type="filepath"
189
- )
 
 
 
 
190
 
191
  with gr.Column():
192
- src_text = gr.Textbox(
193
- label="Source Text",
194
- placeholder="Enter text for TTS mode",
195
- visible=False
196
- )
197
-
198
- ref_text = gr.Textbox(
199
- label="Reference Style Text",
200
- placeholder="Optional: Enter reference text",
201
- visible=False
202
- )
203
-
204
- src_language = gr.Dropdown(
205
- choices=["en", "zh"],
206
- value="en",
207
- label="Source Language",
208
- visible=False
209
- )
210
-
211
- ref_language = gr.Dropdown(
212
- choices=["en", "zh"],
213
- value="en",
214
- label="Reference Language",
215
- visible=False
216
- )
 
217
 
218
  with gr.Row():
219
  steps = gr.Slider(
@@ -229,24 +247,18 @@ with gr.Blocks(title="Vevo Voice Conversion") as demo:
229
  output_audio = gr.Audio(label="Generated Audio")
230
 
231
  # Handle visibility of components based on mode
232
- def update_visibility(mode):
233
  is_tts = mode == "tts"
234
- is_voice = mode == "voice"
235
- is_timbre = mode == "timbre"
236
-
237
  return {
238
- content_audio: not is_tts,
239
- ref_style_audio: not is_timbre,
240
- src_text: is_tts,
241
- ref_text: is_tts,
242
- src_language: is_tts,
243
- ref_language: is_tts
244
  }
245
 
246
  mode.change(
247
- fn=update_visibility,
248
  inputs=[mode],
249
- outputs=[content_audio, ref_style_audio, src_text, ref_text, src_language, ref_language]
250
  )
251
 
252
  # Handle generation
@@ -267,4 +279,4 @@ with gr.Blocks(title="Vevo Voice Conversion") as demo:
267
  )
268
 
269
  if __name__ == "__main__":
270
- demo.launch()
 
10
  from huggingface_hub import snapshot_download
11
 
12
  def load_model():
13
+ print("Loading model...")
14
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
15
+ print(f"Using device: {device}")
16
+
17
+ cache_dir = "./ckpts/Vevo"
18
+ os.makedirs(cache_dir, exist_ok=True)
19
 
20
  # Content Tokenizer
21
  local_dir = snapshot_download(
22
  repo_id="amphion/Vevo",
23
  repo_type="model",
24
+ cache_dir=cache_dir,
25
  allow_patterns=["tokenizer/vq32/*"],
26
  )
27
  content_tokenizer_ckpt_path = os.path.join(
 
32
  local_dir = snapshot_download(
33
  repo_id="amphion/Vevo",
34
  repo_type="model",
35
+ cache_dir=cache_dir,
36
  allow_patterns=["tokenizer/vq8192/*"],
37
  )
38
  content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
 
41
  local_dir = snapshot_download(
42
  repo_id="amphion/Vevo",
43
  repo_type="model",
44
+ cache_dir=cache_dir,
45
  allow_patterns=["contentstyle_modeling/Vq32ToVq8192/*"],
46
  )
47
  ar_cfg_path = "./config/Vq32ToVq8192.json"
 
51
  local_dir = snapshot_download(
52
  repo_id="amphion/Vevo",
53
  repo_type="model",
54
+ cache_dir=cache_dir,
55
  allow_patterns=["acoustic_modeling/Vq8192ToMels/*"],
56
  )
57
  fmt_cfg_path = "./config/Vq8192ToMels.json"
 
61
  local_dir = snapshot_download(
62
  repo_id="amphion/Vevo",
63
  repo_type="model",
64
+ cache_dir=cache_dir,
65
  allow_patterns=["acoustic_modeling/Vocoder/*"],
66
  )
67
  vocoder_cfg_path = "./Amphion/models/vc/vevo/config/Vocoder.json"
68
  vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
69
 
70
+ print("Initializing pipeline...")
71
  pipeline = vevo_utils.VevoInferencePipeline(
72
  content_tokenizer_ckpt_path=content_tokenizer_ckpt_path,
73
  content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
 
79
  vocoder_ckpt_path=vocoder_ckpt_path,
80
  device=device
81
  )
82
+ print("Model loaded successfully!")
83
  return pipeline
84
 
85
  def convert_to_wav(audio_path):
 
101
  src_text, ref_text, src_language, ref_language, steps,
102
  progress=gr.Progress()):
103
  try:
104
+ output_dir = "outputs"
105
+ os.makedirs(output_dir, exist_ok=True)
106
+ output_path = os.path.join(output_dir, "output.wav")
107
+
108
  # Convert uploaded audio files to WAV if needed
109
  if content_audio:
110
  content_path = convert_to_wav(content_audio)
 
121
  else:
122
  ref_timbre_path = None
123
 
124
+ progress(0.2, "Processing audio...")
125
+
126
  # Run inference based on mode
127
  if mode == 'voice':
128
  if not all([content_path, ref_style_path, ref_timbre_path]):
129
+ raise gr.Error("Voice mode requires all audio inputs")
130
 
131
  gen_audio = inference_pipeline.inference_ar_and_fm(
132
  src_wav_path=content_path,
 
138
 
139
  elif mode == 'timbre':
140
  if not all([content_path, ref_timbre_path]):
141
+ raise gr.Error("Timbre mode requires source and timbre reference audio")
142
 
143
  gen_audio = inference_pipeline.inference_fm(
144
  src_wav_path=content_path,
 
147
  )
148
 
149
  elif mode == 'tts':
150
+ if not all([ref_style_path, ref_timbre_path]) or not src_text:
151
+ raise gr.Error("TTS mode requires style audio, timbre audio, and source text")
152
 
153
  gen_audio = inference_pipeline.inference_ar_and_fm(
154
  src_wav_path=None,
 
160
  style_ref_wav_text_language=ref_language
161
  )
162
 
163
+ progress(0.8, "Saving generated audio...")
164
+
165
  # Save and return the generated audio
 
166
  vevo_utils.save_audio(gen_audio, target_sample_rate=48000, output_path=output_path)
167
  return output_path
168
+
169
  except Exception as e:
170
  raise gr.Error(str(e))
171
 
172
  # Initialize the model
 
173
  inference_pipeline = load_model()
 
174
 
175
  # Create the Gradio interface
176
  with gr.Blocks(title="Vevo Voice Conversion") as demo:
 
180
  mode = gr.Radio(
181
  choices=["voice", "timbre", "tts"],
182
  value="timbre",
183
+ label="Inference Mode",
184
+ interactive=True
185
  )
186
 
187
  with gr.Row():
188
  with gr.Column():
189
+ with gr.Group(visible=True) as audio_inputs:
190
+ content_audio = gr.Audio(
191
+ label="Source Audio",
192
+ type="filepath",
193
+ interactive=True
194
+ )
195
+
196
+ ref_style_audio = gr.Audio(
197
+ label="Reference Style Audio",
198
+ type="filepath",
199
+ interactive=True
200
+ )
201
+
202
+ ref_timbre_audio = gr.Audio(
203
+ label="Reference Timbre Audio",
204
+ type="filepath",
205
+ interactive=True
206
+ )
207
 
208
  with gr.Column():
209
+ with gr.Group(visible=False) as text_inputs:
210
+ src_text = gr.Textbox(
211
+ label="Source Text",
212
+ placeholder="Enter text for TTS mode",
213
+ interactive=True
214
+ )
215
+
216
+ ref_text = gr.Textbox(
217
+ label="Reference Style Text",
218
+ placeholder="Optional: Enter reference text",
219
+ interactive=True
220
+ )
221
+
222
+ src_language = gr.Dropdown(
223
+ choices=["en", "zh"],
224
+ value="en",
225
+ label="Source Language",
226
+ interactive=True
227
+ )
228
+
229
+ ref_language = gr.Dropdown(
230
+ choices=["en", "zh"],
231
+ value="en",
232
+ label="Reference Language",
233
+ interactive=True
234
+ )
235
 
236
  with gr.Row():
237
  steps = gr.Slider(
 
247
  output_audio = gr.Audio(label="Generated Audio")
248
 
249
  # Handle visibility of components based on mode
250
+ def update_interface(mode):
251
  is_tts = mode == "tts"
 
 
 
252
  return {
253
+ audio_inputs: not is_tts,
254
+ text_inputs: is_tts,
255
+ ref_style_audio: mode != "timbre",
 
 
 
256
  }
257
 
258
  mode.change(
259
+ fn=update_interface,
260
  inputs=[mode],
261
+ outputs=[audio_inputs, text_inputs, ref_style_audio]
262
  )
263
 
264
  # Handle generation
 
279
  )
280
 
281
  if __name__ == "__main__":
282
+ demo.queue().launch()