FQiao's picture
Upload 70 files
3324de2 verified
# Prediction interface for Cog ⚙️
# https://cog.run/python
import os
import subprocess
import time
import json
from cog import BasePredictor, Input, Path
from diffusers import AutoencoderOobleck
import soundfile as sf
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
from tangoflux.model import TangoFlux
from tangoflux import TangoFluxInference
MODEL_CACHE = "model_cache"
MODEL_URL = (
"https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
)
class CachedTangoFluxInference(TangoFluxInference):
## load the weights from replicate.delivery for faster booting
def __init__(self, name="declare-lab/TangoFlux", device="cuda", cached_paths=None):
if cached_paths:
paths = cached_paths
else:
paths = snapshot_download(repo_id=name)
self.vae = AutoencoderOobleck()
vae_weights = load_file(f"{paths}/vae.safetensors")
self.vae.load_state_dict(vae_weights)
weights = load_file(f"{paths}/tangoflux.safetensors")
with open(f"{paths}/config.json", "r") as f:
config = json.load(f)
self.model = TangoFlux(config)
self.model.load_state_dict(weights, strict=False)
self.vae.to(device)
self.model.to(device)
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
if not os.path.exists(MODEL_CACHE):
print("downloading")
download_weights(MODEL_URL, MODEL_CACHE)
self.model = CachedTangoFluxInference(
cached_paths=f"{MODEL_CACHE}/declare-lab/TangoFlux"
)
def predict(
self,
prompt: str = Input(
description="Input prompt", default="Hammer slowly hitting the wooden table"
),
duration: int = Input(
description="Duration of the output audio in seconds", default=10
),
steps: int = Input(
description="Number of inference steps", ge=1, le=200, default=25
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=4.5
),
) -> Path:
"""Run a single prediction on the model"""
audio = self.model.generate(
prompt,
steps=steps,
guidance_scale=guidance_scale,
duration=duration,
)
audio_numpy = audio.numpy()
out_path = "/tmp/out.wav"
sf.write(
out_path, audio_numpy.T, samplerate=self.model.vae.config.sampling_rate
)
return Path(out_path)