Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import gc | |
from pathlib import Path | |
import gradio as gr | |
import torch | |
import torchaudio | |
from config import LOGS_DIR, OUTPUT_DIR | |
from SoundMapper import SoundMapper | |
from GenerateAudio import GenerateAudio | |
from GenerateCaptions import generate_caption | |
from audio_mixer import compose_audio | |
# Ensure required directories exist | |
os.makedirs(LOGS_DIR, exist_ok=True) | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# Prepare external model dir and download checkpoint if missing | |
from pathlib import Path | |
depthfm_ckpt = Path('external_models/depth-fm/checkpoints/depthfm-v1.ckpt') | |
if not depthfm_ckpt.exists(): | |
depthfm_ckpt.parent.mkdir(parents=True, exist_ok=True) | |
os.system('wget https://ommer-lab.com/files/depthfm/depthfm-v1.ckpt -P external_models/depth-fm/checkpoints/') | |
# Clear CUDA cache between runs | |
def clear_cuda(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def process_images( | |
image_dir: str, | |
output_dir: str, | |
panoramic: bool, | |
view: str, | |
model: str, | |
location: str, | |
audio_duration: int, | |
cpu_only: bool | |
) -> None: | |
# Existing processing logic, generates files in OUTPUT_DIR | |
lat, lon = location.split(",") | |
os.makedirs(output_dir, exist_ok=True) | |
sound_mapper = SoundMapper() | |
audio_generator = GenerateAudio() | |
if panoramic: | |
# Panoramic: generate per-view audio then composition | |
view_results = generate_caption(lat, lon, view=view, model=model, | |
cpu_only=cpu_only, panoramic=True) | |
processed_maps = sound_mapper.process_depth_maps() | |
image_paths = sorted(Path(image_dir).glob("*.jpg")) | |
audios = {} | |
for vr in view_results: | |
cv = vr["view"] | |
img_file = Path(image_dir) / f"{cv}.jpg" | |
if not img_file.exists(): | |
continue | |
idx = [i for i, p in enumerate(image_paths) if p.name == img_file.name] | |
if not idx: | |
continue | |
depth_map = processed_maps[idx[0]]["normalization"] | |
obj_depths = sound_mapper.analyze_object_depths( | |
str(img_file), depth_map, lat, lon, | |
caption_data=vr, all_objects=False | |
) | |
if not obj_depths: | |
continue | |
out_wav = Path(output_dir) / f"sound_{cv}.wav" | |
audio, sr = audio_generator.process_and_generate_audio( | |
obj_depths, duration=audio_duration | |
) | |
if audio.dim() == 3: | |
audio = audio.squeeze(0) | |
elif audio.dim() == 1: | |
audio = audio.unsqueeze(0) | |
torchaudio.save(str(out_wav), audio, sr) | |
audios[cv] = str(out_wav) | |
# final panoramic composition | |
comp = Path(output_dir) / "panoramic_composition.wav" | |
compose_audio(list(audios.values()), [1.0]*len(audios), str(comp)) | |
audios['panorama'] = str(comp) | |
clear_cuda() | |
return | |
# Single-view: generate one audio | |
vr = generate_caption(lat, lon, view=view, model=model, | |
cpu_only=cpu_only, panoramic=False) | |
img_file = Path(image_dir) / f"{view}.jpg" | |
processed_maps = sound_mapper.process_depth_maps() | |
image_paths = sorted(Path(image_dir).glob("*.jpg")) | |
idx = [i for i, p in enumerate(image_paths) if p.name == img_file.name] | |
depth_map = processed_maps[idx[0]]["normalization"] | |
obj_depths = sound_mapper.analyze_object_depths( | |
str(img_file), depth_map, lat, lon, | |
caption_data=vr, all_objects=True | |
) | |
out_wav = Path(output_dir) / f"sound_{view}.wav" | |
audio, sr = audio_generator.process_and_generate_audio(obj_depths, duration=audio_duration) | |
if audio.dim() == 3: | |
audio = audio.squeeze(0) | |
elif audio.dim() == 1: | |
audio = audio.unsqueeze(0) | |
torchaudio.save(str(out_wav), audio, sr) | |
clear_cuda() | |
# Gradio UI | |
demo = gr.Blocks(title="Panoramic Audio Generator") | |
with demo: | |
gr.Markdown(""" | |
# Panoramic Audio Generator | |
Displays each view with its audio side by side. | |
""" | |
) | |
with gr.Row(): | |
panoramic = gr.Checkbox(label="Panoramic (multi-view)", value=False) | |
view = gr.Dropdown(["front", "back", "left", "right"], value="front", label="View") | |
location = gr.Textbox(value="52.3436723,4.8529625", label="Location (lat,lon)") | |
model = gr.Textbox(value="intern_2_5-4B", label="Vision-Language Model") | |
# model = "intern_2_5-4B" | |
audio_duration = gr.Slider(1, 60, value=10, step=1, label="Audio Duration (sec)") | |
cpu_only = gr.Checkbox(label="CPU Only", value=False) | |
btn = gr.Button("Generate") | |
# Output layout: two rows of two | |
with gr.Row(): | |
with gr.Column(): | |
img_front = gr.Image(label="Front View", type="filepath") | |
aud_front = gr.Audio(label="Front Audio", type="filepath") | |
with gr.Column(): | |
img_back = gr.Image(label="Back View", type="filepath") | |
aud_back = gr.Audio(label="Back Audio", type="filepath") | |
with gr.Row(): | |
with gr.Column(): | |
img_left = gr.Image(label="Left View", type="filepath") | |
aud_left = gr.Audio(label="Left Audio", type="filepath") | |
with gr.Column(): | |
img_right = gr.Image(label="Right View", type="filepath") | |
aud_right = gr.Audio(label="Right Audio", type="filepath") | |
# Panorama at bottom | |
img_pan = gr.Image(label="Panorama View", type="filepath") | |
aud_pan = gr.Audio(label="Panoramic Audio", type="filepath") | |
# Preview update | |
def run_all(pan, vw, loc, mdl, dur, cpu): | |
# generate files | |
process_images(LOGS_DIR, OUTPUT_DIR, pan, vw, mdl, loc, dur, cpu) | |
# collect files | |
views = ["front", "back", "left", "right", "panorama"] | |
paths = {} | |
for v in views: | |
img = Path(LOGS_DIR) / f"{v}.jpg" | |
audio = Path(OUTPUT_DIR) / ("panoramic_composition.wav" if v == "panorama" else f"sound_{v}.wav") | |
paths[v] = { | |
'img': str(img) if img.exists() else None, | |
'aud': str(audio) if audio.exists() else None | |
} | |
return ( | |
paths['front']['img'], paths['front']['aud'], | |
paths['back']['img'], paths['back']['aud'], | |
paths['left']['img'], paths['left']['aud'], | |
paths['right']['img'], paths['right']['aud'], | |
paths['panorama']['img'], paths['panorama']['aud'] | |
) | |
btn.click( | |
fn=run_all, | |
inputs=[panoramic, view, location, model, audio_duration, cpu_only], | |
outputs=[ | |
img_front, aud_front, | |
img_back, aud_back, | |
img_left, aud_left, | |
img_right, aud_right, | |
img_pan, aud_pan | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch(show_api=False) | |