nevreal commited on
Commit
3f2313e
·
verified ·
1 Parent(s): 15a42b1

Delete api_231006.py

Browse files
Files changed (1) hide show
  1. api_231006.py +0 -440
api_231006.py DELETED
@@ -1,440 +0,0 @@
1
- #api for 231006 release version by Xiaokai
2
- import os
3
- import sys
4
- import json
5
- import re
6
- import time
7
- import librosa
8
- import torch
9
- import numpy as np
10
- import torch.nn.functional as F
11
- import torchaudio.transforms as tat
12
- import sounddevice as sd
13
- from dotenv import load_dotenv
14
- from fastapi import FastAPI, HTTPException
15
- from pydantic import BaseModel
16
- import threading
17
- import uvicorn
18
- import logging
19
-
20
- # Initialize the logger
21
- logging.basicConfig(level=logging.INFO)
22
- logger = logging.getLogger(__name__)
23
-
24
- # Define FastAPI app
25
- app = FastAPI()
26
-
27
- class GUIConfig:
28
- def __init__(self) -> None:
29
- self.pth_path: str = ""
30
- self.index_path: str = ""
31
- self.pitch: int = 0
32
- self.samplerate: int = 40000
33
- self.block_time: float = 1.0 # s
34
- self.buffer_num: int = 1
35
- self.threhold: int = -60
36
- self.crossfade_time: float = 0.05
37
- self.extra_time: float = 2.5
38
- self.I_noise_reduce = False
39
- self.O_noise_reduce = False
40
- self.rms_mix_rate = 0.0
41
- self.index_rate = 0.3
42
- self.f0method = "rmvpe"
43
- self.sg_input_device = ""
44
- self.sg_output_device = ""
45
-
46
- class ConfigData(BaseModel):
47
- pth_path: str
48
- index_path: str
49
- sg_input_device: str
50
- sg_output_device: str
51
- threhold: int = -60
52
- pitch: int = 0
53
- index_rate: float = 0.3
54
- rms_mix_rate: float = 0.0
55
- block_time: float = 0.25
56
- crossfade_length: float = 0.05
57
- extra_time: float = 2.5
58
- n_cpu: int = 4
59
- I_noise_reduce: bool = False
60
- O_noise_reduce: bool = False
61
-
62
- class AudioAPI:
63
- def __init__(self) -> None:
64
- self.gui_config = GUIConfig()
65
- self.config = None # Initialize Config object as None
66
- self.flag_vc = False
67
- self.function = "vc"
68
- self.delay_time = 0
69
- self.rvc = None # Initialize RVC object as None
70
-
71
- def load(self):
72
- input_devices, output_devices, _, _ = self.get_devices()
73
- try:
74
- with open("configs/config.json", "r", encoding='utf-8') as j:
75
- data = json.load(j)
76
- data["rmvpe"] = True # Ensure rmvpe is the only f0method
77
- if data["sg_input_device"] not in input_devices:
78
- data["sg_input_device"] = input_devices[sd.default.device[0]]
79
- if data["sg_output_device"] not in output_devices:
80
- data["sg_output_device"] = output_devices[sd.default.device[1]]
81
- except Exception as e:
82
- logger.error(f"Failed to load configuration: {e}")
83
- with open("configs/config.json", "w", encoding='utf-8') as j:
84
- data = {
85
- "pth_path": " ",
86
- "index_path": " ",
87
- "sg_input_device": input_devices[sd.default.device[0]],
88
- "sg_output_device": output_devices[sd.default.device[1]],
89
- "threhold": "-60",
90
- "pitch": "0",
91
- "index_rate": "0",
92
- "rms_mix_rate": "0",
93
- "block_time": "0.25",
94
- "crossfade_length": "0.05",
95
- "extra_time": "2.5",
96
- "f0method": "rmvpe",
97
- "use_jit": False,
98
- }
99
- data["rmvpe"] = True # Ensure rmvpe is the only f0method
100
- json.dump(data, j, ensure_ascii=False)
101
- return data
102
-
103
- def set_values(self, values):
104
- logger.info(f"Setting values: {values}")
105
- if not values.pth_path.strip():
106
- raise HTTPException(status_code=400, detail="Please select a .pth file")
107
- if not values.index_path.strip():
108
- raise HTTPException(status_code=400, detail="Please select an index file")
109
- self.set_devices(values.sg_input_device, values.sg_output_device)
110
- self.config.use_jit = False
111
- self.gui_config.pth_path = values.pth_path
112
- self.gui_config.index_path = values.index_path
113
- self.gui_config.threhold = values.threhold
114
- self.gui_config.pitch = values.pitch
115
- self.gui_config.block_time = values.block_time
116
- self.gui_config.crossfade_time = values.crossfade_length
117
- self.gui_config.extra_time = values.extra_time
118
- self.gui_config.I_noise_reduce = values.I_noise_reduce
119
- self.gui_config.O_noise_reduce = values.O_noise_reduce
120
- self.gui_config.rms_mix_rate = values.rms_mix_rate
121
- self.gui_config.index_rate = values.index_rate
122
- self.gui_config.n_cpu = values.n_cpu
123
- self.gui_config.f0method = "rmvpe"
124
- return True
125
-
126
- def start_vc(self):
127
- torch.cuda.empty_cache()
128
- self.flag_vc = True
129
- self.rvc = rvc_for_realtime.RVC(
130
- self.gui_config.pitch,
131
- self.gui_config.pth_path,
132
- self.gui_config.index_path,
133
- self.gui_config.index_rate,
134
- 0,
135
- 0,
136
- 0,
137
- self.config,
138
- self.rvc if self.rvc else None,
139
- )
140
- self.gui_config.samplerate = self.rvc.tgt_sr
141
- self.zc = self.rvc.tgt_sr // 100
142
- self.block_frame = (
143
- int(
144
- np.round(
145
- self.gui_config.block_time
146
- * self.gui_config.samplerate
147
- / self.zc
148
- )
149
- )
150
- * self.zc
151
- )
152
- self.block_frame_16k = 160 * self.block_frame // self.zc
153
- self.crossfade_frame = (
154
- int(
155
- np.round(
156
- self.gui_config.crossfade_time
157
- * self.gui_config.samplerate
158
- / self.zc
159
- )
160
- )
161
- * self.zc
162
- )
163
- self.sola_search_frame = self.zc
164
- self.extra_frame = (
165
- int(
166
- np.round(
167
- self.gui_config.extra_time
168
- * self.gui_config.samplerate
169
- / self.zc
170
- )
171
- )
172
- * self.zc
173
- )
174
- self.input_wav = torch.zeros(
175
- self.extra_frame + self.crossfade_frame + self.sola_search_frame + self.block_frame,
176
- device=self.config.device,
177
- dtype=torch.float32,
178
- )
179
- self.input_wav_res = torch.zeros(
180
- 160 * self.input_wav.shape[0] // self.zc,
181
- device=self.config.device,
182
- dtype=torch.float32,
183
- )
184
- self.pitch = np.zeros(self.input_wav.shape[0] // self.zc, dtype="int32")
185
- self.pitchf = np.zeros(self.input_wav.shape[0] // self.zc, dtype="float64")
186
- self.sola_buffer = torch.zeros(self.crossfade_frame, device=self.config.device, dtype=torch.float32)
187
- self.nr_buffer = self.sola_buffer.clone()
188
- self.output_buffer = self.input_wav.clone()
189
- self.res_buffer = torch.zeros(2 * self.zc, device=self.config.device, dtype=torch.float32)
190
- self.valid_rate = 1 - (self.extra_frame - 1) / self.input_wav.shape[0]
191
- self.fade_in_window = (
192
- torch.sin(0.5 * np.pi * torch.linspace(0.0, 1.0, steps=self.crossfade_frame, device=self.config.device, dtype=torch.float32)) ** 2
193
- )
194
- self.fade_out_window = 1 - self.fade_in_window
195
- self.resampler = tat.Resample(
196
- orig_freq=self.gui_config.samplerate,
197
- new_freq=16000,
198
- dtype=torch.float32,
199
- ).to(self.config.device)
200
- self.tg = TorchGate(
201
- sr=self.gui_config.samplerate, n_fft=4 * self.zc, prop_decrease=0.9
202
- ).to(self.config.device)
203
- thread_vc = threading.Thread(target=self.soundinput)
204
- thread_vc.start()
205
-
206
- def soundinput(self):
207
- channels = 1 if sys.platform == "darwin" else 2
208
- with sd.Stream(
209
- channels=channels,
210
- callback=self.audio_callback,
211
- blocksize=self.block_frame,
212
- samplerate=self.gui_config.samplerate,
213
- dtype="float32",
214
- ) as stream:
215
- global stream_latency
216
- stream_latency = stream.latency[-1]
217
- while self.flag_vc:
218
- time.sleep(self.gui_config.block_time)
219
- logger.info("Audio block passed.")
220
- logger.info("Ending VC")
221
-
222
- def audio_callback(self, indata: np.ndarray, outdata: np.ndarray, frames, times, status):
223
- start_time = time.perf_counter()
224
- indata = librosa.to_mono(indata.T)
225
- if self.gui_config.threhold > -60:
226
- rms = librosa.feature.rms(y=indata, frame_length=4 * self.zc, hop_length=self.zc)
227
- db_threhold = (librosa.amplitude_to_db(rms, ref=1.0)[0] < self.gui_config.threhold)
228
- for i in range(db_threhold.shape[0]):
229
- if db_threhold[i]:
230
- indata[i * self.zc : (i + 1) * self.zc] = 0
231
- self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :].clone()
232
- self.input_wav[-self.block_frame :] = torch.from_numpy(indata).to(self.config.device)
233
- self.input_wav_res[: -self.block_frame_16k] = self.input_wav_res[self.block_frame_16k :].clone()
234
- if self.gui_config.I_noise_reduce and self.function == "vc":
235
- input_wav = self.input_wav[-self.crossfade_frame - self.block_frame - 2 * self.zc :]
236
- input_wav = self.tg(input_wav.unsqueeze(0), self.input_wav.unsqueeze(0))[0, 2 * self.zc :]
237
- input_wav[: self.crossfade_frame] *= self.fade_in_window
238
- input_wav[: self.crossfade_frame] += self.nr_buffer * self.fade_out_window
239
- self.nr_buffer[:] = input_wav[-self.crossfade_frame :]
240
- input_wav = torch.cat((self.res_buffer[:], input_wav[: self.block_frame]))
241
- self.res_buffer[:] = input_wav[-2 * self.zc :]
242
- self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler(input_wav)[160:]
243
- else:
244
- self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler(self.input_wav[-self.block_frame - 2 * self.zc :])[160:]
245
- if self.function == "vc":
246
- f0_extractor_frame = self.block_frame_16k + 800
247
- if self.gui_config.f0method == "rmvpe":
248
- f0_extractor_frame = (5120 * ((f0_extractor_frame - 1) // 5120 + 1) - 160)
249
- infer_wav = self.rvc.infer(
250
- self.input_wav_res,
251
- self.input_wav_res[-f0_extractor_frame:].cpu().numpy(),
252
- self.block_frame_16k,
253
- self.valid_rate,
254
- self.pitch,
255
- self.pitchf,
256
- self.gui_config.f0method,
257
- )
258
- infer_wav = infer_wav[-self.crossfade_frame - self.sola_search_frame - self.block_frame :]
259
- else:
260
- infer_wav = self.input_wav[-self.crossfade_frame - self.sola_search_frame - self.block_frame :].clone()
261
- if (self.gui_config.O_noise_reduce and self.function == "vc") or (self.gui_config.I_noise_reduce and self.function == "im"):
262
- self.output_buffer[: -self.block_frame] = self.output_buffer[self.block_frame :].clone()
263
- self.output_buffer[-self.block_frame :] = infer_wav[-self.block_frame :]
264
- infer_wav = self.tg(infer_wav.unsqueeze(0), self.output_buffer.unsqueeze(0)).squeeze(0)
265
- if self.gui_config.rms_mix_rate < 1 and self.function == "vc":
266
- rms1 = librosa.feature.rms(y=self.input_wav_res[-160 * infer_wav.shape[0] // self.zc :].cpu().numpy(), frame_length=640, hop_length=160)
267
- rms1 = torch.from_numpy(rms1).to(self.config.device)
268
- rms1 = F.interpolate(rms1.unsqueeze(0), size=infer_wav.shape[0] + 1, mode="linear", align_corners=True)[0, 0, :-1]
269
- rms2 = librosa.feature.rms(y=infer_wav[:].cpu().numpy(), frame_length=4 * self.zc, hop_length=self.zc)
270
- rms2 = torch.from_numpy(rms2).to(self.config.device)
271
- rms2 = F.interpolate(rms2.unsqueeze(0), size=infer_wav.shape[0] + 1, mode="linear", align_corners=True)[0, 0, :-1]
272
- rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-3)
273
- infer_wav *= torch.pow(rms1 / rms2, torch.tensor(1 - self.gui_config.rms_mix_rate))
274
- conv_input = infer_wav[None, None, : self.crossfade_frame + self.sola_search_frame]
275
- cor_nom = F.conv1d(conv_input, self.sola_buffer[None, None, :])
276
- cor_den = torch.sqrt(F.conv1d(conv_input**2, torch.ones(1, 1, self.crossfade_frame, device=self.config.device)) + 1e-8)
277
- if sys.platform == "darwin":
278
- _, sola_offset = torch.max(cor_nom[0, 0] / cor_den[0, 0])
279
- sola_offset = sola_offset.item()
280
- else:
281
- sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
282
- logger.info(f"sola_offset = {sola_offset}")
283
- infer_wav = infer_wav[sola_offset : sola_offset + self.block_frame + self.crossfade_frame]
284
- infer_wav[: self.crossfade_frame] *= self.fade_in_window
285
- infer_wav[: self.crossfade_frame] += self.sola_buffer * self.fade_out_window
286
- self.sola_buffer[:] = infer_wav[-self.crossfade_frame :]
287
- if sys.platform == "darwin":
288
- outdata[:] = infer_wav[: -self.crossfade_frame].cpu().numpy()[:, np.newaxis]
289
- else:
290
- outdata[:] = infer_wav[: -self.crossfade_frame].repeat(2, 1).t().cpu().numpy()
291
- total_time = time.perf_counter() - start_time
292
- logger.info(f"Infer time: {total_time:.2f}")
293
-
294
- def get_devices(self, update: bool = True):
295
- if update:
296
- sd._terminate()
297
- sd._initialize()
298
- devices = sd.query_devices()
299
- hostapis = sd.query_hostapis()
300
- for hostapi in hostapis:
301
- for device_idx in hostapi["devices"]:
302
- devices[device_idx]["hostapi_name"] = hostapi["name"]
303
- input_devices = [
304
- f"{d['name']} ({d['hostapi_name']})"
305
- for d in devices
306
- if d["max_input_channels"] > 0
307
- ]
308
- output_devices = [
309
- f"{d['name']} ({d['hostapi_name']})"
310
- for d in devices
311
- if d["max_output_channels"] > 0
312
- ]
313
- input_devices_indices = [
314
- d["index"] if "index" in d else d["name"]
315
- for d in devices
316
- if d["max_input_channels"] > 0
317
- ]
318
- output_devices_indices = [
319
- d["index"] if "index" in d else d["name"]
320
- for d in devices
321
- if d["max_output_channels"] > 0
322
- ]
323
- return (
324
- input_devices,
325
- output_devices,
326
- input_devices_indices,
327
- output_devices_indices,
328
- )
329
-
330
- def set_devices(self, input_device, output_device):
331
- (
332
- input_devices,
333
- output_devices,
334
- input_device_indices,
335
- output_device_indices,
336
- ) = self.get_devices()
337
- logger.debug(f"Available input devices: {input_devices}")
338
- logger.debug(f"Available output devices: {output_devices}")
339
- logger.debug(f"Selected input device: {input_device}")
340
- logger.debug(f"Selected output device: {output_device}")
341
-
342
- if input_device not in input_devices:
343
- logger.error(f"Input device '{input_device}' is not in the list of available devices")
344
- raise HTTPException(status_code=400, detail=f"Input device '{input_device}' is not available")
345
-
346
- if output_device not in output_devices:
347
- logger.error(f"Output device '{output_device}' is not in the list of available devices")
348
- raise HTTPException(status_code=400, detail=f"Output device '{output_device}' is not available")
349
-
350
- sd.default.device[0] = input_device_indices[input_devices.index(input_device)]
351
- sd.default.device[1] = output_device_indices[output_devices.index(output_device)]
352
- logger.info(f"Input device set to {sd.default.device[0]}: {input_device}")
353
- logger.info(f"Output device set to {sd.default.device[1]}: {output_device}")
354
-
355
- audio_api = AudioAPI()
356
-
357
- @app.get("/inputDevices", response_model=list)
358
- def get_input_devices():
359
- try:
360
- input_devices, _, _, _ = audio_api.get_devices()
361
- return input_devices
362
- except Exception as e:
363
- logger.error(f"Failed to get input devices: {e}")
364
- raise HTTPException(status_code=500, detail="Failed to get input devices")
365
-
366
- @app.get("/outputDevices", response_model=list)
367
- def get_output_devices():
368
- try:
369
- _, output_devices, _, _ = audio_api.get_devices()
370
- return output_devices
371
- except Exception as e:
372
- logger.error(f"Failed to get output devices: {e}")
373
- raise HTTPException(status_code=500, detail="Failed to get output devices")
374
-
375
- @app.post("/config")
376
- def configure_audio(config_data: ConfigData):
377
- try:
378
- logger.info(f"Configuring audio with data: {config_data}")
379
- if audio_api.set_values(config_data):
380
- settings = config_data.dict()
381
- settings["use_jit"] = False
382
- settings["f0method"] = "rmvpe"
383
- with open("configs/config.json", "w", encoding='utf-8') as j:
384
- json.dump(settings, j, ensure_ascii=False)
385
- logger.info("Configuration set successfully")
386
- return {"message": "Configuration set successfully"}
387
- except HTTPException as e:
388
- logger.error(f"Configuration error: {e.detail}")
389
- raise
390
- except Exception as e:
391
- logger.error(f"Configuration failed: {e}")
392
- raise HTTPException(status_code=400, detail=f"Configuration failed: {e}")
393
-
394
- @app.post("/start")
395
- def start_conversion():
396
- try:
397
- if not audio_api.flag_vc:
398
- audio_api.start_vc()
399
- return {"message": "Audio conversion started"}
400
- else:
401
- logger.warning("Audio conversion already running")
402
- raise HTTPException(status_code=400, detail="Audio conversion already running")
403
- except HTTPException as e:
404
- logger.error(f"Start conversion error: {e.detail}")
405
- raise
406
- except Exception as e:
407
- logger.error(f"Failed to start conversion: {e}")
408
- raise HTTPException(status_code=500, detail=f"Failed to start conversion: {e}")
409
-
410
- @app.post("/stop")
411
- def stop_conversion():
412
- try:
413
- if audio_api.flag_vc:
414
- audio_api.flag_vc = False
415
- global stream_latency
416
- stream_latency = -1
417
- return {"message": "Audio conversion stopped"}
418
- else:
419
- logger.warning("Audio conversion not running")
420
- raise HTTPException(status_code=400, detail="Audio conversion not running")
421
- except HTTPException as e:
422
- logger.error(f"Stop conversion error: {e.detail}")
423
- raise
424
- except Exception as e:
425
- logger.error(f"Failed to stop conversion: {e}")
426
- raise HTTPException(status_code=500, detail=f"Failed to stop conversion: {e}")
427
-
428
- if __name__ == "__main__":
429
- if sys.platform == "win32":
430
- from multiprocessing import freeze_support
431
- freeze_support()
432
- load_dotenv()
433
- os.environ["OMP_NUM_THREADS"] = "4"
434
- if sys.platform == "darwin":
435
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
436
- from tools.torchgate import TorchGate
437
- import tools.rvc_for_realtime as rvc_for_realtime
438
- from configs.config import Config
439
- audio_api.config = Config()
440
- uvicorn.run(app, host="0.0.0.0", port=6242)