import spaces import gradio as gr import torch import torchaudio import os from einops import rearrange import gc import spaces import gradio as gr import torch import torchaudio import os from einops import rearrange from stable_audio_tools import get_pretrained_model from stable_audio_tools.inference.generation import generate_diffusion_cond from stable_audio_tools.data.utils import read_video, merge_video_audio, load_and_process_audio import stat import platform import logging from transformers import logging as transformers_logging transformers_logging.set_verbosity_error() logging.getLogger("transformers").setLevel(logging.ERROR) model, model_config = get_pretrained_model('HKUSTAudio/AudioX') sample_rate = model_config["sample_rate"] sample_size = model_config["sample_size"] TEMP_DIR = "tmp/gradio" os.makedirs(TEMP_DIR, exist_ok=True) os.chmod(TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) VIDEO_TEMP_DIR = os.path.join(TEMP_DIR, "videos") os.makedirs(VIDEO_TEMP_DIR, exist_ok=True) os.chmod(VIDEO_TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) @spaces.GPU(duration=10) def generate_cond( prompt, negative_prompt=None, video_file=None, audio_prompt_file=None, audio_prompt_path=None, seconds_start=0, seconds_total=10, cfg_scale=7.0, steps=100, preview_every=0, seed=-1, sampler_type="dpmpp-3m-sde", sigma_min=0.03, sigma_max=500, cfg_rescale=0.0, use_init=False, init_audio=None, init_noise_level=0.1, mask_cropfrom=None, mask_pastefrom=None, mask_pasteto=None, mask_maskstart=None, mask_maskend=None, mask_softnessL=None, mask_softnessR=None, mask_marination=None, batch_size=1 ): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() print(f"Prompt: {prompt}") preview_images = [] if preview_every == 0: preview_every = None try: has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() except Exception: has_mps = False if has_mps: device = torch.device("mps") elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") global model model = model.to(device) target_fps = model_config.get("video_fps", 5) model_type = model_config.get("model_type", "diffusion_cond") if video_file is not None: actual_video_path = video_file['name'] if isinstance(video_file, dict) else video_file.name else: actual_video_path = None if audio_prompt_file is not None: audio_path = audio_prompt_file.name elif audio_prompt_path: audio_path = audio_prompt_path.strip() else: audio_path = None Video_tensors = read_video(actual_video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps) audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) audio_tensor = audio_tensor.to(device) seconds_input = sample_size / sample_rate if not prompt: prompt = "" conditioning = [{ "video_prompt": [Video_tensors.unsqueeze(0)], "text_prompt": prompt, "audio_prompt": audio_tensor.unsqueeze(0), "seconds_start": seconds_start, "seconds_total": seconds_input }] if negative_prompt: negative_conditioning = [{ "video_prompt": [Video_tensors.unsqueeze(0)], "text_prompt": negative_prompt, "audio_prompt": audio_tensor.unsqueeze(0), "seconds_start": seconds_start, "seconds_total": seconds_total }] * 1 else: negative_conditioning = None seed = int(seed) if not use_init: init_audio = None input_sample_size = sample_size def progress_callback(callback_info): nonlocal preview_images denoised = callback_info["denoised"] current_step = callback_info["i"] sigma = callback_info["sigma"] if (current_step - 1) % preview_every == 0: if model.pretransform is not None: denoised = model.pretransform.decode(denoised) denoised = rearrange(denoised, "b d n -> d (b n)") denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) if model_type == "diffusion_cond": audio = generate_diffusion_cond( model, conditioning=conditioning, negative_conditioning=negative_conditioning, steps=steps, cfg_scale=cfg_scale, batch_size=batch_size, sample_size=input_sample_size, sample_rate=sample_rate, seed=seed, device=device, sampler_type=sampler_type, sigma_min=sigma_min, sigma_max=sigma_max, init_audio=init_audio, init_noise_level=init_noise_level, mask_args=None, callback=progress_callback if preview_every is not None else None, scale_phi=cfg_rescale ) audio = rearrange(audio, "b d n -> d (b n)") samples_10s = 10 * sample_rate audio = audio[:, :samples_10s] audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() output_dir = "demo_result" os.makedirs(output_dir, exist_ok=True) output_audio_path = f"{output_dir}/output.wav" torchaudio.save(output_audio_path, audio, sample_rate) if actual_video_path: output_video_path = f"{output_dir}/{os.path.basename(actual_video_path)}" target_width = 1280 target_height = 720 merge_video_audio( actual_video_path, output_audio_path, output_video_path, seconds_start, seconds_total ) else: output_video_path = None del actual_video_path torch.cuda.empty_cache() gc.collect() return output_video_path, output_audio_path with gr.Blocks() as interface: gr.Markdown( """ # 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation **[Paper](https://arxiv.org/abs/2503.10522) · [Project Page](https://zeyuet.github.io/AudioX/) · [Huggingface](https://huggingface.co./HKUSTAudio/AudioX) · [GitHub](https://github.com/ZeyueT/AudioX)** """ ) with gr.Tab("Generation"): with gr.Row(): with gr.Column(): prompt = gr.Textbox( show_label=False, placeholder="Enter your prompt" ) negative_prompt = gr.Textbox( show_label=False, placeholder="Negative prompt", visible=False ) video_file = gr.File(label="Upload Video File") audio_prompt_file = gr.File( label="Upload Audio Prompt File", visible=False ) audio_prompt_path = gr.Textbox( label="Audio Prompt Path", placeholder="Enter audio file path", visible=False ) with gr.Row(): with gr.Column(scale=6): with gr.Accordion("Video Params", open=False): seconds_start = gr.Slider( minimum=0, maximum=512, step=1, value=0, label="Video Seconds Start" ) seconds_total = gr.Slider( minimum=0, maximum=10, step=1, value=10, label="Seconds Total", interactive=False ) with gr.Row(): with gr.Column(scale=4): with gr.Accordion("Sampler Params", open=False): steps = gr.Slider( minimum=1, maximum=500, step=1, value=100, label="Steps" ) preview_every = gr.Slider( minimum=0, maximum=100, step=1, value=0, label="Preview Every" ) cfg_scale = gr.Slider( minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG Scale" ) seed = gr.Textbox( label="Seed (set to -1 for random seed)", value="-1" ) sampler_type = gr.Dropdown( choices=[ "dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast" ], label="Sampler Type", value="dpmpp-3m-sde" ) sigma_min = gr.Slider( minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma Min" ) sigma_max = gr.Slider( minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma Max" ) cfg_rescale = gr.Slider( minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG Rescale Amount" ) with gr.Row(): with gr.Column(scale=4): with gr.Accordion("Init Audio", open=False, visible=False): init_audio_checkbox = gr.Checkbox(label="Use Init Audio") init_audio_input = gr.Audio(label="Init Audio") init_noise_level = gr.Slider( minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init Noise Level" ) with gr.Row(): generate_button = gr.Button("Generate", variant="primary") with gr.Row(): with gr.Column(scale=6): video_output = gr.Video(label="Output Video", interactive=False) audio_output = gr.Audio(label="Output Audio", interactive=False) inputs = [ prompt, negative_prompt, video_file, audio_prompt_file, audio_prompt_path, seconds_start, seconds_total, cfg_scale, steps, preview_every, seed, sampler_type, sigma_min, sigma_max, cfg_rescale, init_audio_checkbox, init_audio_input, init_noise_level ] generate_button.click( fn=generate_cond, inputs=inputs, outputs=[video_output, audio_output] ) gr.Markdown("## Examples") with gr.Accordion("Click to show examples", open=False): with gr.Row(): gr.Markdown("**📝 Task: Text-to-Audio**") with gr.Column(scale=1.2): gr.Markdown("Prompt: *Typing on a keyboard*") ex1 = gr.Button("Load Example") with gr.Column(scale=1.2): gr.Markdown("Prompt: *Ocean waves crashing*") ex2 = gr.Button("Load Example") with gr.Column(scale=1.2): gr.Markdown("Prompt: *Footsteps in snow*") ex3 = gr.Button("Load Example") with gr.Row(): gr.Markdown("**🎶 Task: Text-to-Music**") with gr.Column(scale=1.2): gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*") ex4 = gr.Button("Load Example") with gr.Column(scale=1.2): gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*") ex5 = gr.Button("Load Example") with gr.Column(scale=1.2): gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*") ex6 = gr.Button("Load Example") ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) ex3.click(lambda: ["Footsteps in snow", None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) interface.queue(5).launch(server_name="0.0.0.0", server_port=7860, share=True)