Zeyue7's picture
AudioX
8ab1cf8
import gc
import platform
import os
import subprocess as sp
import gradio as gr
import json
import torch
import torchaudio
from aeiou.viz import audio_spectrogram_image
from einops import rearrange
from safetensors.torch import load_file
from torch.nn import functional as F
from torchaudio import transforms as T
from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond
from ..models.factory import create_model_from_config
from ..models.pretrained import get_pretrained_model
from ..models.utils import load_ckpt_state_dict
from ..inference.utils import prepare_audio
from ..training.utils import copy_state_dict
from ..data.utils import read_video, merge_video_audio
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
device = torch.device("cpu")
os.environ['TMPDIR'] = './tmp'
current_model_name = None
current_model = None
current_sample_rate = None
current_sample_size = None
def load_model(model_name, model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
global model_configurations
if pretrained_name is not None:
print(f"Loading pretrained model {pretrained_name}")
model, model_config = get_pretrained_model(pretrained_name)
elif model_config is not None and model_ckpt_path is not None:
print(f"Creating model from config")
model = create_model_from_config(model_config)
print(f"Loading model checkpoint from {model_ckpt_path}")
copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]
if pretransform_ckpt_path is not None:
print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False)
print(f"Done loading pretransform")
model.to(device).eval().requires_grad_(False)
if model_half:
model.to(torch.float16)
print(f"Done loading model")
return model, model_config, sample_rate, sample_size
def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total):
if audio_path is None:
return torch.zeros((2, int(sample_rate * seconds_total)))
audio_tensor, sr = torchaudio.load(audio_path)
start_index = int(sample_rate * seconds_start)
target_length = int(sample_rate * seconds_total)
end_index = start_index + target_length
audio_tensor = audio_tensor[:, start_index:end_index]
if audio_tensor.shape[1] < target_length:
pad_length = target_length - audio_tensor.shape[1]
audio_tensor = F.pad(audio_tensor, (pad_length, 0))
return audio_tensor
def generate_cond(
prompt,
negative_prompt=None,
video_file=None,
video_path=None,
audio_prompt_file=None,
audio_prompt_path=None,
seconds_start=0,
seconds_total=10,
cfg_scale=6.0,
steps=250,
preview_every=None,
seed=-1,
sampler_type="dpmpp-3m-sde",
sigma_min=0.03,
sigma_max=1000,
cfg_rescale=0.0,
use_init=False,
init_audio=None,
init_noise_level=1.0,
mask_cropfrom=None,
mask_pastefrom=None,
mask_pasteto=None,
mask_maskstart=None,
mask_maskend=None,
mask_softnessL=None,
mask_softnessR=None,
mask_marination=None,
batch_size=1
):
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
print(f"Prompt: {prompt}")
preview_images = []
if preview_every == 0:
preview_every = None
try:
has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
except Exception:
has_mps = False
if has_mps:
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model_name = 'default'
cfg = model_configurations[model_name]
model_config_path = cfg.get("model_config")
ckpt_path = cfg.get("ckpt_path")
pretrained_name = cfg.get("pretrained_name")
pretransform_ckpt_path = cfg.get("pretransform_ckpt_path")
model_type = cfg.get("model_type", "diffusion_cond")
if model_config_path:
with open(model_config_path) as f:
model_config = json.load(f)
else:
model_config = None
target_fps = model_config.get("video_fps", 5)
global current_model_name, current_model, current_sample_rate, current_sample_size
if current_model is None or model_name != current_model_name:
current_model, model_config, sample_rate, sample_size = load_model(
model_name=model_name,
model_config=model_config,
model_ckpt_path=ckpt_path,
pretrained_name=pretrained_name,
pretransform_ckpt_path=pretransform_ckpt_path,
device=device,
model_half=False
)
current_model_name = model_name
model = current_model
current_sample_rate = sample_rate
current_sample_size = sample_size
else:
model = current_model
sample_rate = current_sample_rate
sample_size = current_sample_size
if video_file is not None:
video_path = video_file.name
elif video_path:
video_path = video_path.strip()
else:
video_path = None
if audio_prompt_file is not None:
print(f'audio_prompt_file: {audio_prompt_file}')
audio_path = audio_prompt_file.name
elif audio_prompt_path:
audio_path = audio_prompt_path.strip()
else:
audio_path = None
Video_tensors = read_video(video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps)
audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total)
audio_tensor = audio_tensor.to(device)
seconds_input = sample_size / sample_rate
print(f'video_path: {video_path}')
if not prompt:
prompt = ""
conditioning = [{
"video_prompt": [Video_tensors.unsqueeze(0)],
"text_prompt": prompt,
"audio_prompt": audio_tensor.unsqueeze(0),
"seconds_start": seconds_start,
"seconds_total": seconds_input
}] * batch_size
if negative_prompt:
negative_conditioning = [{
"video_prompt": [Video_tensors.unsqueeze(0)],
"text_prompt": negative_prompt,
"audio_prompt": audio_tensor.unsqueeze(0),
"seconds_start": seconds_start,
"seconds_total": seconds_total
}] * batch_size
else:
negative_conditioning = None
try:
device = next(model.parameters()).device
except Exception as e:
device = next(current_model.parameters()).device
seed = int(seed)
if not use_init:
init_audio = None
input_sample_size = sample_size
if init_audio is not None:
in_sr, init_audio = init_audio
init_audio = torch.from_numpy(init_audio).float().div(32767)
if init_audio.dim() == 1:
init_audio = init_audio.unsqueeze(0)
elif init_audio.dim() == 2:
init_audio = init_audio.transpose(0, 1)
if in_sr != sample_rate:
resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
init_audio = resample_tf(init_audio)
audio_length = init_audio.shape[-1]
if audio_length > sample_size:
input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
init_audio = (sample_rate, init_audio)
def progress_callback(callback_info):
nonlocal preview_images
denoised = callback_info["denoised"]
current_step = callback_info["i"]
sigma = callback_info["sigma"]
if (current_step - 1) % preview_every == 0:
if model.pretransform is not None:
denoised = model.pretransform.decode(denoised)
denoised = rearrange(denoised, "b d n -> d (b n)")
denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
if mask_cropfrom is not None:
mask_args = {
"cropfrom": mask_cropfrom,
"pastefrom": mask_pastefrom,
"pasteto": mask_pasteto,
"maskstart": mask_maskstart,
"maskend": mask_maskend,
"softnessL": mask_softnessL,
"softnessR": mask_softnessR,
"marination": mask_marination,
}
else:
mask_args = None
if model_type == "diffusion_cond":
audio = generate_diffusion_cond(
model,
conditioning=conditioning,
negative_conditioning=negative_conditioning,
steps=steps,
cfg_scale=cfg_scale,
batch_size=batch_size,
sample_size=input_sample_size,
sample_rate=sample_rate,
seed=seed,
device=device,
sampler_type=sampler_type,
sigma_min=sigma_min,
sigma_max=sigma_max,
init_audio=init_audio,
init_noise_level=init_noise_level,
mask_args=mask_args,
callback=progress_callback if preview_every is not None else None,
scale_phi=cfg_rescale
)
elif model_type == "diffusion_uncond":
audio = generate_diffusion_uncond(
model,
steps=steps,
batch_size=batch_size,
sample_size=input_sample_size,
seed=seed,
device=device,
sampler_type=sampler_type,
sigma_min=sigma_min,
sigma_max=sigma_max,
init_audio=init_audio,
init_noise_level=init_noise_level,
callback=progress_callback if preview_every is not None else None
)
else:
raise ValueError(f"Unsupported model type: {model_type}")
audio = rearrange(audio, "b d n -> d (b n)")
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
file_name = os.path.basename(video_path) if video_path else "output"
output_dir = f"demo_result"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_video_path = f"{output_dir}/{file_name}"
torchaudio.save(f"{output_dir}/output.wav", audio, sample_rate)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if video_path:
merge_video_audio(video_path, f"{output_dir}/output.wav", output_video_path, seconds_start, seconds_total)
audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
del video_path
torch.cuda.empty_cache()
gc.collect()
return (output_video_path, f"{output_dir}/output.wav")
def toggle_custom_model(selected_model):
return gr.Row.update(visible=(selected_model == "Custom Model"))
def create_sampling_ui(model_config_map, inpainting=False):
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation
**[Project Page](https://zeyuet.github.io/AudioX/) Β· [Huggingface](https://huggingface.co./Zeyue7/AudioX) Β· [GitHub](https://github.com/ZeyueT/AudioX)**
"""
)
with gr.Tab("Generation"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt")
negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt", visible=False)
video_path = gr.Textbox(label="Video Path", placeholder="Enter video file path")
video_file = gr.File(label="Upload Video File")
audio_prompt_file = gr.File(label="Upload Audio Prompt File", visible=False)
audio_prompt_path = gr.Textbox(label="Audio Prompt Path", placeholder="Enter audio file path", visible=False)
with gr.Row():
with gr.Column(scale=6):
with gr.Accordion("Video Params", open=False):
seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Video Seconds Start")
seconds_total_slider = gr.Slider(minimum=0, maximum=10, step=1, value=10, label="Seconds Total", interactive=False)
with gr.Row():
with gr.Column(scale=4):
with gr.Accordion("Sampler Params", open=False):
steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every")
cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG Scale")
seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
sampler_type_dropdown = gr.Dropdown(
["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"],
label="Sampler Type",
value="dpmpp-3m-sde"
)
sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma Min")
sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma Max")
cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG Rescale Amount")
with gr.Row():
with gr.Column(scale=4):
with gr.Accordion("Init Audio", open=False, visible=False):
init_audio_checkbox = gr.Checkbox(label="Use Init Audio")
init_audio_input = gr.Audio(label="Init Audio")
init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init Noise Level")
gr.Markdown("## Examples")
with gr.Accordion("Click to show examples", open=False):
with gr.Row():
gr.Markdown("**πŸ“ Task: Text-to-Audio**")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *Typing on a keyboard*")
ex1 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *Ocean waves crashing*")
ex2 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *Footsteps in snow*")
ex3 = gr.Button("Load Example")
with gr.Row():
gr.Markdown("**🎢 Task: Text-to-Music**")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*")
ex4 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*")
ex5 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*")
ex6 = gr.Button("Load Example")
with gr.Row():
gr.Markdown("**🎬 Task: Video-to-Audio**\nPrompt: *Generate general audio for the video*")
with gr.Column(scale=1.2):
gr.Video("example/V2A_sample-1.mp4")
ex7 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Video("example/V2A_sample-2.mp4")
ex8 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Video("example/V2A_sample-3.mp4")
ex9 = gr.Button("Load Example")
with gr.Row():
gr.Markdown("**🎡 Task: Video-to-Music**\nPrompt: *Generate music for the video*")
with gr.Column(scale=1.2):
gr.Video("example/V2M_sample-1.mp4")
ex10 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Video("example/V2M_sample-2.mp4")
ex11 = gr.Button("Load Example")
with gr.Column(scale=1.2):
gr.Video("example/V2M_sample-3.mp4")
ex12 = gr.Button("Load Example")
with gr.Row():
generate_button = gr.Button("Generate", variant='primary', scale=1)
with gr.Row():
with gr.Column(scale=6):
video_output = gr.Video(label="Output Video", interactive=False)
audio_output = gr.Audio(label="Output Audio", interactive=False)
send_to_init_button = gr.Button("Send to Init Audio", scale=1, visible=False)
send_to_init_button.click(
fn=lambda audio: audio,
inputs=[audio_output],
outputs=[init_audio_input]
)
inputs = [
prompt,
negative_prompt,
video_file,
video_path,
audio_prompt_file,
audio_prompt_path,
seconds_start_slider,
seconds_total_slider,
cfg_scale_slider,
steps_slider,
preview_every_slider,
seed_textbox,
sampler_type_dropdown,
sigma_min_slider,
sigma_max_slider,
cfg_rescale_slider,
init_audio_checkbox,
init_audio_input,
init_noise_level_slider
]
generate_button.click(
fn=generate_cond,
inputs=inputs,
outputs=[
video_output,
audio_output
],
api_name="generate"
)
ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex3.click(lambda: ["Footsteps in snow", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex7.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3737819478", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex8.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "1900718499", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex9.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "2289822202", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex10.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3498087420", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex11.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "3753837734", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
ex12.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "3510832996", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
return demo
def create_txt2audio_ui(model_config_map):
with gr.Blocks(css=".gradio-container { max-width: 1120px; margin: auto; }") as ui:
with gr.Tab("Generation"):
create_sampling_ui(model_config_map)
return ui
def toggle_custom_model(selected_model):
return gr.Row.update(visible=(selected_model == "Custom Model"))
def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False):
global model_configurations
global device
try:
has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
except Exception:
has_mps = False
if has_mps:
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print("Using device:", device)
model_configurations = {
"default": {
"model_config": "./model/config.json",
"ckpt_path": "./model/model.ckpt"
}
}
ui = create_txt2audio_ui(model_configurations)
return ui
if __name__ == "__main__":
ui = create_ui(
model_config_path='./model/config.json',
share=True
)
ui.launch()