Spaces:
Running
on
Zero
Running
on
Zero
Upload 70 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- DepthEstimator.py +67 -0
- GenerateAudio.py +184 -0
- GenerateCaptions.py +494 -0
- README.md +50 -13
- SoundMapper.py +438 -0
- app.py +182 -0
- audio_mixer.py +428 -0
- config.py +16 -0
- environment.yml +8 -0
- external_models/TangoFlux/.gitignore +175 -0
- external_models/TangoFlux/Demo.ipynb +117 -0
- external_models/TangoFlux/Inference.ipynb +0 -0
- external_models/TangoFlux/LICENSE.md +51 -0
- external_models/TangoFlux/Notice +1 -0
- external_models/TangoFlux/README.md +188 -0
- external_models/TangoFlux/STABILITY_AI_COMMUNITY_LICENSE.md +57 -0
- external_models/TangoFlux/__init__.py +4 -0
- external_models/TangoFlux/assets/tangoflux.png +3 -0
- external_models/TangoFlux/assets/tf_opener.png +3 -0
- external_models/TangoFlux/assets/tf_teaser.png +3 -0
- external_models/TangoFlux/comfyui/README.md +78 -0
- external_models/TangoFlux/comfyui/__init__.py +6 -0
- external_models/TangoFlux/comfyui/example_workflow.json +168 -0
- external_models/TangoFlux/comfyui/install.py +79 -0
- external_models/TangoFlux/comfyui/nodes.py +328 -0
- external_models/TangoFlux/comfyui/requirements.txt +9 -0
- external_models/TangoFlux/comfyui/server.py +64 -0
- external_models/TangoFlux/comfyui/teacache.py +283 -0
- external_models/TangoFlux/comfyui/web/js/playAudio.js +59 -0
- external_models/TangoFlux/configs/__init__.py +0 -0
- external_models/TangoFlux/configs/accelerator_config.yaml +17 -0
- external_models/TangoFlux/configs/tangoflux_config.yaml +36 -0
- external_models/TangoFlux/crpo.sh +2 -0
- external_models/TangoFlux/inference.py +7 -0
- external_models/TangoFlux/replicate_demo/cog.yaml +31 -0
- external_models/TangoFlux/replicate_demo/predict.py +92 -0
- external_models/TangoFlux/requirements.txt +12 -0
- external_models/TangoFlux/setup.py +30 -0
- external_models/TangoFlux/tangoflux/__init__.py +60 -0
- external_models/TangoFlux/tangoflux/cli.py +29 -0
- external_models/TangoFlux/tangoflux/demo.py +63 -0
- external_models/TangoFlux/tangoflux/generate_crpo_dataset.py +204 -0
- external_models/TangoFlux/tangoflux/label_crpo.py +153 -0
- external_models/TangoFlux/tangoflux/model.py +556 -0
- external_models/TangoFlux/tangoflux/train.py +588 -0
- external_models/TangoFlux/tangoflux/train_dpo.py +608 -0
- external_models/TangoFlux/tangoflux/utils.py +159 -0
- external_models/TangoFlux/train.sh +2 -0
- external_models/depth-fm/.gitignore +5 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
external_models/depth-fm/assets/dog.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
external_models/depth-fm/assets/figures/dfm-cover.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
external_models/depth-fm/assets/figures/radio.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
external_models/TangoFlux/assets/tangoflux.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
external_models/TangoFlux/assets/tf_opener.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
external_models/TangoFlux/assets/tf_teaser.png filter=lfs diff=lfs merge=lfs -text
|
DepthEstimator.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from accelerate.test_utils.testing import get_backend
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from config import LOGS_DIR, DEPTH_FM_CHECKPOINT, DEPTH_FM_DIR
|
7 |
+
sys.path.append(DEPTH_FM_DIR + '/depthfm')
|
8 |
+
from dfm import DepthFM
|
9 |
+
from unet import UNetModel
|
10 |
+
import einops
|
11 |
+
import numpy as np
|
12 |
+
from torchvision import transforms
|
13 |
+
|
14 |
+
|
15 |
+
class DepthEstimator:
|
16 |
+
def __init__(self, image_dir = LOGS_DIR):
|
17 |
+
self.device,_,_ = get_backend()
|
18 |
+
self.image_dir = image_dir
|
19 |
+
self.model = None
|
20 |
+
|
21 |
+
def _load_model(self):
|
22 |
+
if self.model is None:
|
23 |
+
self.model = DepthFM(DEPTH_FM_CHECKPOINT).to(self.device).eval()
|
24 |
+
else:
|
25 |
+
self.model = self.model.to(self.device).eval()
|
26 |
+
|
27 |
+
def _unload_model(self):
|
28 |
+
if self.model is not None:
|
29 |
+
self.model = self.model.to("cpu")
|
30 |
+
torch.cuda.empty_cache()
|
31 |
+
|
32 |
+
|
33 |
+
def estimate_depth(self, image_path : str) -> list:
|
34 |
+
print("Estimating depth...")
|
35 |
+
predictions_list = []
|
36 |
+
self._load_model()
|
37 |
+
for img in os.listdir(image_path):
|
38 |
+
if img.endswith(".jpg") or img.endswith(".jpeg") or img.endswith(".png"):
|
39 |
+
image = Image.open(os.path.join(image_path, img))
|
40 |
+
x = np.array(image)
|
41 |
+
x = einops.rearrange(x, 'h w c -> c h w')
|
42 |
+
x = x / 127.5 - 1
|
43 |
+
x = torch.tensor(x, dtype=torch.float32)[None]
|
44 |
+
with torch.no_grad():
|
45 |
+
depth = self.model.predict_depth(x.to(self.device), num_steps=2, ensemble_size=4) # returns a tensor
|
46 |
+
depth.cpu()
|
47 |
+
to_pil = transforms.ToPILImage()
|
48 |
+
PIL_image = to_pil(depth.squeeze())
|
49 |
+
predictions_list.append({"depth": PIL_image})
|
50 |
+
del x, depth
|
51 |
+
torch.cuda.empty_cache()
|
52 |
+
self._unload_model()
|
53 |
+
print("Depth estimation complete.")
|
54 |
+
return predictions_list
|
55 |
+
|
56 |
+
def visualize(self, predictions_list : list) -> None:
|
57 |
+
for (i, prediction) in enumerate(predictions_list):
|
58 |
+
prediction["depth"].save(f"depth_{i}.png")
|
59 |
+
|
60 |
+
|
61 |
+
# Estimator = DepthEstimator()
|
62 |
+
# predictions = Estimator.estimate_depth(Estimator.image_dir)
|
63 |
+
# Estimator.visualize(predictions)
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
GenerateAudio.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchaudio
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
from config import TANGO_FLUX_DIR
|
6 |
+
sys.path.append(TANGO_FLUX_DIR)
|
7 |
+
from tangoflux import TangoFluxInference
|
8 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
9 |
+
from collections import Counter
|
10 |
+
|
11 |
+
class GenerateAudio():
|
12 |
+
def __init__(self):
|
13 |
+
self.device = "cuda"
|
14 |
+
self.model = None
|
15 |
+
self.text_encoder = None
|
16 |
+
|
17 |
+
# Basic categories for object classification
|
18 |
+
self.categories = {
|
19 |
+
'vehicle': ['car', 'bus', 'truck', 'motorcycle', 'bicycle', 'train', 'vehicle'],
|
20 |
+
'nature': ['tree', 'bird', 'water', 'river', 'lake', 'ocean', 'rain', 'wind', 'forest'],
|
21 |
+
'urban': ['traffic', 'building', 'street', 'signal', 'construction'],
|
22 |
+
'animal': ['dog', 'cat', 'bird', 'insect', 'frog', 'squirrel'],
|
23 |
+
'human': ['person', 'people', 'crowd', 'child', 'footstep', 'voice'],
|
24 |
+
'indoor': ['door', 'window', 'chair', 'table', 'fan', 'appliance', 'tv', 'radio']
|
25 |
+
}
|
26 |
+
|
27 |
+
# Suffixes and prefixes for pattern matching
|
28 |
+
self.suffixes = {
|
29 |
+
'tree': 'nature',
|
30 |
+
'bird': 'animal',
|
31 |
+
'car': 'vehicle',
|
32 |
+
'truck': 'vehicle',
|
33 |
+
'signal': 'urban'
|
34 |
+
}
|
35 |
+
|
36 |
+
def _load_model(self):
|
37 |
+
if self.model is None:
|
38 |
+
self.model = TangoFluxInference(name='declare-lab/TangoFlux')
|
39 |
+
if self.text_encoder is None:
|
40 |
+
self.text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-large").to(self.device).eval()
|
41 |
+
else:
|
42 |
+
self.text_encoder = self.text_encoder.to(self.device)
|
43 |
+
|
44 |
+
def generate_sound(self, prompt, steps=25, duration=10, guidance_scale=4.5, disable_progress=True):
|
45 |
+
self._load_model()
|
46 |
+
with torch.no_grad():
|
47 |
+
latents = self.model.model.inference_flow(
|
48 |
+
prompt,
|
49 |
+
duration=duration,
|
50 |
+
num_inference_steps=steps,
|
51 |
+
guidance_scale=guidance_scale,
|
52 |
+
disable_progress=disable_progress
|
53 |
+
)
|
54 |
+
wave = self.model.vae.decode(latents.transpose(2, 1)).sample.cpu()[0]
|
55 |
+
waveform_end = int(duration * self.model.vae.config.sampling_rate)
|
56 |
+
wave = wave[:, :waveform_end]
|
57 |
+
|
58 |
+
return wave
|
59 |
+
|
60 |
+
def _categorize_object(self, object_name):
|
61 |
+
"""Categorize an object based on keywords or patterns"""
|
62 |
+
object_lower = object_name.lower()
|
63 |
+
|
64 |
+
# Check if the object contains any category keywords
|
65 |
+
for category, keywords in self.categories.items():
|
66 |
+
for keyword in keywords:
|
67 |
+
if keyword in object_lower:
|
68 |
+
return category
|
69 |
+
|
70 |
+
# Check suffix/prefix patterns
|
71 |
+
words = object_lower.split()
|
72 |
+
for word in words:
|
73 |
+
for suffix, category in self.suffixes.items():
|
74 |
+
if word.endswith(suffix):
|
75 |
+
return category
|
76 |
+
|
77 |
+
return "unknown"
|
78 |
+
|
79 |
+
def _describe_object_sound(self, object_name, zone):
|
80 |
+
"""Generate an appropriate sound description based on object type and distance"""
|
81 |
+
category = self._categorize_object(object_name)
|
82 |
+
|
83 |
+
# Volume descriptor based on zone
|
84 |
+
volume_descriptors = {
|
85 |
+
"near": ["prominent", "clear", "loud", "distinct"],
|
86 |
+
"medium": ["moderate", "audible", "present"],
|
87 |
+
"far": ["subtle", "distant", "faint", "soft"]
|
88 |
+
}
|
89 |
+
|
90 |
+
volume = random.choice(volume_descriptors[zone])
|
91 |
+
|
92 |
+
# Sound descriptors based on category
|
93 |
+
sound_templates = {
|
94 |
+
"vehicle": [
|
95 |
+
"{volume} engine sounds from the {object}",
|
96 |
+
"{volume} mechanical noise of the {object}",
|
97 |
+
"the {object} creating {volume} road noise",
|
98 |
+
"{volume} sounds of the {object} in motion"
|
99 |
+
],
|
100 |
+
"nature": [
|
101 |
+
"{volume} rustling of the {object}",
|
102 |
+
"the {object} making {volume} natural sounds",
|
103 |
+
"{volume} environmental sounds from the {object}",
|
104 |
+
"the {object} with {volume} movement in the wind"
|
105 |
+
],
|
106 |
+
"urban": [
|
107 |
+
"{volume} urban sounds around the {object}",
|
108 |
+
"the {object} with {volume} city ambience",
|
109 |
+
"{volume} noise from the {object}",
|
110 |
+
"the {object} contributing to {volume} street sounds"
|
111 |
+
],
|
112 |
+
"animal": [
|
113 |
+
"{volume} calls from the {object}",
|
114 |
+
"the {object} making {volume} animal sounds",
|
115 |
+
"{volume} sounds of the {object}",
|
116 |
+
"the {object} with its {volume} presence"
|
117 |
+
],
|
118 |
+
"human": [
|
119 |
+
"{volume} voices from the {object}",
|
120 |
+
"the {object} creating {volume} human sounds",
|
121 |
+
"{volume} movement sounds from the {object}",
|
122 |
+
"the {object} with {volume} activity"
|
123 |
+
],
|
124 |
+
"indoor": [
|
125 |
+
"{volume} ambient sounds around the {object}",
|
126 |
+
"the {object} making {volume} indoor noises",
|
127 |
+
"{volume} mechanical sounds from the {object}",
|
128 |
+
"the {object} with its {volume} presence"
|
129 |
+
],
|
130 |
+
"unknown": [
|
131 |
+
"{volume} sounds from the {object}",
|
132 |
+
"the {object} creating {volume} audio",
|
133 |
+
"{volume} noises associated with the {object}",
|
134 |
+
"the {object} with its {volume} acoustic presence"
|
135 |
+
]
|
136 |
+
}
|
137 |
+
|
138 |
+
# Select a template for this category
|
139 |
+
templates = sound_templates.get(category, sound_templates["unknown"])
|
140 |
+
template = random.choice(templates)
|
141 |
+
|
142 |
+
# Fill in the template
|
143 |
+
description = template.format(volume=volume, object=object_name)
|
144 |
+
return description
|
145 |
+
|
146 |
+
def create_audio_prompt(self, object_depths):
|
147 |
+
if not object_depths:
|
148 |
+
return "Environmental ambient sounds."
|
149 |
+
|
150 |
+
for obj in object_depths:
|
151 |
+
if obj.get("sound_description") and len(obj["sound_description"]) > 5:
|
152 |
+
return obj["sound_description"]
|
153 |
+
return f"Sounds of {object_depths[0]['original_label']}."
|
154 |
+
|
155 |
+
def process_and_generate_audio(self, object_depths, output_path=None, duration=10, steps=25, guidance_scale=4.5):
|
156 |
+
self._load_model()
|
157 |
+
|
158 |
+
if not object_depths:
|
159 |
+
prompt = "Environmental ambient sounds."
|
160 |
+
else:
|
161 |
+
# Sort objects by depth to prioritize closer objects
|
162 |
+
sorted_objects = sorted(object_depths, key=lambda x: x["mean_depth"])
|
163 |
+
prompt = self.create_audio_prompt(sorted_objects)
|
164 |
+
|
165 |
+
print(f"Generated audio prompt: {prompt}")
|
166 |
+
|
167 |
+
wave = self.generate_sound(
|
168 |
+
prompt,
|
169 |
+
steps=steps,
|
170 |
+
duration=duration,
|
171 |
+
guidance_scale=guidance_scale
|
172 |
+
)
|
173 |
+
|
174 |
+
sample_rate = self.model.vae.config.sampling_rate
|
175 |
+
|
176 |
+
if output_path:
|
177 |
+
torchaudio.save(
|
178 |
+
output_path,
|
179 |
+
wave.unsqueeze(0),
|
180 |
+
sample_rate
|
181 |
+
)
|
182 |
+
print(f"Audio saved to: {output_path}")
|
183 |
+
|
184 |
+
return wave, sample_rate
|
GenerateCaptions.py
ADDED
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
streetsoundtext.py - A pipeline that downloads Google Street View panoramas,
|
4 |
+
extracts perspective views, and analyzes them for sound information.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import requests
|
9 |
+
import argparse
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import time
|
13 |
+
from PIL import Image
|
14 |
+
from io import BytesIO
|
15 |
+
from config import LOGS_DIR
|
16 |
+
import torchvision.transforms as T
|
17 |
+
from torchvision.transforms.functional import InterpolationMode
|
18 |
+
from transformers import AutoModel, AutoTokenizer
|
19 |
+
from utils import sample_perspective_img
|
20 |
+
import cv2
|
21 |
+
|
22 |
+
log_dir = LOGS_DIR
|
23 |
+
os.makedirs(log_dir, exist_ok=True) # Creates the directory if it doesn't exist
|
24 |
+
|
25 |
+
# soundscape_query = "<image>\nWhat can we expect to hear from the location captured in this image? Name the around five nouns. Avoid speculation and provide a concise response including sound sources visible in the image."
|
26 |
+
soundscape_query = """<image>
|
27 |
+
Identify 5 potential sound sources visible in this image. For each source, provide both the noun and a brief description of its typical sound.
|
28 |
+
|
29 |
+
Format your response exactly like these examples (do not include the word "Noun:" in your response):
|
30 |
+
Car: engine humming with occasional honking.
|
31 |
+
River: gentle flowing water with subtle splashing sounds.
|
32 |
+
Trees: rustling leaves moved by the wind.
|
33 |
+
"""
|
34 |
+
# Constants
|
35 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
36 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
37 |
+
|
38 |
+
# Model Leaderboard Paths
|
39 |
+
MODEL_LEADERBOARD = {
|
40 |
+
"intern_2_5-8B": "OpenGVLab/InternVL2_5-8B-MPO",
|
41 |
+
"intern_2_5-4B": "OpenGVLab/InternVL2_5-4B-MPO",
|
42 |
+
}
|
43 |
+
|
44 |
+
class StreetViewDownloader:
|
45 |
+
"""Downloads panoramic images from Google Street View"""
|
46 |
+
|
47 |
+
def __init__(self):
|
48 |
+
# URLs for API requests
|
49 |
+
# https://www.google.ca/maps/rpc/photo/listentityphotos?authuser=0&hl=en&gl=us&pb=!1e3!5m45!2m2!1i203!2i100!3m3!2i4!3sCAEIBAgFCAYgAQ!5b1!7m33!1m3!1e1!2b0!3e3!1m3!1e2!2b1!3e2!1m3!1e2!2b0!3e3!1m3!1e8!2b0!3e3!1m3!1e10!2b0!3e3!1m3!1e10!2b1!3e2!1m3!1e10!2b0!3e4!1m3!1e9!2b1!3e2!2b1!8m0!9b0!11m1!4b1!6m3!1sI63QZ8b4BcSli-gPvPHf-Qc!7e81!15i11021!9m2!2d-90.30324219145255!3d38.636242944711036!10d91.37627840655999
|
50 |
+
#self.panoid_req = 'https://www.google.com/maps/preview/reveal?authuser=0&hl=en&gl=us&pb=!2m9!1m3!1d82597.14038230096!2d{}!3d{}!2m0!3m2!1i1523!2i1272!4f13.1!3m2!2d{}!3d{}!4m2!1syPETZOjwLvCIptQPiJum-AQ!7e81!5m5!2m4!1i96!2i64!3i1!4i8'
|
51 |
+
self.panoid_req = 'https://www.google.ca/maps/rpc/photo/listentityphotos?authuser=0&hl=en&gl=us&pb=!1e3!5m45!2m2!1i203!2i100!3m3!2i4!3sCAEIBAgFCAYgAQ!5b1!7m33!1m3!1e1!2b0!3e3!1m3!1e2!2b1!3e2!1m3!1e2!2b0!3e3!1m3!1e8!2b0!3e3!1m3!1e10!2b0!3e3!1m3!1e10!2b1!3e2!1m3!1e10!2b0!3e4!1m3!1e9!2b1!3e2!2b1!8m0!9b0!11m1!4b1!6m3!1sI63QZ8b4BcSli-gPvPHf-Qc!7e81!15i11021!9m2!2d{}!3d{}!10d25'
|
52 |
+
# https://www.google.com/maps/photometa/v1?authuser=0&hl=en&gl=us&pb=!1m4!1smaps_sv.tactile!11m2!2m1!1b1!2m2!1sen!2sus!3m3!1m2!1e2!2s{}!4m61!1e1!1e2!1e3!1e4!1e5!1e6!1e8!1e12!1e17!2m1!1e1!4m1!1i48!5m1!1e1!5m1!1e2!6m1!1e1!6m1!1e2!9m36!1m3!1e2!2b1!3e2!1m3!1e2!2b0!3e3!1m3!1e3!2b1!3e2!1m3!1e3!2b0!3e3!1m3!1e8!2b0!3e3!1m3!1e1!2b0!3e3!1m3!1e4!2b0!3e3!1m3!1e10!2b1!3e2!1m3!1e10!2b0!3e3!11m2!3m1!4b1 # vmSzE7zkK2eETwAP_r8UdQ
|
53 |
+
# https://www.google.ca/maps/photometa/v1?authuser=0&hl=en&gl=us&pb=!1m4!1smaps_sv.tactile!11m2!2m1!1b1!2m2!1sen!2sus!3m3!1m2!1e2!2s{}!4m61!1e1!1e2!1e3!1e4!1e5!1e6!1e8!1e12!1e17!2m1!1e1!4m1!1i48!5m1!1e1!5m1!1e2!6m1!1e1!6m1!1e2!9m36!1m3!1e2!2b1!3e2!1m3!1e2!2b0!3e3!1m3!1e3!2b1!3e2!1m3!1e3!2b0!3e3!1m3!1e8!2b0!3e3!1m3!1e1!2b0!3e3!1m3!1e4!2b0!3e3!1m3!1e10!2b1!3e2!1m3!1e10!2b0!3e3!11m2!3m1!4b1 # -9HfuNFUDOw_IP5SA5IspA
|
54 |
+
self.photometa_req = 'https://www.google.com/maps/photometa/v1?authuser=0&hl=en&gl=us&pb=!1m4!1smaps_sv.tactile!11m2!2m1!1b1!2m2!1sen!2sus!3m5!1m2!1e2!2s{}!2m1!5s0x87d8b49f53fc92e9:0x6ecb6e520c6f4d9f!4m57!1e1!1e2!1e3!1e4!1e5!1e6!1e8!1e12!2m1!1e1!4m1!1i48!5m1!1e1!5m1!1e2!6m1!1e1!6m1!1e2!9m36!1m3!1e2!2b1!3e2!1m3!1e2!2b0!3e3!1m3!1e3!2b1!3e2!1m3!1e3!2b0!3e3!1m3!1e8!2b0!3e3!1m3!1e1!2b0!3e3!1m3!1e4!2b0!3e3!1m3!1e10!2b1!3e2!1m3!1e10!2b0!3e3'
|
55 |
+
self.panimg_req = 'https://streetviewpixels-pa.googleapis.com/v1/tile?cb_client=maps_sv.tactile&panoid={}&x={}&y={}&zoom={}'
|
56 |
+
def get_image_id(self, lat, lon):
|
57 |
+
"""Get Street View panorama ID for given coordinates"""
|
58 |
+
null = None
|
59 |
+
pr_response = requests.get(self.panoid_req.format(lon, lat, lon, lat))
|
60 |
+
if pr_response.status_code != 200:
|
61 |
+
error_message = f"Error fetching panorama ID: HTTP {pr_response.status_code}"
|
62 |
+
if pr_response.status_code == 400:
|
63 |
+
error_message += " - Bad request. Check coordinates format."
|
64 |
+
elif pr_response.status_code == 401 or pr_response.status_code == 403:
|
65 |
+
error_message += " - Authentication error. Check API key and permissions."
|
66 |
+
elif pr_response.status_code == 404:
|
67 |
+
error_message += " - No panorama found at these coordinates."
|
68 |
+
elif pr_response.status_code == 429:
|
69 |
+
error_message += " - Rate limit exceeded. Try again later."
|
70 |
+
elif pr_response.status_code >= 500:
|
71 |
+
error_message += " - Server error. Try again later."
|
72 |
+
return None
|
73 |
+
|
74 |
+
pr = BytesIO(pr_response.content).getvalue().decode('utf-8')
|
75 |
+
pr = eval(pr[pr.index('\n'):])
|
76 |
+
try:
|
77 |
+
panoid = pr[0][0][0]
|
78 |
+
except:
|
79 |
+
return None
|
80 |
+
|
81 |
+
return panoid
|
82 |
+
|
83 |
+
def download_image(self, lat, lon, zoom=1):
|
84 |
+
"""Download Street View panorama and metadata"""
|
85 |
+
null = None
|
86 |
+
panoid = self.get_image_id(lat, lon)
|
87 |
+
if panoid is None:
|
88 |
+
raise ValueError(f"get_image_id failed() at coordinates: {lat}, {lon}")
|
89 |
+
|
90 |
+
# Get metadata
|
91 |
+
pm_response = requests.get(self.photometa_req.format(panoid))
|
92 |
+
pm = BytesIO(pm_response.content).getvalue().decode('utf-8')
|
93 |
+
pm = eval(pm[pm.index('\n'):])
|
94 |
+
pan_list = pm[1][0][5][0][3][0]
|
95 |
+
|
96 |
+
# Extract relevant info
|
97 |
+
pid = pan_list[0][0][1]
|
98 |
+
plat = pan_list[0][2][0][2]
|
99 |
+
plon = pan_list[0][2][0][3]
|
100 |
+
p_orient = pan_list[0][2][2][0]
|
101 |
+
|
102 |
+
# Download image tiles and assemble panorama
|
103 |
+
img_part_inds = [(x, y) for x in range(2**zoom) for y in range(2**(zoom-1))]
|
104 |
+
img = np.zeros((512*(2**(zoom-1)), 512*(2**zoom), 3), dtype=np.uint8)
|
105 |
+
|
106 |
+
for x, y in img_part_inds:
|
107 |
+
sub_img_response = requests.get(self.panimg_req.format(pid, x, y, zoom))
|
108 |
+
sub_img = np.array(Image.open(BytesIO(sub_img_response.content)))
|
109 |
+
img[512*y:512*(y+1), 512*x:512*(x+1)] = sub_img
|
110 |
+
|
111 |
+
if (img[-1] == 0).all():
|
112 |
+
# raise ValueError("Failed to download complete panorama")
|
113 |
+
print("Failed to download complete panorama")
|
114 |
+
|
115 |
+
return img, pid, plat, plon, p_orient
|
116 |
+
|
117 |
+
|
118 |
+
class PerspectiveExtractor:
|
119 |
+
"""Extracts perspective views from panoramic images"""
|
120 |
+
|
121 |
+
def __init__(self, output_shape=(256, 256), fov=(90, 90)):
|
122 |
+
self.output_shape = output_shape
|
123 |
+
self.fov = fov
|
124 |
+
|
125 |
+
def extract_views(self, pano_img, face_size=512):
|
126 |
+
"""Extract front, back, left, and right views based on orientation"""
|
127 |
+
# orientations = {
|
128 |
+
# "front": (0, p_orient, 0), # Align front with real orientation
|
129 |
+
# "back": (0, p_orient + 180, 0), # Behind
|
130 |
+
# "left": (0, p_orient - 90, 0), # Left side
|
131 |
+
# "right": (0, p_orient + 90, 0), # Right side
|
132 |
+
# }
|
133 |
+
|
134 |
+
# cutouts = {}
|
135 |
+
# for view, rot in orientations.items():
|
136 |
+
# cutout, fov, applied_rot = sample_perspective_img(
|
137 |
+
# pano_img, self.output_shape, fov=self.fov, rot=rot
|
138 |
+
# )
|
139 |
+
# cutouts[view] = cutout
|
140 |
+
|
141 |
+
# return cutouts
|
142 |
+
"""
|
143 |
+
Convert ERP panorama to four cubic faces: Front, Left, Back, Right.
|
144 |
+
Args:
|
145 |
+
erp_img (numpy.ndarray): The input equirectangular image.
|
146 |
+
face_size (int): The size of each cubic face.
|
147 |
+
Returns:
|
148 |
+
dict: A dictionary with the four cube faces.
|
149 |
+
"""
|
150 |
+
# Get ERP dimensions
|
151 |
+
h_erp, w_erp, _ = pano_img.shape
|
152 |
+
# Define cube face directions (yaw, pitch, roll)
|
153 |
+
cube_faces = {
|
154 |
+
"front": (0, 0),
|
155 |
+
"left": (90, 0),
|
156 |
+
"back": (180, 0),
|
157 |
+
"right": (-90, 0),
|
158 |
+
}
|
159 |
+
# Output faces
|
160 |
+
faces = {}
|
161 |
+
# Generate each face
|
162 |
+
for face_name, (yaw, pitch) in cube_faces.items():
|
163 |
+
# Create a perspective transformation matrix
|
164 |
+
fov = 90 # Field of view
|
165 |
+
K = np.array([
|
166 |
+
[face_size / (2 * np.tan(np.radians(fov / 2))), 0, face_size / 2],
|
167 |
+
[0, face_size / (2 * np.tan(np.radians(fov / 2))), face_size / 2],
|
168 |
+
[0, 0, 1]
|
169 |
+
])
|
170 |
+
# Generate 3D world coordinates for the cube face
|
171 |
+
x, y = np.meshgrid(np.linspace(-1, 1, face_size), np.linspace(-1, 1, face_size))
|
172 |
+
z = np.ones_like(x)
|
173 |
+
# Normalize 3D points
|
174 |
+
points_3d = np.stack((x, y, z), axis=-1) # Shape: (H, W, 3)
|
175 |
+
points_3d /= np.linalg.norm(points_3d, axis=-1, keepdims=True)
|
176 |
+
# Apply rotation to align with the cube face
|
177 |
+
yaw_rad, pitch_rad = np.radians(yaw), np.radians(pitch)
|
178 |
+
Ry = np.array([[np.cos(yaw_rad), 0, np.sin(yaw_rad)], [0, 1, 0], [-np.sin(yaw_rad), 0, np.cos(yaw_rad)]])
|
179 |
+
Rx = np.array([[1, 0, 0], [0, np.cos(pitch_rad), -np.sin(pitch_rad)], [0, np.sin(pitch_rad), np.cos(pitch_rad)]])
|
180 |
+
R = Ry @ Rx
|
181 |
+
# Rotate points
|
182 |
+
points_3d_rot = np.einsum('ij,hwj->hwi', R, points_3d)
|
183 |
+
# Convert 3D to spherical coordinates
|
184 |
+
lon = np.arctan2(points_3d_rot[..., 0], points_3d_rot[..., 2])
|
185 |
+
lat = np.arcsin(points_3d_rot[..., 1])
|
186 |
+
# Map spherical coordinates to ERP image coordinates
|
187 |
+
x_erp = (w_erp * (lon / (2 * np.pi) + 0.5)).astype(np.float32)
|
188 |
+
y_erp = (h_erp * (0.5 - lat / np.pi)).astype(np.float32)
|
189 |
+
# Sample pixels from ERP image
|
190 |
+
face_img = cv2.remap(pano_img, x_erp, y_erp, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP)
|
191 |
+
cv2.rotate(face_img, cv2.ROTATE_180, face_img)
|
192 |
+
faces[face_name] = face_img
|
193 |
+
return faces
|
194 |
+
|
195 |
+
|
196 |
+
class ImageAnalyzer:
|
197 |
+
"""Analyzes images using Vision-Language Models"""
|
198 |
+
|
199 |
+
def __init__(self, model_name="intern_2_5-4B", use_cuda=True):
|
200 |
+
self.model_name = model_name
|
201 |
+
self.use_cuda = use_cuda and torch.cuda.is_available()
|
202 |
+
self.model, self.tokenizer, self.device = self._load_model()
|
203 |
+
|
204 |
+
def _load_model(self):
|
205 |
+
"""Load selected Vision-Language Model"""
|
206 |
+
if self.model_name not in MODEL_LEADERBOARD:
|
207 |
+
raise ValueError(f"Model '{self.model_name}' not found. Choose from: {list(MODEL_LEADERBOARD.keys())}")
|
208 |
+
|
209 |
+
model_path = MODEL_LEADERBOARD[self.model_name]
|
210 |
+
|
211 |
+
# Configure device and parameters
|
212 |
+
if self.use_cuda:
|
213 |
+
device = torch.device("cuda")
|
214 |
+
torch_dtype = torch.bfloat16
|
215 |
+
use_flash_attn = True
|
216 |
+
else:
|
217 |
+
device = torch.device("cpu")
|
218 |
+
torch_dtype = torch.float32
|
219 |
+
use_flash_attn = False
|
220 |
+
|
221 |
+
# Load model and tokenizer
|
222 |
+
model = AutoModel.from_pretrained(
|
223 |
+
model_path,
|
224 |
+
torch_dtype=torch_dtype,
|
225 |
+
load_in_8bit=False,
|
226 |
+
low_cpu_mem_usage=True,
|
227 |
+
use_flash_attn=use_flash_attn,
|
228 |
+
trust_remote_code=True,
|
229 |
+
).eval().to(device)
|
230 |
+
|
231 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
232 |
+
model_path,
|
233 |
+
trust_remote_code=True,
|
234 |
+
use_fast=False
|
235 |
+
)
|
236 |
+
|
237 |
+
return model, tokenizer, device
|
238 |
+
|
239 |
+
def _build_transform(self, input_size=448):
|
240 |
+
"""Create image transformation pipeline"""
|
241 |
+
transform = T.Compose([
|
242 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
243 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
244 |
+
T.ToTensor(),
|
245 |
+
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
246 |
+
])
|
247 |
+
return transform
|
248 |
+
|
249 |
+
def _find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
250 |
+
"""Find closest aspect ratio for image tiling"""
|
251 |
+
best_ratio_diff = float('inf')
|
252 |
+
best_ratio = (1, 1)
|
253 |
+
area = width * height
|
254 |
+
for ratio in target_ratios:
|
255 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
256 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
257 |
+
if ratio_diff < best_ratio_diff:
|
258 |
+
best_ratio_diff = ratio_diff
|
259 |
+
best_ratio = ratio
|
260 |
+
elif ratio_diff == best_ratio_diff:
|
261 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
262 |
+
best_ratio = ratio
|
263 |
+
return best_ratio
|
264 |
+
|
265 |
+
def _preprocess_image(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
266 |
+
"""Preprocess image for model input"""
|
267 |
+
orig_width, orig_height = image.size
|
268 |
+
aspect_ratio = orig_width / orig_height
|
269 |
+
|
270 |
+
# Calculate possible image aspect ratios
|
271 |
+
target_ratios = set(
|
272 |
+
(i, j) for n in range(min_num, max_num + 1)
|
273 |
+
for i in range(1, n + 1)
|
274 |
+
for j in range(1, n + 1)
|
275 |
+
if i * j <= max_num and i * j >= min_num
|
276 |
+
)
|
277 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
278 |
+
|
279 |
+
# Find closest aspect ratio
|
280 |
+
target_aspect_ratio = self._find_closest_aspect_ratio(
|
281 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
282 |
+
)
|
283 |
+
|
284 |
+
# Calculate target dimensions
|
285 |
+
target_width = image_size * target_aspect_ratio[0]
|
286 |
+
target_height = image_size * target_aspect_ratio[1]
|
287 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
288 |
+
|
289 |
+
# Resize and split image
|
290 |
+
resized_img = image.resize((target_width, target_height))
|
291 |
+
processed_images = []
|
292 |
+
for i in range(blocks):
|
293 |
+
box = (
|
294 |
+
(i % (target_width // image_size)) * image_size,
|
295 |
+
(i // (target_width // image_size)) * image_size,
|
296 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
297 |
+
((i // (target_width // image_size)) + 1) * image_size
|
298 |
+
)
|
299 |
+
split_img = resized_img.crop(box)
|
300 |
+
processed_images.append(split_img)
|
301 |
+
|
302 |
+
assert len(processed_images) == blocks
|
303 |
+
if use_thumbnail and len(processed_images) != 1:
|
304 |
+
thumbnail_img = image.resize((image_size, image_size))
|
305 |
+
processed_images.append(thumbnail_img)
|
306 |
+
|
307 |
+
return processed_images
|
308 |
+
|
309 |
+
def load_image(self, image_path, input_size=448, max_num=12):
|
310 |
+
"""Load and process image for analysis"""
|
311 |
+
image = Image.open(image_path).convert('RGB')
|
312 |
+
transform = self._build_transform(input_size)
|
313 |
+
images = self._preprocess_image(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
314 |
+
pixel_values = [transform(image) for image in images]
|
315 |
+
pixel_values = torch.stack(pixel_values)
|
316 |
+
return pixel_values
|
317 |
+
|
318 |
+
def analyze_image(self, image_path, max_num=12):
|
319 |
+
"""Analyze image for expected sounds"""
|
320 |
+
# Load and process image
|
321 |
+
pixel_values = self.load_image(image_path, max_num=max_num)
|
322 |
+
|
323 |
+
# Move to device with appropriate dtype
|
324 |
+
if self.device.type == "cuda":
|
325 |
+
pixel_values = pixel_values.to(torch.bfloat16).to(self.device)
|
326 |
+
else:
|
327 |
+
pixel_values = pixel_values.to(torch.float32).to(self.device)
|
328 |
+
|
329 |
+
# Create sound-focused query
|
330 |
+
query = soundscape_query
|
331 |
+
|
332 |
+
# Generate response
|
333 |
+
generation_config = dict(max_new_tokens=1024, do_sample=True)
|
334 |
+
response = self.model.chat(self.tokenizer, pixel_values, query, generation_config)
|
335 |
+
|
336 |
+
return response
|
337 |
+
|
338 |
+
|
339 |
+
class StreetSoundTextPipeline:
|
340 |
+
"""Complete pipeline for Street View sound analysis"""
|
341 |
+
|
342 |
+
def __init__(self, log_dir="logs", model_name="intern_2_5-4B", use_cuda=True):
|
343 |
+
# Create log directory if it doesn't exist
|
344 |
+
self.log_dir = log_dir
|
345 |
+
os.makedirs(log_dir, exist_ok=True)
|
346 |
+
|
347 |
+
# Initialize components
|
348 |
+
self.downloader = StreetViewDownloader()
|
349 |
+
self.extractor = PerspectiveExtractor()
|
350 |
+
# self.analyzer = ImageAnalyzer(model_name=model_name, use_cuda=use_cuda)
|
351 |
+
self.analyzer = None
|
352 |
+
self.model_name = model_name
|
353 |
+
self.use_cuda = use_cuda
|
354 |
+
|
355 |
+
def _load_analyzer(self):
|
356 |
+
if self.analyzer is None:
|
357 |
+
self.analyzer = ImageAnalyzer(model_name=self.model_name, use_cuda=self.use_cuda)
|
358 |
+
|
359 |
+
def _unload_analyzer(self):
|
360 |
+
if self.analyzer is not None:
|
361 |
+
if hasattr(self.analyzer, 'model') and self.analyzer.model is not None:
|
362 |
+
self.analyzer.model = self.analyzer.model.to("cpu")
|
363 |
+
del self.analyzer.model
|
364 |
+
self.analyzer.model = None
|
365 |
+
torch.cuda.empty_cache()
|
366 |
+
self.analyzer = None
|
367 |
+
|
368 |
+
def process(self, lat, lon, view, panoramic=False):
|
369 |
+
"""
|
370 |
+
Process a location to generate sound description for specified view or all views
|
371 |
+
|
372 |
+
Args:
|
373 |
+
lat (float): Latitude
|
374 |
+
lon (float): Longitude
|
375 |
+
view (str): Perspective view ('front', 'back', 'left', 'right')
|
376 |
+
panoramic (bool): If True, process all views instead of just the specified one
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
dict or list: Results including panorama info and sound description(s)
|
380 |
+
"""
|
381 |
+
if view not in ["front", "back", "left", "right"]:
|
382 |
+
raise ValueError(f"Invalid view: {view}. Choose from: front, back, left, right")
|
383 |
+
|
384 |
+
# Step 1: Download panoramic image
|
385 |
+
print(f"Downloading Street View panorama for coordinates: {lat}, {lon}")
|
386 |
+
|
387 |
+
pano_path = os.path.join(self.log_dir, "panorama.jpg")
|
388 |
+
pano_img, pid, plat, plon, p_orient = self.downloader.download_image(lat, lon)
|
389 |
+
Image.fromarray(pano_img).save(pano_path)
|
390 |
+
|
391 |
+
# Step 2: Extract perspective views
|
392 |
+
print(f"Extracting perspective views with orientation: {p_orient}°")
|
393 |
+
cutouts = self.extractor.extract_views(pano_img, 512)
|
394 |
+
|
395 |
+
# Save all views
|
396 |
+
for v, img in cutouts.items():
|
397 |
+
view_path = os.path.join(self.log_dir, f"{v}.jpg")
|
398 |
+
Image.fromarray(img).save(view_path)
|
399 |
+
|
400 |
+
self._load_analyzer()
|
401 |
+
print("\n[DEBUG] Current soundscape query:")
|
402 |
+
print(soundscape_query)
|
403 |
+
print("-" * 50)
|
404 |
+
if panoramic:
|
405 |
+
# Process all views
|
406 |
+
print(f"Analyzing all views for sound information")
|
407 |
+
results = []
|
408 |
+
|
409 |
+
for current_view in ["front", "back", "left", "right"]:
|
410 |
+
view_path = os.path.join(self.log_dir, f"{current_view}.jpg")
|
411 |
+
sound_description = self.analyzer.analyze_image(view_path)
|
412 |
+
|
413 |
+
view_result = {
|
414 |
+
"panorama_id": pid,
|
415 |
+
"coordinates": {"lat": plat, "lon": plon},
|
416 |
+
"orientation": p_orient,
|
417 |
+
"view": current_view,
|
418 |
+
"sound_description": sound_description,
|
419 |
+
"files": {
|
420 |
+
"panorama": pano_path,
|
421 |
+
"view_path": view_path
|
422 |
+
}
|
423 |
+
}
|
424 |
+
results.append(view_result)
|
425 |
+
|
426 |
+
self._unload_analyzer()
|
427 |
+
return results
|
428 |
+
else:
|
429 |
+
# Process only the selected view
|
430 |
+
view_path = os.path.join(self.log_dir, f"{view}.jpg")
|
431 |
+
print(f"Analyzing {view} view for sound information")
|
432 |
+
sound_description = self.analyzer.analyze_image(view_path)
|
433 |
+
|
434 |
+
self._unload_analyzer()
|
435 |
+
|
436 |
+
# Prepare results
|
437 |
+
results = {
|
438 |
+
"panorama_id": pid,
|
439 |
+
"coordinates": {"lat": plat, "lon": plon},
|
440 |
+
"orientation": p_orient,
|
441 |
+
"view": view,
|
442 |
+
"sound_description": sound_description,
|
443 |
+
"files": {
|
444 |
+
"panorama": pano_path,
|
445 |
+
"views": {v: os.path.join(self.log_dir, f"{v}.jpg") for v in cutouts.keys()}
|
446 |
+
}
|
447 |
+
}
|
448 |
+
|
449 |
+
return results
|
450 |
+
|
451 |
+
|
452 |
+
def parse_location(location_str):
|
453 |
+
"""Parse location string in format 'lat,lon' into float tuple"""
|
454 |
+
try:
|
455 |
+
lat, lon = map(float, location_str.split(','))
|
456 |
+
return lat, lon
|
457 |
+
except ValueError:
|
458 |
+
raise argparse.ArgumentTypeError("Location must be in format 'latitude,longitude'")
|
459 |
+
|
460 |
+
|
461 |
+
def generate_caption(lat, lon, view="front", model="intern_2_5-4B", cpu_only=False, panoramic=False):
|
462 |
+
"""
|
463 |
+
Generate sound captions for one or all views of a street view location
|
464 |
+
|
465 |
+
Args:
|
466 |
+
lat (float/str): Latitude
|
467 |
+
lon (float/str): Longitude
|
468 |
+
view (str): Perspective view ('front', 'back', 'left', 'right')
|
469 |
+
model (str): Model name to use for analysis
|
470 |
+
cpu_only (bool): Whether to force CPU usage
|
471 |
+
panoramic (bool): If True, process all views instead of just the specified one
|
472 |
+
|
473 |
+
Returns:
|
474 |
+
dict or list: Results with sound descriptions
|
475 |
+
"""
|
476 |
+
pipeline = StreetSoundTextPipeline(
|
477 |
+
log_dir=log_dir,
|
478 |
+
model_name=model,
|
479 |
+
use_cuda=not cpu_only
|
480 |
+
)
|
481 |
+
|
482 |
+
try:
|
483 |
+
results = pipeline.process(lat, lon, view, panoramic=panoramic)
|
484 |
+
|
485 |
+
if panoramic:
|
486 |
+
# Process results for all views
|
487 |
+
print(f"Generated captions for all views at location: {lat}, {lon}")
|
488 |
+
else:
|
489 |
+
print(f"Generated caption for {view} view at location: {lat}, {lon}")
|
490 |
+
|
491 |
+
return results
|
492 |
+
except Exception as e:
|
493 |
+
print(f"Error: {str(e)}")
|
494 |
+
return None
|
README.md
CHANGED
@@ -1,13 +1,50 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**A training-free pipeline utilizing pre-trained generative models to synthesize sound for any street on Earth with available Street View panoramic images.**
|
2 |
+
|
3 |
+
1. Change to this directory:
|
4 |
+
```
|
5 |
+
cd SoundingStreet
|
6 |
+
```
|
7 |
+
|
8 |
+
2. Create the conda environment:
|
9 |
+
```
|
10 |
+
conda env create -f environment.yml
|
11 |
+
conda activate geosynthsound
|
12 |
+
```
|
13 |
+
|
14 |
+
3. Make sure to create necessary directories:
|
15 |
+
```
|
16 |
+
mkdir -p logs output
|
17 |
+
```
|
18 |
+
|
19 |
+
4. Download checkpoint for depth estimator model:
|
20 |
+
```
|
21 |
+
wget https://ommer-lab.com/files/depthfm/depthfm-v1.ckpt -P external_models/depth-fm/checkpoints/
|
22 |
+
```
|
23 |
+
|
24 |
+
5. Run the `SoundingStreet` demo:
|
25 |
+
```
|
26 |
+
python main.py --panoramic --location "52.3436723,4.8529625"
|
27 |
+
```
|
28 |
+
Intermediate files such as the downloaded panoramic image and perspective cut-outs can be found in `./logs/`, and output audios for each view as well as the composite audio for the location are saved as `./output/panoramic_composition.wav`
|
29 |
+
|
30 |
+
|
31 |
+
## Acknowledgements
|
32 |
+
|
33 |
+
- **InternVL2.5-8B-MPO**
|
34 |
+
For vision-language modeling, we employ InternVL2.5-8B-MPO, which is released under the MIT License.
|
35 |
+
GitHub: https://github.com/OpenGVLab/InternVL
|
36 |
+
|
37 |
+
- **Grounding DINO**
|
38 |
+
We use Grounding DINO for open-set object detection. Grounding DINO is released under the Apache 2.0 License.
|
39 |
+
GitHub: https://github.com/IDEA-Research/GroundingDINO
|
40 |
+
|
41 |
+
- **DepthFM**
|
42 |
+
We utilize the DepthFM model for monocular depth estimation. DepthFM is released under the MIT License.
|
43 |
+
GitHub: https://github.com/CompVis/depth-fm
|
44 |
+
|
45 |
+
- **TangoFlux**
|
46 |
+
We incorporate TangoFlux for text-to-audio generation. TangoFlux is available for non-commercial research use only and is subject to the Stability AI Community License, WavCaps license, and the original licenses of the datasets used in training.
|
47 |
+
GitHub: https://github.com/declare-lab/TangoFlux
|
48 |
+
|
49 |
+
|
50 |
+
Our repository's license and usage terms adhere to the respective licenses of these models.
|
SoundMapper.py
ADDED
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from DepthEstimator import DepthEstimator
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
from GenerateCaptions import generate_caption
|
6 |
+
import re
|
7 |
+
from config import LOGS_DIR
|
8 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
9 |
+
import torch
|
10 |
+
from PIL import Image, ImageDraw, ImageFont
|
11 |
+
import spacy
|
12 |
+
import gc
|
13 |
+
|
14 |
+
class SoundMapper:
|
15 |
+
def __init__(self):
|
16 |
+
self.depth_estimator = DepthEstimator()
|
17 |
+
# List of depth maps in dict["predicted_depth" ,"depth"] in (tensor, PIL.Image) format
|
18 |
+
self.device = "cuda"
|
19 |
+
# self.map_list = self.depth_estimator.estimate_depth(self.depth_estimator.image_dir)
|
20 |
+
self.map_list = None
|
21 |
+
self.image_dir = self.depth_estimator.image_dir
|
22 |
+
# self.nlp = spacy.load("en_core_web_sm")
|
23 |
+
self.nlp = None
|
24 |
+
self.dino = None
|
25 |
+
self.dino_processor = None
|
26 |
+
# self.dino = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(self.device)
|
27 |
+
# self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
|
28 |
+
|
29 |
+
def _load_nlp(self):
|
30 |
+
if self.nlp is None:
|
31 |
+
self.nlp = spacy.load("en_core_web_sm")
|
32 |
+
return self.nlp
|
33 |
+
|
34 |
+
def _load_depth_maps(self):
|
35 |
+
if self.map_list is None:
|
36 |
+
self.map_list = self.depth_estimator.estimate_depth(self.depth_estimator.image_dir)
|
37 |
+
return self.map_list
|
38 |
+
|
39 |
+
def process_depth_maps(self) -> list:
|
40 |
+
depth_maps = self._load_depth_maps()
|
41 |
+
processed_maps = []
|
42 |
+
for item in depth_maps:
|
43 |
+
depth_map = item["depth"]
|
44 |
+
depth_array = np.array(depth_map)
|
45 |
+
normalization = depth_array / 255.0
|
46 |
+
processed_maps.append({
|
47 |
+
"original": depth_map,
|
48 |
+
"normalization": normalization
|
49 |
+
})
|
50 |
+
return processed_maps
|
51 |
+
|
52 |
+
# def create_depth_zone(self, processed_maps : list, num_zones = 3):
|
53 |
+
# zones_data = []
|
54 |
+
# for depth_data in processed_maps:
|
55 |
+
# normalized = depth_data["normalization"]
|
56 |
+
# thresholds = np.linspace(0, 1, num_zones+1)
|
57 |
+
# zones = []
|
58 |
+
# for i in range(num_zones):
|
59 |
+
# zone_mask = (normalized >= thresholds[i]) & (normalized < thresholds[i+1])
|
60 |
+
# zone_percentage = zone_mask.sum() / zone_mask.size
|
61 |
+
# zones.append({
|
62 |
+
# "range": (thresholds[i], thresholds[i+1]),
|
63 |
+
# "percentage": zone_percentage,
|
64 |
+
# "mask": zone_mask
|
65 |
+
# })
|
66 |
+
# zones_data.append(zones)
|
67 |
+
# return zones_data
|
68 |
+
|
69 |
+
def detect_sound_sources(self, caption_text: str) -> dict:
|
70 |
+
"""
|
71 |
+
Extract nouns and their sound descriptions from caption text.
|
72 |
+
Returns a dictionary mapping nouns to their descriptions.
|
73 |
+
"""
|
74 |
+
sound_sources = {}
|
75 |
+
nlp = self._load_nlp()
|
76 |
+
|
77 |
+
print(f"\n[DEBUG] Beginning sound source detection")
|
78 |
+
print(f"Raw caption text length: {len(caption_text)}")
|
79 |
+
print(f"First 100 chars: {caption_text[:100]}...")
|
80 |
+
|
81 |
+
# Split the caption by newlines to separate entries
|
82 |
+
lines = caption_text.strip().split('\n')
|
83 |
+
print(f"Found {len(lines)} lines after splitting")
|
84 |
+
|
85 |
+
for i, line in enumerate(lines):
|
86 |
+
# Skip empty lines
|
87 |
+
if not line.strip():
|
88 |
+
continue
|
89 |
+
|
90 |
+
print(f"Processing line {i}: {line[:50]}{'...' if len(line) > 50 else ''}")
|
91 |
+
|
92 |
+
# Check if line matches the expected format (Noun: description)
|
93 |
+
if ':' in line:
|
94 |
+
parts = line.split(':', 1) # Split only on the first colon
|
95 |
+
|
96 |
+
# Clean up the noun part - remove numbers and leading/trailing whitespace
|
97 |
+
noun_part = parts[0].strip().lower()
|
98 |
+
# Remove list numbering (e.g., "1. ", "2. ", etc.)
|
99 |
+
noun_part = re.sub(r'^\d+\.\s*', '', noun_part)
|
100 |
+
|
101 |
+
description = parts[1].strip()
|
102 |
+
|
103 |
+
# Clean any markdown formatting
|
104 |
+
noun = re.sub(r'[*()]', '', noun_part).strip()
|
105 |
+
description = re.sub(r'[*()]', '', description).strip()
|
106 |
+
|
107 |
+
# Separate the description at em dash if present
|
108 |
+
if ' — ' in description:
|
109 |
+
description = description.split(' — ', 1)[0].strip()
|
110 |
+
elif ' - ' in description:
|
111 |
+
description = description.split(' - ', 1)[0].strip()
|
112 |
+
|
113 |
+
print(f" - Found potential noun: '{noun}' with description: '{description[:30]}...'")
|
114 |
+
|
115 |
+
# Skip if noun contains invalid characters or is too short
|
116 |
+
if '##' not in noun and len(noun) > 1 and noun[0].isalpha():
|
117 |
+
sound_sources[noun] = description
|
118 |
+
print(f" √ Added to sound sources")
|
119 |
+
else:
|
120 |
+
print(f" × Skipped (invalid format)")
|
121 |
+
|
122 |
+
# If no structured format found, try to extract nouns from the text
|
123 |
+
if not sound_sources:
|
124 |
+
print("No structured format found, falling back to noun extraction")
|
125 |
+
all_nouns = []
|
126 |
+
doc = nlp(caption_text)
|
127 |
+
for token in doc:
|
128 |
+
if token.pos_ == "NOUN" and len(token.text) > 1:
|
129 |
+
if token.text[0].isalpha():
|
130 |
+
all_nouns.append(token.text.lower())
|
131 |
+
print(f" - Extracted noun: '{token.text.lower()}'")
|
132 |
+
|
133 |
+
for noun in all_nouns:
|
134 |
+
sound_sources[noun] = "" # Empty description
|
135 |
+
|
136 |
+
print(f"[DEBUG] Final detected sound sources: {list(sound_sources.keys())}")
|
137 |
+
return sound_sources
|
138 |
+
|
139 |
+
def map_bbox_to_depth_zone(self, bbox, depth_map, num_zones=3):
|
140 |
+
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
141 |
+
|
142 |
+
height, width = depth_map.shape
|
143 |
+
x1, y1 = max(0, x1), max(0, y1)
|
144 |
+
x2, y2 = min(width, x2), min(height, y2)
|
145 |
+
|
146 |
+
depth_roi = depth_map[y1:y2, x1:x2]
|
147 |
+
|
148 |
+
if depth_roi.size == 0:
|
149 |
+
return num_zones - 1
|
150 |
+
|
151 |
+
mean_depth = np.mean(depth_roi)
|
152 |
+
|
153 |
+
thresholds = self.create_histogram_depth_zones(depth_map, num_zones)
|
154 |
+
for i in range(num_zones):
|
155 |
+
if thresholds[i] <= mean_depth < thresholds[i+1]:
|
156 |
+
return i
|
157 |
+
return num_zones - 1
|
158 |
+
|
159 |
+
def detect_objects(self, nouns : list, image: Image):
|
160 |
+
filtered_nouns = []
|
161 |
+
for noun in nouns:
|
162 |
+
if '##' not in noun and len(noun) > 1 and noun[0].isalpha():
|
163 |
+
filtered_nouns.append(noun)
|
164 |
+
|
165 |
+
print(f"Detecting objects for nouns: {filtered_nouns}")
|
166 |
+
|
167 |
+
if self.dino is None:
|
168 |
+
self.dino = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(self.device)
|
169 |
+
self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
|
170 |
+
else:
|
171 |
+
self.dino = self.dino.to(self.device)
|
172 |
+
|
173 |
+
text_prompt = " . ".join(filtered_nouns)
|
174 |
+
inputs = self.dino_processor(images=image, text=text_prompt, return_tensors="pt").to(self.device)
|
175 |
+
|
176 |
+
with torch.no_grad():
|
177 |
+
outputs = self.dino(**inputs)
|
178 |
+
results = self.dino_processor.post_process_grounded_object_detection(
|
179 |
+
outputs,
|
180 |
+
inputs.input_ids,
|
181 |
+
box_threshold=0.25,
|
182 |
+
text_threshold=0.25,
|
183 |
+
target_sizes=[image.size[::-1]]
|
184 |
+
)
|
185 |
+
|
186 |
+
result = results[0]
|
187 |
+
labels = result["labels"]
|
188 |
+
bboxes = result["boxes"]
|
189 |
+
|
190 |
+
clean_labels = []
|
191 |
+
for label in labels:
|
192 |
+
clean_label = re.sub(r'##\w+', '', label)
|
193 |
+
clean_label = self._split_combined_words(clean_label, filtered_nouns)
|
194 |
+
clean_labels.append(clean_label)
|
195 |
+
|
196 |
+
self.dino = self.dino.to("cpu")
|
197 |
+
torch.cuda.empty_cache()
|
198 |
+
del inputs, outputs, results
|
199 |
+
|
200 |
+
print(f"Detected objects: {clean_labels}")
|
201 |
+
|
202 |
+
return (clean_labels, bboxes)
|
203 |
+
|
204 |
+
def _split_combined_words(self, text, nouns=None):
|
205 |
+
nlp = self._load_nlp()
|
206 |
+
if nouns is None:
|
207 |
+
known_words = set()
|
208 |
+
doc = nlp(text)
|
209 |
+
for token in doc:
|
210 |
+
if token.pos_ == "NOUN" and len(token.text) > 1:
|
211 |
+
known_words.add(token.text.lower())
|
212 |
+
else:
|
213 |
+
known_words = set(nouns)
|
214 |
+
|
215 |
+
result = []
|
216 |
+
for word in text.split():
|
217 |
+
if word in known_words:
|
218 |
+
result.append(word)
|
219 |
+
continue
|
220 |
+
|
221 |
+
found = False
|
222 |
+
for known in known_words:
|
223 |
+
if known in word and len(known) > 2:
|
224 |
+
result.append(known)
|
225 |
+
found = True
|
226 |
+
|
227 |
+
if not found:
|
228 |
+
result.append(word)
|
229 |
+
|
230 |
+
return " ".join(result)
|
231 |
+
|
232 |
+
def process_dino_labels(self, labels):
|
233 |
+
processed_labels = []
|
234 |
+
nlp = self._load_nlp()
|
235 |
+
|
236 |
+
for label in labels:
|
237 |
+
if label.startswith('##'):
|
238 |
+
continue
|
239 |
+
label = re.sub(r'[*()]', '', label).strip()
|
240 |
+
|
241 |
+
parts = label.split()
|
242 |
+
for part in parts:
|
243 |
+
if part.startswith('##'):
|
244 |
+
continue
|
245 |
+
doc = nlp(part)
|
246 |
+
for token in doc:
|
247 |
+
if token.pos_ == "NOUN" and len(token.text) > 1:
|
248 |
+
processed_labels.append(token.text.lower())
|
249 |
+
|
250 |
+
unique_labels = []
|
251 |
+
for label in processed_labels:
|
252 |
+
if label not in unique_labels:
|
253 |
+
unique_labels.append(label)
|
254 |
+
|
255 |
+
return unique_labels
|
256 |
+
|
257 |
+
|
258 |
+
def create_histogram_depth_zones(self, depth_map, num_zones = 3):
|
259 |
+
# using 50 bins because it is faster
|
260 |
+
hist, bin_edge = np.histogram(depth_map.flatten(), bins=50, range=(0, 1))
|
261 |
+
cumulative = np.cumsum(hist) / np.sum(hist)
|
262 |
+
thresholds = [0.0]
|
263 |
+
for i in range(1, num_zones):
|
264 |
+
target = i / num_zones
|
265 |
+
idx = np.argmin(np.abs(cumulative - target))
|
266 |
+
thresholds.append(bin_edge[idx + 1])
|
267 |
+
thresholds.append(1.0)
|
268 |
+
|
269 |
+
return thresholds
|
270 |
+
|
271 |
+
|
272 |
+
def analyze_object_depths(self, image_path, depth_map, lat, lon, caption_data=None, all_objects=False):
|
273 |
+
image = Image.open(image_path)
|
274 |
+
|
275 |
+
if caption_data is None:
|
276 |
+
caption = generate_caption(lat, lon)
|
277 |
+
if not caption:
|
278 |
+
print(f"Failed to generate caption for {image_path}")
|
279 |
+
return []
|
280 |
+
caption_text = caption.get("sound_description", "")
|
281 |
+
else:
|
282 |
+
caption_text = caption_data.get("sound_description", "")
|
283 |
+
|
284 |
+
# Debug: Print the raw caption text
|
285 |
+
print(f"\n[DEBUG] Raw caption text for {os.path.basename(image_path)}:")
|
286 |
+
print(caption_text)
|
287 |
+
print("-" * 50)
|
288 |
+
|
289 |
+
if not caption_text:
|
290 |
+
print(f"No caption text available for {image_path}")
|
291 |
+
return []
|
292 |
+
|
293 |
+
# Extract nouns and their sound descriptions
|
294 |
+
sound_sources = self.detect_sound_sources(caption_text)
|
295 |
+
|
296 |
+
# Debug: Print the extracted sound sources
|
297 |
+
print(f"[DEBUG] Extracted sound sources:")
|
298 |
+
for noun, desc in sound_sources.items():
|
299 |
+
print(f" - {noun}: {desc}")
|
300 |
+
print("-" * 50)
|
301 |
+
|
302 |
+
if not sound_sources:
|
303 |
+
print(f"No sound sources detected in caption for {image_path}")
|
304 |
+
return []
|
305 |
+
|
306 |
+
# Get list of nouns only for object detection
|
307 |
+
nouns = list(sound_sources.keys())
|
308 |
+
|
309 |
+
# Debug: Print the list of nouns being used for detection
|
310 |
+
print(f"[DEBUG] Nouns for object detection: {nouns}")
|
311 |
+
print("-" * 50)
|
312 |
+
|
313 |
+
labels, bboxes = self.detect_objects(nouns, image)
|
314 |
+
if len(labels) == 0 or len(bboxes) == 0:
|
315 |
+
print(f"No objects detected in {image_path}")
|
316 |
+
return []
|
317 |
+
|
318 |
+
object_data = []
|
319 |
+
known_objects = set(nouns) if nouns else set()
|
320 |
+
|
321 |
+
for i, (label, bbox) in enumerate(zip(labels, bboxes)):
|
322 |
+
if '##' in label:
|
323 |
+
continue
|
324 |
+
|
325 |
+
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
326 |
+
height, width = depth_map.shape
|
327 |
+
x1, y1 = max(0, x1), max(0, y1)
|
328 |
+
x2, y2 = min(width, x2), min(height, y2)
|
329 |
+
|
330 |
+
depth_roi = depth_map[y1:y2, x1:x2]
|
331 |
+
if depth_roi.size == 0:
|
332 |
+
continue
|
333 |
+
|
334 |
+
mean_depth = np.mean(depth_roi)
|
335 |
+
|
336 |
+
matched_noun = None
|
337 |
+
matched_desc = None
|
338 |
+
|
339 |
+
for word in label.split():
|
340 |
+
word = word.lower()
|
341 |
+
if word in sound_sources:
|
342 |
+
matched_noun = word
|
343 |
+
matched_desc = sound_sources[word]
|
344 |
+
break
|
345 |
+
if matched_noun is None:
|
346 |
+
for noun in sound_sources:
|
347 |
+
if noun in label.lower():
|
348 |
+
matched_noun = noun
|
349 |
+
matched_desc = sound_sources[noun]
|
350 |
+
break
|
351 |
+
if matched_noun is None:
|
352 |
+
for word in label.split():
|
353 |
+
if len(word) > 1 and word[0].isalpha() and '##' not in word:
|
354 |
+
matched_noun = word.lower()
|
355 |
+
matched_desc = "" # No description available
|
356 |
+
break
|
357 |
+
|
358 |
+
if matched_noun:
|
359 |
+
thresholds = self.create_histogram_depth_zones(depth_map, num_zones=3)
|
360 |
+
zone = 0 # The default is 0 which is the closest zone
|
361 |
+
for i in range(3):
|
362 |
+
if thresholds[i] <= mean_depth < thresholds[i+1]:
|
363 |
+
zone = i
|
364 |
+
break
|
365 |
+
|
366 |
+
object_data.append({
|
367 |
+
"original_label": matched_noun,
|
368 |
+
"bbox": bbox.tolist(),
|
369 |
+
"depth_zone": zone,
|
370 |
+
"zone_description": ["near", "medium", "far"][zone],
|
371 |
+
"mean_depth": mean_depth,
|
372 |
+
"weight": 1.0 - mean_depth,
|
373 |
+
"sound_description": matched_desc
|
374 |
+
})
|
375 |
+
if all_objects:
|
376 |
+
object_data.sort(key=lambda x: x["mean_depth"])
|
377 |
+
return object_data
|
378 |
+
else:
|
379 |
+
if not object_data:
|
380 |
+
return []
|
381 |
+
closest_object = min(object_data, key=lambda x: x["mean_depth"])
|
382 |
+
return [closest_object]
|
383 |
+
|
384 |
+
def cleanup(self):
|
385 |
+
if hasattr(self, 'depth_estimator') and self.depth_estimator is not None:
|
386 |
+
del self.depth_estimator
|
387 |
+
self.depth_estimator = None
|
388 |
+
|
389 |
+
if self.map_list is not None:
|
390 |
+
del self.map_list
|
391 |
+
self.map_list = None
|
392 |
+
|
393 |
+
if self.dino is not None:
|
394 |
+
self.dino = self.dino.to("cpu")
|
395 |
+
del self.dino
|
396 |
+
self.dino = None
|
397 |
+
del self.dino_processor
|
398 |
+
self.dino_processor = None
|
399 |
+
|
400 |
+
if self.nlp is not None:
|
401 |
+
del self.nlp
|
402 |
+
self.nlp = None
|
403 |
+
torch.cuda.empty_cache()
|
404 |
+
gc.collect()
|
405 |
+
|
406 |
+
def test_object_depth_analysis(self):
|
407 |
+
"""
|
408 |
+
Test the object depth analysis on all images in the directory.
|
409 |
+
"""
|
410 |
+
# Process depth maps first
|
411 |
+
processed_maps = self.process_depth_maps()
|
412 |
+
|
413 |
+
# Get list of original image paths
|
414 |
+
image_dir = self.depth_estimator.image_dir
|
415 |
+
image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".jpg")]
|
416 |
+
|
417 |
+
results = []
|
418 |
+
|
419 |
+
# For each image and its corresponding depth map
|
420 |
+
for i, (image_path, processed_map) in enumerate(zip(image_paths, processed_maps)):
|
421 |
+
# Extract the normalized depth map
|
422 |
+
depth_map = processed_map["normalization"]
|
423 |
+
|
424 |
+
# Analyze objects and their depths
|
425 |
+
object_depths = self.analyze_object_depths(image_path, depth_map)
|
426 |
+
|
427 |
+
# Store results
|
428 |
+
results.append({
|
429 |
+
"image_path": image_path,
|
430 |
+
"object_depths": object_depths
|
431 |
+
})
|
432 |
+
|
433 |
+
# Print some information for debugging
|
434 |
+
print(f"Analyzed {image_path}:")
|
435 |
+
for obj in object_depths:
|
436 |
+
print(f" - {obj['original_label']} (Zone: {obj['zone_description']})")
|
437 |
+
|
438 |
+
return results
|
app.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
|
9 |
+
from config import LOGS_DIR, OUTPUT_DIR
|
10 |
+
from SoundMapper import SoundMapper
|
11 |
+
from GenerateAudio import GenerateAudio
|
12 |
+
from GenerateCaptions import generate_caption
|
13 |
+
from audio_mixer import compose_audio
|
14 |
+
|
15 |
+
# Ensure required directories exist
|
16 |
+
os.makedirs(LOGS_DIR, exist_ok=True)
|
17 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
18 |
+
# Prepare external model dir and download checkpoint if missing
|
19 |
+
from pathlib import Path
|
20 |
+
depthfm_ckpt = Path('external_models/depth-fm/checkpoints/depthfm-v1.ckpt')
|
21 |
+
if not depthfm_ckpt.exists():
|
22 |
+
depthfm_ckpt.parent.mkdir(parents=True, exist_ok=True)
|
23 |
+
os.system('wget https://ommer-lab.com/files/depthfm/depthfm-v1.ckpt -P external_models/depth-fm/checkpoints/')
|
24 |
+
|
25 |
+
|
26 |
+
# Clear CUDA cache between runs
|
27 |
+
def clear_cuda():
|
28 |
+
torch.cuda.empty_cache()
|
29 |
+
gc.collect()
|
30 |
+
|
31 |
+
|
32 |
+
def process_images(
|
33 |
+
image_dir: str,
|
34 |
+
output_dir: str,
|
35 |
+
panoramic: bool,
|
36 |
+
view: str,
|
37 |
+
model: str,
|
38 |
+
location: str,
|
39 |
+
audio_duration: int,
|
40 |
+
cpu_only: bool
|
41 |
+
) -> None:
|
42 |
+
# Existing processing logic, generates files in OUTPUT_DIR
|
43 |
+
lat, lon = location.split(",")
|
44 |
+
os.makedirs(output_dir, exist_ok=True)
|
45 |
+
sound_mapper = SoundMapper()
|
46 |
+
audio_generator = GenerateAudio()
|
47 |
+
|
48 |
+
if panoramic:
|
49 |
+
# Panoramic: generate per-view audio then composition
|
50 |
+
view_results = generate_caption(lat, lon, view=view, model=model,
|
51 |
+
cpu_only=cpu_only, panoramic=True)
|
52 |
+
processed_maps = sound_mapper.process_depth_maps()
|
53 |
+
image_paths = sorted(Path(image_dir).glob("*.jpg"))
|
54 |
+
audios = {}
|
55 |
+
for vr in view_results:
|
56 |
+
cv = vr["view"]
|
57 |
+
img_file = Path(image_dir) / f"{cv}.jpg"
|
58 |
+
if not img_file.exists():
|
59 |
+
continue
|
60 |
+
idx = [i for i, p in enumerate(image_paths) if p.name == img_file.name]
|
61 |
+
if not idx:
|
62 |
+
continue
|
63 |
+
depth_map = processed_maps[idx[0]]["normalization"]
|
64 |
+
obj_depths = sound_mapper.analyze_object_depths(
|
65 |
+
str(img_file), depth_map, lat, lon,
|
66 |
+
caption_data=vr, all_objects=False
|
67 |
+
)
|
68 |
+
if not obj_depths:
|
69 |
+
continue
|
70 |
+
out_wav = Path(output_dir) / f"sound_{cv}.wav"
|
71 |
+
audio, sr = audio_generator.process_and_generate_audio(
|
72 |
+
obj_depths, duration=audio_duration
|
73 |
+
)
|
74 |
+
if audio.dim() == 3:
|
75 |
+
audio = audio.squeeze(0)
|
76 |
+
elif audio.dim() == 1:
|
77 |
+
audio = audio.unsqueeze(0)
|
78 |
+
torchaudio.save(str(out_wav), audio, sr)
|
79 |
+
audios[cv] = str(out_wav)
|
80 |
+
# final panoramic composition
|
81 |
+
comp = Path(output_dir) / "panoramic_composition.wav"
|
82 |
+
compose_audio(list(audios.values()), [1.0]*len(audios), str(comp))
|
83 |
+
audios['panorama'] = str(comp)
|
84 |
+
clear_cuda()
|
85 |
+
return
|
86 |
+
|
87 |
+
# Single-view: generate one audio
|
88 |
+
vr = generate_caption(lat, lon, view=view, model=model,
|
89 |
+
cpu_only=cpu_only, panoramic=False)
|
90 |
+
img_file = Path(image_dir) / f"{view}.jpg"
|
91 |
+
processed_maps = sound_mapper.process_depth_maps()
|
92 |
+
image_paths = sorted(Path(image_dir).glob("*.jpg"))
|
93 |
+
idx = [i for i, p in enumerate(image_paths) if p.name == img_file.name]
|
94 |
+
depth_map = processed_maps[idx[0]]["normalization"]
|
95 |
+
obj_depths = sound_mapper.analyze_object_depths(
|
96 |
+
str(img_file), depth_map, lat, lon,
|
97 |
+
caption_data=vr, all_objects=True
|
98 |
+
)
|
99 |
+
out_wav = Path(output_dir) / f"sound_{view}.wav"
|
100 |
+
audio, sr = audio_generator.process_and_generate_audio(obj_depths, duration=audio_duration)
|
101 |
+
if audio.dim() == 3:
|
102 |
+
audio = audio.squeeze(0)
|
103 |
+
elif audio.dim() == 1:
|
104 |
+
audio = audio.unsqueeze(0)
|
105 |
+
torchaudio.save(str(out_wav), audio, sr)
|
106 |
+
clear_cuda()
|
107 |
+
|
108 |
+
# Gradio UI
|
109 |
+
demo = gr.Blocks(title="Panoramic Audio Generator")
|
110 |
+
with demo:
|
111 |
+
gr.Markdown("""
|
112 |
+
# Panoramic Audio Generator
|
113 |
+
|
114 |
+
Displays each view with its audio side by side.
|
115 |
+
"""
|
116 |
+
)
|
117 |
+
|
118 |
+
with gr.Row():
|
119 |
+
panoramic = gr.Checkbox(label="Panoramic (multi-view)", value=False)
|
120 |
+
view = gr.Dropdown(["front", "back", "left", "right"], value="front", label="View")
|
121 |
+
location = gr.Textbox(value="52.3436723,4.8529625", label="Location (lat,lon)")
|
122 |
+
# model = gr.Textbox(value="intern_2_5-4B", label="Vision-Language Model")
|
123 |
+
model = "intern_2_5-4B"
|
124 |
+
audio_duration = gr.Slider(1, 60, value=10, step=1, label="Audio Duration (sec)")
|
125 |
+
cpu_only = gr.Checkbox(label="CPU Only", value=False)
|
126 |
+
btn = gr.Button("Generate")
|
127 |
+
|
128 |
+
# Output layout: two rows of two
|
129 |
+
with gr.Row():
|
130 |
+
with gr.Column():
|
131 |
+
img_front = gr.Image(label="Front View", type="filepath")
|
132 |
+
aud_front = gr.Audio(label="Front Audio", type="filepath")
|
133 |
+
with gr.Column():
|
134 |
+
img_back = gr.Image(label="Back View", type="filepath")
|
135 |
+
aud_back = gr.Audio(label="Back Audio", type="filepath")
|
136 |
+
with gr.Row():
|
137 |
+
with gr.Column():
|
138 |
+
img_left = gr.Image(label="Left View", type="filepath")
|
139 |
+
aud_left = gr.Audio(label="Left Audio", type="filepath")
|
140 |
+
with gr.Column():
|
141 |
+
img_right = gr.Image(label="Right View", type="filepath")
|
142 |
+
aud_right = gr.Audio(label="Right Audio", type="filepath")
|
143 |
+
# Panorama at bottom
|
144 |
+
img_pan = gr.Image(label="Panorama View", type="filepath")
|
145 |
+
aud_pan = gr.Audio(label="Panoramic Audio", type="filepath")
|
146 |
+
|
147 |
+
# Preview update
|
148 |
+
def run_all(pan, vw, loc, mdl, dur, cpu):
|
149 |
+
# generate files
|
150 |
+
process_images(LOGS_DIR, OUTPUT_DIR, pan, vw, mdl, loc, dur, cpu)
|
151 |
+
# collect files
|
152 |
+
views = ["front", "back", "left", "right", "panorama"]
|
153 |
+
paths = {}
|
154 |
+
for v in views:
|
155 |
+
img = Path(LOGS_DIR) / f"{v}.jpg"
|
156 |
+
audio = Path(OUTPUT_DIR) / ("panoramic_composition.wav" if v == "panorama" else f"sound_{v}.wav")
|
157 |
+
paths[v] = {
|
158 |
+
'img': str(img) if img.exists() else None,
|
159 |
+
'aud': str(audio) if audio.exists() else None
|
160 |
+
}
|
161 |
+
return (
|
162 |
+
paths['front']['img'], paths['front']['aud'],
|
163 |
+
paths['back']['img'], paths['back']['aud'],
|
164 |
+
paths['left']['img'], paths['left']['aud'],
|
165 |
+
paths['right']['img'], paths['right']['aud'],
|
166 |
+
paths['panorama']['img'], paths['panorama']['aud']
|
167 |
+
)
|
168 |
+
|
169 |
+
btn.click(
|
170 |
+
fn=run_all,
|
171 |
+
inputs=[panoramic, view, location, model, audio_duration, cpu_only],
|
172 |
+
outputs=[
|
173 |
+
img_front, aud_front,
|
174 |
+
img_back, aud_back,
|
175 |
+
img_left, aud_left,
|
176 |
+
img_right, aud_right,
|
177 |
+
img_pan, aud_pan
|
178 |
+
]
|
179 |
+
)
|
180 |
+
|
181 |
+
if __name__ == "__main__":
|
182 |
+
demo.launch(share=True)
|
audio_mixer.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import torchaudio.transforms as T
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import os
|
7 |
+
from typing import List, Tuple
|
8 |
+
from config import LOGS_DIR
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
##Some utils:
|
13 |
+
def load_audio_files(file_paths: List[str]) -> List[Tuple[torch.Tensor, int]]:
|
14 |
+
"""
|
15 |
+
Load multiple audio files and ensure they have the same length.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
file_paths: List of paths to audio files
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
List of tuples containing audio data and sample rate
|
22 |
+
"""
|
23 |
+
audio_data = []
|
24 |
+
|
25 |
+
for path in file_paths:
|
26 |
+
# Load audio file
|
27 |
+
waveform, sample_rate = torchaudio.load(path)
|
28 |
+
# Convert to mono if stereo
|
29 |
+
if waveform.shape[0] > 1:
|
30 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
31 |
+
audio_data.append((waveform.squeeze(), sample_rate))
|
32 |
+
|
33 |
+
# Verify all audio files have the same length and sample rate
|
34 |
+
lengths = [len(audio) for audio, _ in audio_data]
|
35 |
+
sample_rates = [sr for _, sr in audio_data]
|
36 |
+
|
37 |
+
if len(set(lengths)) > 1:
|
38 |
+
raise ValueError(f"Audio files have different lengths: {lengths}")
|
39 |
+
if len(set(sample_rates)) > 1:
|
40 |
+
raise ValueError(f"Audio files have different sample rates: {sample_rates}")
|
41 |
+
|
42 |
+
return audio_data
|
43 |
+
|
44 |
+
|
45 |
+
def normalize_audio_volumes(audio_data: List[Tuple[torch.Tensor, int]]) -> List[Tuple[torch.Tensor, int]]:
|
46 |
+
"""
|
47 |
+
Normalize the volume of each audio file to have the same energy level.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
audio_data: List of tuples containing audio data and sample rate
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
List of tuples containing normalized audio data and sample rate
|
54 |
+
"""
|
55 |
+
normalized_data = []
|
56 |
+
|
57 |
+
# Calculate RMS (Root Mean Square) for each audio
|
58 |
+
rms_values = []
|
59 |
+
for audio, sr in audio_data:
|
60 |
+
# Calculate energy (squared amplitude)
|
61 |
+
energy = torch.mean(audio ** 2)
|
62 |
+
# Calculate RMS (square root of mean energy)
|
63 |
+
rms = torch.sqrt(energy)
|
64 |
+
rms_values.append(rms)
|
65 |
+
|
66 |
+
# Find the target RMS (we'll use the median to avoid outliers)
|
67 |
+
target_rms = torch.median(torch.tensor(rms_values))
|
68 |
+
|
69 |
+
# Normalize each audio to the target RMS
|
70 |
+
for (audio, sr), rms in zip(audio_data, rms_values):
|
71 |
+
if rms > 0: # Avoid division by zero
|
72 |
+
# Calculate scaling factor
|
73 |
+
scaling_factor = target_rms / rms
|
74 |
+
# Apply scaling
|
75 |
+
normalized_audio = audio * scaling_factor
|
76 |
+
else:
|
77 |
+
normalized_audio = audio
|
78 |
+
|
79 |
+
normalized_data.append((normalized_audio, sr))
|
80 |
+
|
81 |
+
return normalized_data
|
82 |
+
|
83 |
+
def plot_energy_comparison(original_metrics: List[dict], normalized_metrics: List[dict], file_names: List[str], output_path: str = "./logs/energy_comparison.png") -> None:
|
84 |
+
"""
|
85 |
+
Plot a comparison of energy metrics before and after normalization.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
original_metrics: List of dictionaries containing metrics for original audio
|
89 |
+
normalized_metrics: List of dictionaries containing metrics for normalized audio
|
90 |
+
file_names: List of audio file names
|
91 |
+
output_path: Path to save the plot
|
92 |
+
"""
|
93 |
+
fig, axs = plt.subplots(2, 2, figsize=(14, 10))
|
94 |
+
|
95 |
+
# Extract metrics
|
96 |
+
orig_rms = [m['rms'] for m in original_metrics]
|
97 |
+
norm_rms = [m['rms'] for m in normalized_metrics]
|
98 |
+
|
99 |
+
orig_peak = [m['peak'] for m in original_metrics]
|
100 |
+
norm_peak = [m['peak'] for m in normalized_metrics]
|
101 |
+
|
102 |
+
orig_dr = [m['dynamic_range_db'] for m in original_metrics]
|
103 |
+
norm_dr = [m['dynamic_range_db'] for m in normalized_metrics]
|
104 |
+
|
105 |
+
orig_cf = [m['crest_factor'] for m in original_metrics]
|
106 |
+
norm_cf = [m['crest_factor'] for m in normalized_metrics]
|
107 |
+
|
108 |
+
# Prepare x-axis
|
109 |
+
x = np.arange(len(file_names))
|
110 |
+
width = 0.35
|
111 |
+
|
112 |
+
# Plot RMS (volume)
|
113 |
+
axs[0, 0].bar(x - width/2, orig_rms, width, label='Original')
|
114 |
+
axs[0, 0].bar(x + width/2, norm_rms, width, label='Normalized')
|
115 |
+
axs[0, 0].set_title('RMS Energy (Volume)')
|
116 |
+
axs[0, 0].set_xticks(x)
|
117 |
+
axs[0, 0].set_xticklabels(file_names, rotation=45, ha='right')
|
118 |
+
axs[0, 0].set_ylabel('RMS Value')
|
119 |
+
axs[0, 0].legend()
|
120 |
+
|
121 |
+
# Plot Peak Amplitude
|
122 |
+
axs[0, 1].bar(x - width/2, orig_peak, width, label='Original')
|
123 |
+
axs[0, 1].bar(x + width/2, norm_peak, width, label='Normalized')
|
124 |
+
axs[0, 1].set_title('Peak Amplitude')
|
125 |
+
axs[0, 1].set_xticks(x)
|
126 |
+
axs[0, 1].set_xticklabels(file_names, rotation=45, ha='right')
|
127 |
+
axs[0, 1].set_ylabel('Peak Value')
|
128 |
+
axs[0, 1].legend()
|
129 |
+
|
130 |
+
# Plot Dynamic Range
|
131 |
+
axs[1, 0].bar(x - width/2, orig_dr, width, label='Original')
|
132 |
+
axs[1, 0].bar(x + width/2, norm_dr, width, label='Normalized')
|
133 |
+
axs[1, 0].set_title('Dynamic Range (dB)')
|
134 |
+
axs[1, 0].set_xticks(x)
|
135 |
+
axs[1, 0].set_xticklabels(file_names, rotation=45, ha='right')
|
136 |
+
axs[1, 0].set_ylabel('dB')
|
137 |
+
axs[1, 0].legend()
|
138 |
+
|
139 |
+
# Plot Crest Factor
|
140 |
+
axs[1, 1].bar(x - width/2, orig_cf, width, label='Original')
|
141 |
+
axs[1, 1].bar(x + width/2, norm_cf, width, label='Normalized')
|
142 |
+
axs[1, 1].set_title('Crest Factor (Peak-to-RMS Ratio)')
|
143 |
+
axs[1, 1].set_xticks(x)
|
144 |
+
axs[1, 1].set_xticklabels(file_names, rotation=45, ha='right')
|
145 |
+
axs[1, 1].set_ylabel('Ratio')
|
146 |
+
axs[1, 1].legend()
|
147 |
+
|
148 |
+
plt.tight_layout()
|
149 |
+
|
150 |
+
# Create directory if it doesn't exist
|
151 |
+
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
|
152 |
+
|
153 |
+
# Save the plot
|
154 |
+
plt.savefig(output_path)
|
155 |
+
plt.close()
|
156 |
+
|
157 |
+
def calculate_audio_metrics(audio_data: List[Tuple[torch.Tensor, int]]) -> List[dict]:
|
158 |
+
"""
|
159 |
+
Calculate various audio metrics for each audio file.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
audio_data: List of tuples containing audio data and sample rate
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
List of dictionaries containing metrics
|
166 |
+
"""
|
167 |
+
metrics = []
|
168 |
+
|
169 |
+
for audio, sr in audio_data:
|
170 |
+
# Calculate RMS (Root Mean Square)
|
171 |
+
energy = torch.mean(audio ** 2)
|
172 |
+
rms = torch.sqrt(energy)
|
173 |
+
|
174 |
+
# Calculate peak amplitude
|
175 |
+
peak = torch.max(torch.abs(audio))
|
176 |
+
|
177 |
+
# Calculate dynamic range
|
178 |
+
if torch.min(torch.abs(audio[audio != 0])) > 0:
|
179 |
+
min_non_zero = torch.min(torch.abs(audio[audio != 0]))
|
180 |
+
dynamic_range_db = 20 * torch.log10(peak / min_non_zero)
|
181 |
+
else:
|
182 |
+
dynamic_range_db = torch.tensor(float('inf'))
|
183 |
+
|
184 |
+
# Calculate crest factor (peak to RMS ratio)
|
185 |
+
crest_factor = peak / rms if rms > 0 else torch.tensor(float('inf'))
|
186 |
+
|
187 |
+
metrics.append({
|
188 |
+
'rms': rms.item(),
|
189 |
+
'peak': peak.item(),
|
190 |
+
'dynamic_range_db': dynamic_range_db.item() if not torch.isinf(dynamic_range_db) else float('inf'),
|
191 |
+
'crest_factor': crest_factor.item() if not torch.isinf(crest_factor) else float('inf')
|
192 |
+
})
|
193 |
+
|
194 |
+
return metrics
|
195 |
+
|
196 |
+
|
197 |
+
def create_weighted_composite(
|
198 |
+
audio_data: List[Tuple[torch.Tensor, int]],
|
199 |
+
weights: List[float]
|
200 |
+
) -> torch.Tensor:
|
201 |
+
"""
|
202 |
+
Create a weighted composite of multiple audio files.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
audio_data: List of tuples containing audio data and sample rate
|
206 |
+
weights: List of weights for each audio file
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
Weighted composite audio data
|
210 |
+
"""
|
211 |
+
if len(audio_data) != len(weights):
|
212 |
+
raise ValueError("Number of audio files and weights must match")
|
213 |
+
|
214 |
+
# Normalize weights to sum to 1
|
215 |
+
weights = torch.tensor(weights) / sum(weights)
|
216 |
+
|
217 |
+
# Initialize composite audio with zeros
|
218 |
+
composite = torch.zeros_like(audio_data[0][0])
|
219 |
+
|
220 |
+
# Add weighted audio data
|
221 |
+
for (audio, _), weight in zip(audio_data, weights):
|
222 |
+
composite += audio * weight
|
223 |
+
|
224 |
+
# Normalize to prevent clipping
|
225 |
+
max_val = torch.max(torch.abs(composite))
|
226 |
+
if max_val > 1.0:
|
227 |
+
composite = composite / max_val
|
228 |
+
|
229 |
+
return composite
|
230 |
+
|
231 |
+
|
232 |
+
def create_melspectrograms(
|
233 |
+
audio_data: List[Tuple[torch.Tensor, int]],
|
234 |
+
composite: torch.Tensor,
|
235 |
+
sr: int
|
236 |
+
) -> List[torch.Tensor]:
|
237 |
+
"""
|
238 |
+
Create melspectrograms for individual audio files and the composite.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
audio_data: List of tuples containing audio data and sample rate
|
242 |
+
composite: Composite audio data
|
243 |
+
sr: Sample rate
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
List of melspectrogram data
|
247 |
+
"""
|
248 |
+
specs = []
|
249 |
+
|
250 |
+
# Create mel spectrogram transform
|
251 |
+
mel_transform = T.MelSpectrogram(
|
252 |
+
sample_rate=sr,
|
253 |
+
n_fft=2048,
|
254 |
+
win_length=2048,
|
255 |
+
hop_length=512,
|
256 |
+
n_mels=128,
|
257 |
+
f_max=8000
|
258 |
+
)
|
259 |
+
|
260 |
+
# Generate spectrograms for individual audio files
|
261 |
+
for audio, _ in audio_data:
|
262 |
+
melspec = mel_transform(audio)
|
263 |
+
specs.append(melspec)
|
264 |
+
|
265 |
+
# Generate spectrogram for composite audio
|
266 |
+
composite_melspec = mel_transform(composite)
|
267 |
+
specs.append(composite_melspec)
|
268 |
+
|
269 |
+
return specs
|
270 |
+
|
271 |
+
|
272 |
+
def plot_melspectrograms(
|
273 |
+
specs: List[torch.Tensor],
|
274 |
+
sr: int,
|
275 |
+
file_names: List[str],
|
276 |
+
weights: List[float],
|
277 |
+
output_path: str = "melspectrograms.png"
|
278 |
+
) -> None:
|
279 |
+
"""
|
280 |
+
Plot melspectrograms for individual audio files and the composite.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
specs: List of melspectrogram data
|
284 |
+
sr: Sample rate
|
285 |
+
file_names: List of audio file names
|
286 |
+
weights: List of weights for each audio file
|
287 |
+
output_path: Path to save the plot
|
288 |
+
"""
|
289 |
+
fig, axs = plt.subplots(len(specs), 1, figsize=(12, 4 * len(specs)))
|
290 |
+
|
291 |
+
# Create labels for the plots
|
292 |
+
labels = [f"{name} (weight: {weight:.2f})" for name, weight in zip(file_names, weights)]
|
293 |
+
labels.append("Composite.wav")
|
294 |
+
|
295 |
+
# Convert to dB scale (similar to librosa's power_to_db)
|
296 |
+
def power_to_db(spec):
|
297 |
+
return 10 * torch.log10(spec + 1e-10)
|
298 |
+
|
299 |
+
# Plot each melspectrogram
|
300 |
+
for i, (spec, label) in enumerate(zip(specs, labels)):
|
301 |
+
spec_db = power_to_db(spec).numpy().squeeze()
|
302 |
+
|
303 |
+
# For single subplot case
|
304 |
+
if len(specs) == 1:
|
305 |
+
ax = axs
|
306 |
+
else:
|
307 |
+
ax = axs[i]
|
308 |
+
|
309 |
+
img = ax.imshow(
|
310 |
+
spec_db,
|
311 |
+
aspect='auto',
|
312 |
+
origin='lower',
|
313 |
+
interpolation='none',
|
314 |
+
extent=[0, spec_db.shape[1], 0, sr/2]
|
315 |
+
)
|
316 |
+
ax.set_title(label)
|
317 |
+
ax.set_ylabel('Frequency (Hz)')
|
318 |
+
ax.set_xlabel('Time Frames')
|
319 |
+
|
320 |
+
# No colorbar as requested
|
321 |
+
|
322 |
+
plt.tight_layout()
|
323 |
+
|
324 |
+
# Create directory if it doesn't exist
|
325 |
+
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
|
326 |
+
# Save the plot
|
327 |
+
plt.savefig(output_path,dpi=300)
|
328 |
+
plt.close()
|
329 |
+
|
330 |
+
|
331 |
+
def compose_audio(
|
332 |
+
file_paths: List[str],
|
333 |
+
weights: List[float],
|
334 |
+
output_audio_path: str = os.path.join(LOGS_DIR, "composite.wav"),
|
335 |
+
output_plot_path: str = os.path.join(LOGS_DIR, "plot/melspectrograms.png"),
|
336 |
+
energy_plot_path: str = os.path.join(LOGS_DIR, "plot/energy_comparison.png")
|
337 |
+
) -> None:
|
338 |
+
"""
|
339 |
+
Main function to process audio files and create visualizations.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
file_paths: List of paths to audio files (supports 4 audio files)
|
343 |
+
weights: List of weights for each audio file
|
344 |
+
output_audio_path: Path to save the composite audio
|
345 |
+
output_plot_path: Path to save the melspectrogram plot
|
346 |
+
energy_plot_path: Path to save the energy comparison plot
|
347 |
+
"""
|
348 |
+
# Load audio files
|
349 |
+
audio_data = load_audio_files(file_paths)
|
350 |
+
|
351 |
+
# # Calculate metrics for original audio
|
352 |
+
print("Calculating metrics for original audio...")
|
353 |
+
original_metrics = calculate_audio_metrics(audio_data)
|
354 |
+
|
355 |
+
# Normalize audio volumes to have same energy level
|
356 |
+
print("Normalizing audio volumes...")
|
357 |
+
normalized_audio_data = normalize_audio_volumes(audio_data)
|
358 |
+
|
359 |
+
# Calculate metrics for normalized audio
|
360 |
+
print("Calculating metrics for normalized audio...")
|
361 |
+
normalized_metrics = calculate_audio_metrics(normalized_audio_data)
|
362 |
+
|
363 |
+
# Print energy comparison
|
364 |
+
print("\nAudio Energy Comparison (RMS values):")
|
365 |
+
print("-" * 50)
|
366 |
+
print(f"{'File':<20} {'Original':<15} {'Normalized':<15} {'Scaling Factor':<15}")
|
367 |
+
print("-" * 50)
|
368 |
+
for i, path in enumerate(file_paths):
|
369 |
+
file_name = path.split("/")[-1]
|
370 |
+
orig_rms = original_metrics[i]['rms']
|
371 |
+
norm_rms = normalized_metrics[i]['rms']
|
372 |
+
scaling = norm_rms / orig_rms if orig_rms > 0 else float('inf')
|
373 |
+
print(f"{file_name[:20]:<20} {orig_rms:<15.6f} {norm_rms:<15.6f} {scaling:<15.6f}")
|
374 |
+
|
375 |
+
# Create energy comparison plot
|
376 |
+
print("\nCreating energy comparison plot...")
|
377 |
+
file_names = [path.split("/")[-1] for path in file_paths]
|
378 |
+
plot_energy_comparison(original_metrics, normalized_metrics, file_names, energy_plot_path)
|
379 |
+
|
380 |
+
# Get sample rate (all files have the same sample rate)
|
381 |
+
sr = normalized_audio_data[0][1]
|
382 |
+
|
383 |
+
# Create weighted composite
|
384 |
+
print("\nCreating weighted composite...")
|
385 |
+
composite = create_weighted_composite(normalized_audio_data, weights)
|
386 |
+
|
387 |
+
# Create directory if it doesn't exist
|
388 |
+
os.makedirs(os.path.dirname(output_audio_path) or '.', exist_ok=True)
|
389 |
+
|
390 |
+
# Save composite audio
|
391 |
+
print("Saving composite audio...")
|
392 |
+
torchaudio.save(output_audio_path, composite.unsqueeze(0), sr)
|
393 |
+
|
394 |
+
# Create melspectrograms for normalized audio (not original)
|
395 |
+
print("Creating melspectrograms for normalized audio...")
|
396 |
+
specs = create_melspectrograms(normalized_audio_data, composite, sr)
|
397 |
+
|
398 |
+
# Get file names without path
|
399 |
+
labeled_file_names = [path.split("/")[-1] for path in file_paths]
|
400 |
+
|
401 |
+
# Plot melspectrograms
|
402 |
+
print("Plotting melspectrograms...")
|
403 |
+
plot_melspectrograms(specs, sr, labeled_file_names, weights, output_plot_path)
|
404 |
+
|
405 |
+
print(f"\nComposite audio saved to {output_audio_path}")
|
406 |
+
print(f"Melspectrograms saved to {output_plot_path}")
|
407 |
+
print(f"Energy comparison saved to {energy_plot_path}")
|
408 |
+
|
409 |
+
print(f"Composite audio saved to {output_audio_path}")
|
410 |
+
print(f"Melspectrograms saved to {output_plot_path}")
|
411 |
+
|
412 |
+
|
413 |
+
# if __name__ == "__main__":
|
414 |
+
# import argparse
|
415 |
+
|
416 |
+
# parser = argparse.ArgumentParser(description="Mix audio files with weights and create melspectrograms")
|
417 |
+
# parser.add_argument("--files", nargs="+", required=True, help="Paths to audio files")
|
418 |
+
# parser.add_argument("--weights", nargs="+", type=float, required=True, help="Weights for each audio file")
|
419 |
+
# parser.add_argument("--output-audio", default="./logs/composite.wav", help="Path to save the composite audio")
|
420 |
+
# parser.add_argument("--output-plot", default="./logs/melspectrograms.png", help="Path to save the melspectrogram plot")
|
421 |
+
|
422 |
+
# args = parser.parse_args()
|
423 |
+
# os.makedirs("./logs", exist_ok=True)
|
424 |
+
# main(args.files, args.weights, args.output_audio, args.output_plot)
|
425 |
+
|
426 |
+
|
427 |
+
# Example usage:
|
428 |
+
# python audio_mixer.py --files audio1.wav audio2.wav audio3.wav audio4.wav --weights 0.4 0.3 0.2 0.1
|
config.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# Base directories
|
4 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
LOGS_DIR = os.path.join(BASE_DIR, "logs")
|
6 |
+
OUTPUT_DIR = os.path.join(BASE_DIR, "output")
|
7 |
+
|
8 |
+
# Model paths
|
9 |
+
EXTERNAL_MODELS_DIR = os.path.join(BASE_DIR, "external_models")
|
10 |
+
DEPTH_FM_DIR = os.path.join(EXTERNAL_MODELS_DIR, "depth-fm")
|
11 |
+
DEPTH_FM_CHECKPOINT = os.path.join(DEPTH_FM_DIR, "checkpoints/depthfm-v1.ckpt") # You will need to download the checkpoint manually. Here is the link: https://github.com/CompVis/depth-fm/tree/main/checkpoints
|
12 |
+
TANGO_FLUX_DIR = os.path.join(EXTERNAL_MODELS_DIR, "TangoFlux")
|
13 |
+
|
14 |
+
# Create required directories
|
15 |
+
os.makedirs(LOGS_DIR, exist_ok=True)
|
16 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
environment.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: geosynthsound
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.10
|
7 |
+
- pip:
|
8 |
+
- -r requirements.txt
|
external_models/TangoFlux/.gitignore
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.py[cod]
|
3 |
+
*$py.class
|
4 |
+
|
5 |
+
# C extensions
|
6 |
+
*.so
|
7 |
+
|
8 |
+
# Distribution / packaging
|
9 |
+
.Python
|
10 |
+
build/
|
11 |
+
develop-eggs/
|
12 |
+
dist/
|
13 |
+
downloads/
|
14 |
+
eggs/
|
15 |
+
.eggs/
|
16 |
+
lib/
|
17 |
+
lib64/
|
18 |
+
parts/
|
19 |
+
sdist/
|
20 |
+
var/
|
21 |
+
wheels/
|
22 |
+
share/python-wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
# Usually these files are written by a python script from a template
|
30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.nox/
|
42 |
+
.coverage
|
43 |
+
.coverage.*
|
44 |
+
.cache
|
45 |
+
nosetests.xml
|
46 |
+
coverage.xml
|
47 |
+
*.cover
|
48 |
+
*.py,cover
|
49 |
+
.hypothesis/
|
50 |
+
.pytest_cache/
|
51 |
+
cover/
|
52 |
+
|
53 |
+
# Translations
|
54 |
+
*.mo
|
55 |
+
*.pot
|
56 |
+
|
57 |
+
# Django stuff:
|
58 |
+
*.log
|
59 |
+
local_settings.py
|
60 |
+
db.sqlite3
|
61 |
+
db.sqlite3-journal
|
62 |
+
|
63 |
+
# Flask stuff:
|
64 |
+
instance/
|
65 |
+
.webassets-cache
|
66 |
+
|
67 |
+
# Scrapy stuff:
|
68 |
+
.scrapy
|
69 |
+
|
70 |
+
# Sphinx documentation
|
71 |
+
docs/_build/
|
72 |
+
|
73 |
+
# PyBuilder
|
74 |
+
.pybuilder/
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
# For a library or package, you might want to ignore these files since the code is
|
86 |
+
# intended to run in multiple environments; otherwise, check them in:
|
87 |
+
# .python-version
|
88 |
+
|
89 |
+
# pipenv
|
90 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
91 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
92 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
93 |
+
# install all needed dependencies.
|
94 |
+
#Pipfile.lock
|
95 |
+
|
96 |
+
# UV
|
97 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
98 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
99 |
+
# commonly ignored for libraries.
|
100 |
+
#uv.lock
|
101 |
+
|
102 |
+
# poetry
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
107 |
+
#poetry.lock
|
108 |
+
|
109 |
+
# pdm
|
110 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
111 |
+
#pdm.lock
|
112 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
113 |
+
# in version control.
|
114 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
115 |
+
.pdm.toml
|
116 |
+
.pdm-python
|
117 |
+
.pdm-build/
|
118 |
+
|
119 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
120 |
+
__pypackages__/
|
121 |
+
|
122 |
+
# Celery stuff
|
123 |
+
celerybeat-schedule
|
124 |
+
celerybeat.pid
|
125 |
+
|
126 |
+
# SageMath parsed files
|
127 |
+
*.sage.py
|
128 |
+
|
129 |
+
# Environments
|
130 |
+
.env
|
131 |
+
.venv
|
132 |
+
env/
|
133 |
+
venv/
|
134 |
+
ENV/
|
135 |
+
env.bak/
|
136 |
+
venv.bak/
|
137 |
+
|
138 |
+
# Spyder project settings
|
139 |
+
.spyderproject
|
140 |
+
.spyproject
|
141 |
+
|
142 |
+
# Rope project settings
|
143 |
+
.ropeproject
|
144 |
+
|
145 |
+
# mkdocs documentation
|
146 |
+
/site
|
147 |
+
|
148 |
+
# mypy
|
149 |
+
.mypy_cache/
|
150 |
+
.dmypy.json
|
151 |
+
dmypy.json
|
152 |
+
|
153 |
+
# Pyre type checker
|
154 |
+
.pyre/
|
155 |
+
|
156 |
+
# pytype static type analyzer
|
157 |
+
.pytype/
|
158 |
+
|
159 |
+
# Cython debug symbols
|
160 |
+
cython_debug/
|
161 |
+
|
162 |
+
# PyCharm
|
163 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
164 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
165 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
166 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
167 |
+
#.idea/
|
168 |
+
|
169 |
+
# PyPI configuration file
|
170 |
+
.pypirc
|
171 |
+
|
172 |
+
|
173 |
+
.DS_Store
|
174 |
+
|
175 |
+
*.wav
|
external_models/TangoFlux/Demo.ipynb
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "view-in-github",
|
7 |
+
"colab_type": "text"
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"<a href=\"https://colab.research.google.com/github/OIEIEIO/TangoFlux/blob/main/Demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": null,
|
16 |
+
"metadata": {
|
17 |
+
"collapsed": true,
|
18 |
+
"id": "xiaRzuzPOP4H"
|
19 |
+
},
|
20 |
+
"outputs": [],
|
21 |
+
"source": [
|
22 |
+
"!pip install git+https://github.com/declare-lab/TangoFlux.git"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": null,
|
28 |
+
"metadata": {
|
29 |
+
"collapsed": true,
|
30 |
+
"id": "Hfu3zXTDOP4J"
|
31 |
+
},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"import IPython\n",
|
35 |
+
"import torchaudio\n",
|
36 |
+
"from tangoflux import TangoFluxInference\n",
|
37 |
+
"from IPython.display import Audio\n",
|
38 |
+
"\n",
|
39 |
+
"model = TangoFluxInference(name='declare-lab/TangoFlux')"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": null,
|
45 |
+
"metadata": {
|
46 |
+
"id": "oFiak5QIOP4K"
|
47 |
+
},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"# @title Generate Audio\n",
|
51 |
+
"\n",
|
52 |
+
"prompt = 'a futuristic space craft with unique engine sound' # @param {type:\"string\"}\n",
|
53 |
+
"duration = 10 # @param {type:\"number\"}\n",
|
54 |
+
"steps = 50 # @param {type:\"number\"}\n",
|
55 |
+
"\n",
|
56 |
+
"audio = model.generate(prompt, steps=steps, duration=duration)\n",
|
57 |
+
"\n",
|
58 |
+
"Audio(data=audio, rate=44100)"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"cell_type": "code",
|
63 |
+
"source": [
|
64 |
+
"import IPython\n",
|
65 |
+
"import torchaudio\n",
|
66 |
+
"from tangoflux import TangoFluxInference\n",
|
67 |
+
"from IPython.display import Audio\n",
|
68 |
+
"\n",
|
69 |
+
"model = TangoFluxInference(name='declare-lab/TangoFlux')\n",
|
70 |
+
"\n",
|
71 |
+
"# @title Generate Audio\n",
|
72 |
+
"prompt = 'Melodic human whistling harmonizing with natural birdsong' # @param {type:\"string\"}\n",
|
73 |
+
"duration = 10 # @param {type:\"number\"}\n",
|
74 |
+
"steps = 50 # @param {type:\"number\"}\n",
|
75 |
+
"\n",
|
76 |
+
"# Generate the audio\n",
|
77 |
+
"audio = model.generate(prompt, steps=steps, duration=duration)\n",
|
78 |
+
"\n",
|
79 |
+
"# Ensure audio is in the correct format (2D Tensor: [channels, samples])\n",
|
80 |
+
"if len(audio.shape) == 1: # If mono audio (1D tensor)\n",
|
81 |
+
" audio_tensor = audio.unsqueeze(0) # Add channel dimension to make it [1, samples]\n",
|
82 |
+
"elif len(audio.shape) == 2: # Stereo audio (2D tensor)\n",
|
83 |
+
" audio_tensor = audio # Already in correct format\n",
|
84 |
+
"else:\n",
|
85 |
+
" raise ValueError(f\"Unexpected audio tensor shape: {audio.shape}\")\n",
|
86 |
+
"\n",
|
87 |
+
"# Save the audio as a .wav file\n",
|
88 |
+
"torchaudio.save('generated_audio.wav', audio_tensor, sample_rate=44100)\n",
|
89 |
+
"\n",
|
90 |
+
"# Optionally play the audio in the notebook\n",
|
91 |
+
"Audio(data=audio.numpy(), rate=44100)\n"
|
92 |
+
],
|
93 |
+
"metadata": {
|
94 |
+
"id": "_Z8elHyOHOQ1"
|
95 |
+
},
|
96 |
+
"execution_count": null,
|
97 |
+
"outputs": []
|
98 |
+
}
|
99 |
+
],
|
100 |
+
"metadata": {
|
101 |
+
"language_info": {
|
102 |
+
"name": "python"
|
103 |
+
},
|
104 |
+
"colab": {
|
105 |
+
"provenance": [],
|
106 |
+
"machine_shape": "hm",
|
107 |
+
"private_outputs": true,
|
108 |
+
"include_colab_link": true
|
109 |
+
},
|
110 |
+
"kernelspec": {
|
111 |
+
"name": "python3",
|
112 |
+
"display_name": "Python 3"
|
113 |
+
}
|
114 |
+
},
|
115 |
+
"nbformat": 4,
|
116 |
+
"nbformat_minor": 0
|
117 |
+
}
|
external_models/TangoFlux/Inference.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
external_models/TangoFlux/LICENSE.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LICENSE
|
2 |
+
|
3 |
+
## 1. Model & License Summary
|
4 |
+
|
5 |
+
This repository contains **TangoFlux** (the “Model”) created for **non-commercial, research-only** purposes under the **UK data copyright exemption**. The Model is subject to:
|
6 |
+
|
7 |
+
1. The **Stability AI Community License Agreement**, provided in the file ```STABILITY_AI_COMMUNITY_LICENSE.md```.
|
8 |
+
2. The **WavCaps** license requirement: **only academic uses** are permitted for data sourced from WavCaps.
|
9 |
+
3. The **original licenses** of the datasets used in training.
|
10 |
+
|
11 |
+
By using or distributing this Model, you **agree** to adhere to all applicable licenses and restrictions, as summarized below.
|
12 |
+
|
13 |
+
---
|
14 |
+
|
15 |
+
## 2. Stability AI Community License Requirements
|
16 |
+
|
17 |
+
- You must comply with the **Stability AI Community License Agreement** (the “Agreement”) for any usage, distribution, or modification of this Model.
|
18 |
+
- **Non-Commercial Use**: This Model is for research and academic purposes only. Any commercial usage requires registering with Stability AI or obtaining a separate commercial license.
|
19 |
+
- **Attribution & Notice**:
|
20 |
+
- Retain the notice:
|
21 |
+
```
|
22 |
+
This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved.
|
23 |
+
```
|
24 |
+
- Clearly display “Powered by Stability AI” if you build upon or showcase this Model.
|
25 |
+
- **Disclaimer & Liability**: This Model is provided **“AS IS”** with **no warranties**. Neither we nor Stability AI will be liable for any claim or damages related to Model use.
|
26 |
+
|
27 |
+
See ```STABILITY_AI_COMMUNITY_LICENSE.md``` for the full text.
|
28 |
+
|
29 |
+
---
|
30 |
+
|
31 |
+
## 3. WavCaps & Dataset Usage
|
32 |
+
|
33 |
+
- **Academic-Only for WavCaps**: By accessing any WavCaps-sourced data (including audio clips via provided links), you agree to use them **strictly for non-commercial, academic research** in accordance with WavCaps’ terms.
|
34 |
+
- **WavCaps Audio**: Each WavCaps audio subset has its own license terms. **You** are responsible for reviewing and complying with those licenses, including attribution requirements on your end.
|
35 |
+
|
36 |
+
---
|
37 |
+
|
38 |
+
## 4. UK Data Copyright Exemption
|
39 |
+
|
40 |
+
This Model was developed under the **UK data copyright exemption for non-commercial research**. Distribution or use outside these bounds must **not** violate that exemption or infringe on any underlying dataset’s license.
|
41 |
+
|
42 |
+
---
|
43 |
+
|
44 |
+
## 5. Further Information
|
45 |
+
|
46 |
+
- **Stability AI License Terms**: <https://stability.ai/community-license>
|
47 |
+
- **WavCaps License**: <https://github.com/XinhaoMei/WavCaps?tab=readme-ov-file#license>
|
48 |
+
|
49 |
+
---
|
50 |
+
|
51 |
+
**End of License**.
|
external_models/TangoFlux/Notice
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved
|
external_models/TangoFlux/README.md
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<img src="assets/tf_opener.png" alt="TangoFluxOpener" width="1000" />
|
3 |
+
|
4 |
+
<br/>
|
5 |
+
|
6 |
+
[](https://arxiv.org/abs/2412.21037) [](https://huggingface.co/declare-lab/TangoFlux) [](https://tangoflux.github.io/) [](https://huggingface.co/spaces/declare-lab/TangoFlux) [](https://huggingface.co/datasets/declare-lab/CRPO) [](https://replicate.com/declare-lab/tangoflux)
|
7 |
+
|
8 |
+
<img src="assets/tf_teaser.png" alt="TangoFlux" width="1000" />
|
9 |
+
<br/>
|
10 |
+
|
11 |
+
</div>
|
12 |
+
|
13 |
+
* Powered by **Stability AI**
|
14 |
+
## News
|
15 |
+
> 📣 1/3/25: We have released CRPO dataset as well as the script to perform CRPO dataset generation!
|
16 |
+
|
17 |
+
## Demos
|
18 |
+
|
19 |
+
[](https://huggingface.co/spaces/declare-lab/TangoFlux)
|
20 |
+
|
21 |
+
[](https://colab.research.google.com/github/declare-lab/TangoFlux/blob/main/Demo.ipynb)
|
22 |
+
|
23 |
+
## Overall Pipeline
|
24 |
+
|
25 |
+
TangoFlux consists of FluxTransformer blocks, which are Diffusion Transformers (DiT) and Multimodal Diffusion Transformers (MMDiT) conditioned on a textual prompt and a duration embedding to generate a 44.1kHz audio up to 30 seconds long. TangoFlux learns a rectified flow trajectory to an audio latent representation encoded by a variational autoencoder (VAE). TangoFlux training pipeline consists of three stages: pre-training, fine-tuning, and preference optimization with CRPO. CRPO, particularly, iteratively generates new synthetic data and constructs preference pairs for preference optimization using DPO loss for flow matching.
|
26 |
+
|
27 |
+

|
28 |
+
|
29 |
+
🚀 **TangoFlux can generate 44.1kHz stereo audio up to 30 seconds in ~3 seconds on a single A40 GPU.**
|
30 |
+
|
31 |
+
## Installation
|
32 |
+
|
33 |
+
```bash
|
34 |
+
pip install git+https://github.com/declare-lab/TangoFlux
|
35 |
+
```
|
36 |
+
|
37 |
+
## Inference
|
38 |
+
|
39 |
+
TangoFlux can generate audio up to 30 seconds long. You must pass a duration to the `model.generate` function when using the Python API. Please note that duration should be between 1 and 30.
|
40 |
+
|
41 |
+
### Web Interface
|
42 |
+
|
43 |
+
Run the following command to start the web interface:
|
44 |
+
|
45 |
+
```bash
|
46 |
+
tangoflux-demo
|
47 |
+
```
|
48 |
+
|
49 |
+
### CLI
|
50 |
+
|
51 |
+
Use the CLI to generate audio from text.
|
52 |
+
|
53 |
+
```bash
|
54 |
+
tangoflux "Hammer slowly hitting the wooden table" output.wav --duration 10 --steps 50
|
55 |
+
```
|
56 |
+
|
57 |
+
### Python API
|
58 |
+
|
59 |
+
```python
|
60 |
+
import torchaudio
|
61 |
+
from tangoflux import TangoFluxInference
|
62 |
+
|
63 |
+
model = TangoFluxInference(name='declare-lab/TangoFlux')
|
64 |
+
audio = model.generate('Hammer slowly hitting the wooden table', steps=50, duration=10)
|
65 |
+
|
66 |
+
torchaudio.save('output.wav', audio, 44100)
|
67 |
+
```
|
68 |
+
|
69 |
+
### [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
|
70 |
+
|
71 |
+
> This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface.
|
72 |
+
|
73 |
+
Check [this](https://github.com/LucipherDev/ComfyUI-TangoFlux) repo for the TangoFlux custom node for *ComfyUI*. (Thanks to [LucipherDev](https://github.com/LucipherDev))
|
74 |
+
|
75 |
+
Our evaluation shows that inference with 50 steps yields the best results. A CFG scale of 3.5, 4, and 4.5 yield similar quality output. Inference with 25 steps yields similar audio quality at a faster speed.
|
76 |
+
|
77 |
+
## Training
|
78 |
+
|
79 |
+
We use the `accelerate` package from Hugging Face for multi-GPU training. Run `accelerate config` to setup your run configuration. The default accelerate config is in the `configs` folder. Please specify the path to your training files in the `configs/tangoflux_config.yaml`. Samples of `train.json` and `val.json` have been provided. Replace them with your own audio.
|
80 |
+
|
81 |
+
`tangoflux_config.yaml` defines the training file paths and model hyperparameters:
|
82 |
+
|
83 |
+
```bash
|
84 |
+
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' tangoflux/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
|
85 |
+
```
|
86 |
+
|
87 |
+
To perform DPO training, modify the training files such that each data point contains "chosen", "reject", "caption" and "duration" fields. Please specify the path to your training files in `configs/tangoflux_config.yaml`. An example has been provided in `train_dpo.json`. Replace it with your own audio.
|
88 |
+
|
89 |
+
```bash
|
90 |
+
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' tangoflux/train_dpo.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
|
91 |
+
```
|
92 |
+
|
93 |
+
## Evaluation
|
94 |
+
|
95 |
+
### TangoFlux vs. Other Audio Generation Models
|
96 |
+
|
97 |
+
This key comparison metrics include:
|
98 |
+
|
99 |
+
- **Output Length**: Represents the duration of the generated audio.
|
100 |
+
- **FD**<sub>openl3</sub>: Fréchet Distance.
|
101 |
+
- **KL**<sub>passt</sub>: KL divergence.
|
102 |
+
- **CLAP**<sub>score</sub>: Alignment score.
|
103 |
+
|
104 |
+
|
105 |
+
All the inference times are observed on the same A40 GPU. The counts of trainable parameters are reported in the **\#Params** column.
|
106 |
+
|
107 |
+
| Model | Params | Duration | Steps | FD<sub>openl3</sub> ↓ | KL<sub>passt</sub> ↓ | CLAP<sub>score</sub> ↑ | IS ↑ | Inference Time (s) |
|
108 |
+
|---|---|---|---|---|---|---|---|---|
|
109 |
+
| **AudioLDM 2 (Large)** | 712M | 10 sec | 200 | 108.3 | 1.81 | 0.419 | 7.9 | 24.8 |
|
110 |
+
| **Stable Audio Open** | 1056M | 47 sec | 100 | 89.2 | 2.58 | 0.291 | 9.9 | 8.6 |
|
111 |
+
| **Tango 2** | 866M | 10 sec | 200 | 108.4 | 1.11 | 0.447 | 9.0 | 22.8 |
|
112 |
+
| **TangoFlux (Base)** | 515M | 30 sec | 50 | 80.2 | 1.22 | 0.431 | 11.7 | 3.7 |
|
113 |
+
| **TangoFlux** | 515M | 30 sec | 50 | 75.1 | 1.15 | 0.480 | 12.2 | 3.7 |
|
114 |
+
|
115 |
+
## CRPO dataset generation
|
116 |
+
|
117 |
+
There are 2 py files for CRPO dataset generation.
|
118 |
+
tangoflux/generate_crpo.py generates the crpo dataset by providing path to prompt bank and model weights. You can specify the sample size as well as number of samples per prompt for crpo in the arguments.
|
119 |
+
tangoflux/label_crpo.py labels the generated audio and construct preference pairs. This will also create a train.json in the output dir that can be passed into train_dpo.py
|
120 |
+
|
121 |
+
You can follow the example in crpo.sh which will generate crpo dataset, then perform reward labelling to generate the train.json
|
122 |
+
|
123 |
+
To run CRPO for multiple iteration, you can simply repeat the above the process multiple time through setting the correct model weight.
|
124 |
+
## Citation
|
125 |
+
|
126 |
+
```bibtex
|
127 |
+
@misc{hung2024tangofluxsuperfastfaithful,
|
128 |
+
title={TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization},
|
129 |
+
author={Chia-Yu Hung and Navonil Majumder and Zhifeng Kong and Ambuj Mehrish and Amir Zadeh and Chuan Li and Rafael Valle and Bryan Catanzaro and Soujanya Poria},
|
130 |
+
year={2024},
|
131 |
+
eprint={2412.21037},
|
132 |
+
archivePrefix={arXiv},
|
133 |
+
primaryClass={cs.SD},
|
134 |
+
url={https://arxiv.org/abs/2412.21037},
|
135 |
+
}
|
136 |
+
```
|
137 |
+
|
138 |
+
## LICENSE
|
139 |
+
|
140 |
+
### 1. Model & License Summary
|
141 |
+
|
142 |
+
This repository contains **TangoFlux** (the “Model”) created for **non-commercial, research-only** purposes under the **UK data copyright exemption**. The Model is subject to:
|
143 |
+
|
144 |
+
1. The **Stability AI Community License Agreement**, provided in the file ```STABILITY_AI_COMMUNITY_LICENSE.md```.
|
145 |
+
2. The **WavCaps** license requirement: **only academic uses** are permitted for data sourced from WavCaps.
|
146 |
+
3. The **original licenses** of the datasets used in training.
|
147 |
+
|
148 |
+
By using or distributing this Model, you **agree** to adhere to all applicable licenses and restrictions, as summarized below.
|
149 |
+
|
150 |
+
---
|
151 |
+
|
152 |
+
### 2. Stability AI Community License Requirements
|
153 |
+
|
154 |
+
- You must comply with the **Stability AI Community License Agreement** (the “Agreement”) for any usage, distribution, or modification of this Model.
|
155 |
+
- **Non-Commercial Use**: This Model is for research and academic purposes only. Any commercial usage requires registering with Stability AI or obtaining a separate commercial license.
|
156 |
+
- **Attribution & Notice**:
|
157 |
+
- Retain the notice:
|
158 |
+
```
|
159 |
+
This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved.
|
160 |
+
```
|
161 |
+
- Clearly display “Powered by Stability AI” if you build upon or showcase this Model.
|
162 |
+
- **Disclaimer & Liability**: This Model is provided **“AS IS”** with **no warranties**. Neither we nor Stability AI will be liable for any claim or damages related to Model use.
|
163 |
+
|
164 |
+
See ```STABILITY_AI_COMMUNITY_LICENSE.md``` for the full text.
|
165 |
+
|
166 |
+
---
|
167 |
+
|
168 |
+
### 3. WavCaps & Dataset Usage
|
169 |
+
|
170 |
+
- **Academic-Only for WavCaps**: By accessing any WavCaps-sourced data (including audio clips via provided links), you agree to use them **strictly for non-commercial, academic research** in accordance with WavCaps’ terms.
|
171 |
+
- **WavCaps Audio**: Each WavCaps audio subset has its own license terms. **You** are responsible for reviewing and complying with those licenses, including attribution requirements on your end.
|
172 |
+
|
173 |
+
---
|
174 |
+
|
175 |
+
### 4. UK Data Copyright Exemption
|
176 |
+
|
177 |
+
This Model was developed under the **UK data copyright exemption for non-commercial research**. Distribution or use outside these bounds must **not** violate that exemption or infringe on any underlying dataset’s license.
|
178 |
+
|
179 |
+
---
|
180 |
+
|
181 |
+
### 5. Further Information
|
182 |
+
|
183 |
+
- **Stability AI License Terms**: <https://stability.ai/community-license>
|
184 |
+
- **WavCaps License**: <https://github.com/XinhaoMei/WavCaps?tab=readme-ov-file#license>
|
185 |
+
|
186 |
+
---
|
187 |
+
|
188 |
+
**End of License**.
|
external_models/TangoFlux/STABILITY_AI_COMMUNITY_LICENSE.md
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
STABILITY AI COMMUNITY LICENSE AGREEMENT
|
2 |
+
|
3 |
+
Last Updated: July 5, 2024
|
4 |
+
1. INTRODUCTION
|
5 |
+
|
6 |
+
This Agreement applies to any individual person or entity (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
|
7 |
+
|
8 |
+
This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
|
9 |
+
|
10 |
+
By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then “You” includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity’s behalf.
|
11 |
+
|
12 |
+
2. RESEARCH & NON-COMMERCIAL USE LICENSE
|
13 |
+
|
14 |
+
Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. “Research Purpose” means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. “Non-Commercial Purpose” means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
|
15 |
+
|
16 |
+
3. COMMERCIAL USE LICENSE
|
17 |
+
|
18 |
+
Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. “Commercial Purpose” means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business’s or organization’s internal operations.
|
19 |
+
If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
|
20 |
+
|
21 |
+
4. GENERAL TERMS
|
22 |
+
|
23 |
+
Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
|
24 |
+
a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved”, and (iii) prominently display “Powered by Stability AI” on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the “Notice” text file that You changed the Stability AI Materials and how it was modified.
|
25 |
+
b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
|
26 |
+
c. Intellectual Property.
|
27 |
+
(i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
|
28 |
+
(ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
|
29 |
+
(iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
|
30 |
+
(iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
|
31 |
+
(v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI’s existing or prospective technology, products or services (collectively, “Feedback”). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided “AS IS” and You make no warranties whatsoever about any Feedback.
|
32 |
+
d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
|
33 |
+
e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
34 |
+
f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
|
35 |
+
g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
|
36 |
+
|
37 |
+
5. DEFINITIONS
|
38 |
+
|
39 |
+
“Affiliate(s)” means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, “control” means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
|
40 |
+
|
41 |
+
"Agreement" means this Stability AI Community License Agreement.
|
42 |
+
|
43 |
+
“AUP” means the Stability AI Acceptable Use Policy available at (https://stability.ai/use-policy), as may be updated from time to time.
|
44 |
+
|
45 |
+
"Derivative Work(s)” means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output, including “fine tune” and “low-rank adaptation” models derived from a Model or a Model’s output, but do not include the output of any Model.
|
46 |
+
|
47 |
+
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
|
48 |
+
|
49 |
+
“Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability’s Core Models Webpage available at (https://stability.ai/core-models), as may be updated from time to time.
|
50 |
+
|
51 |
+
"Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
|
52 |
+
|
53 |
+
"Software" means Stability AI’s proprietary software made available under this Agreement now or in the future.
|
54 |
+
|
55 |
+
“Stability AI Materials” means, collectively, Stability’s proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
|
56 |
+
|
57 |
+
“Trade Control Laws” means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
|
external_models/TangoFlux/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
from .comfyui import *
|
3 |
+
except:
|
4 |
+
pass
|
external_models/TangoFlux/assets/tangoflux.png
ADDED
![]() |
Git LFS Details
|
external_models/TangoFlux/assets/tf_opener.png
ADDED
![]() |
Git LFS Details
|
external_models/TangoFlux/assets/tf_teaser.png
ADDED
![]() |
Git LFS Details
|
external_models/TangoFlux/comfyui/README.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ComfyUI-TangoFlux
|
2 |
+
ComfyUI Custom Nodes for ["TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching"](https://arxiv.org/abs/2412.21037). These nodes, adapted from [the official implementations](https://github.com/declare-lab/TangoFlux/), generates high-quality 44.1kHz audio up to 30 seconds using just a text promptproduction.
|
3 |
+
|
4 |
+
## Installation
|
5 |
+
|
6 |
+
1. Navigate to your ComfyUI's custom_nodes directory:
|
7 |
+
```bash
|
8 |
+
cd ComfyUI/custom_nodes
|
9 |
+
```
|
10 |
+
|
11 |
+
2. Clone this repository:
|
12 |
+
```bash
|
13 |
+
git clone https://github.com/declare-lab/TangoFlux ComfyUI-TangoFlux
|
14 |
+
```
|
15 |
+
|
16 |
+
3. Install requirements:
|
17 |
+
```bash
|
18 |
+
cd ComfyUI-TangoFlux/comfyui
|
19 |
+
python install.py
|
20 |
+
```
|
21 |
+
|
22 |
+
### Or Install via ComfyUI Manager
|
23 |
+
|
24 |
+
#### Check out some demos from [the official demo page](https://tangoflux.github.io/)
|
25 |
+
|
26 |
+
## Example Workflow
|
27 |
+
|
28 |
+

|
29 |
+
|
30 |
+
## Usage
|
31 |
+
|
32 |
+
**All the necessary models should be automatically downloaded when the TangoFluxLoader node is used for the first time.**
|
33 |
+
|
34 |
+
**Models can also be downloaded using the `install.py` script**
|
35 |
+
|
36 |
+

|
37 |
+
|
38 |
+
**Manual Download:**
|
39 |
+
- Download TangoFlux from [here](https://huggingface.co/declare-lab/TangoFlux/tree/main) into `models/tangoflux`
|
40 |
+
- Download text encoders from [here](https://huggingface.co/google/flan-t5-large/tree/main) into `models/text_encoders/google-flan-t5-large`
|
41 |
+
|
42 |
+
*(Include Everything as shown in the screenshot above. Do Not Rename Anything)*
|
43 |
+
|
44 |
+
The nodes can be found in "TangoFlux" category as `TangoFluxLoader`, `TangoFluxSampler`, `TangoFluxVAEDecodeAndPlay`.
|
45 |
+
|
46 |
+

|
47 |
+
|
48 |
+
> [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup TangoFlux 2x without much audio quality degradation, in a training-free manner.
|
49 |
+
>
|
50 |
+
>
|
51 |
+
> ## 📈 Inference Latency Comparisons on a Single A800
|
52 |
+
>
|
53 |
+
>
|
54 |
+
> | TangoFlux | TeaCache (0.25) | TeaCache (0.4) |
|
55 |
+
> |:-------------------:|:----------------------------:|:--------------------:|
|
56 |
+
> | ~4.08 s | ~2.42 s | ~1.95 s |
|
57 |
+
|
58 |
+
## Citation
|
59 |
+
|
60 |
+
```bibtex
|
61 |
+
@misc{hung2024tangofluxsuperfastfaithful,
|
62 |
+
title={TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization},
|
63 |
+
author={Chia-Yu Hung and Navonil Majumder and Zhifeng Kong and Ambuj Mehrish and Rafael Valle and Bryan Catanzaro and Soujanya Poria},
|
64 |
+
year={2024},
|
65 |
+
eprint={2412.21037},
|
66 |
+
archivePrefix={arXiv},
|
67 |
+
primaryClass={cs.SD},
|
68 |
+
url={https://arxiv.org/abs/2412.21037},
|
69 |
+
}
|
70 |
+
```
|
71 |
+
```
|
72 |
+
@article{liu2024timestep,
|
73 |
+
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
|
74 |
+
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
|
75 |
+
journal={arXiv preprint arXiv:2411.19108},
|
76 |
+
year={2024}
|
77 |
+
}
|
78 |
+
```
|
external_models/TangoFlux/comfyui/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .nodes import NODE_CLASS_MAPPINGS
|
2 |
+
from .server import *
|
3 |
+
|
4 |
+
WEB_DIRECTORY = "./comfyui/web"
|
5 |
+
|
6 |
+
__all__ = ["NODE_CLASS_MAPPINGS", "WEB_DIRECTORY"]
|
external_models/TangoFlux/comfyui/example_workflow.json
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"last_node_id": 13,
|
3 |
+
"last_link_id": 15,
|
4 |
+
"nodes": [
|
5 |
+
{
|
6 |
+
"id": 10,
|
7 |
+
"type": "TangoFluxLoader",
|
8 |
+
"pos": [
|
9 |
+
380,
|
10 |
+
320
|
11 |
+
],
|
12 |
+
"size": [
|
13 |
+
210,
|
14 |
+
102
|
15 |
+
],
|
16 |
+
"flags": {},
|
17 |
+
"order": 0,
|
18 |
+
"mode": 0,
|
19 |
+
"inputs": [],
|
20 |
+
"outputs": [
|
21 |
+
{
|
22 |
+
"name": "model",
|
23 |
+
"type": "TANGOFLUX_MODEL",
|
24 |
+
"links": [
|
25 |
+
11
|
26 |
+
],
|
27 |
+
"slot_index": 0
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"name": "vae",
|
31 |
+
"type": "TANGOFLUX_VAE",
|
32 |
+
"links": [
|
33 |
+
15
|
34 |
+
],
|
35 |
+
"slot_index": 1
|
36 |
+
}
|
37 |
+
],
|
38 |
+
"properties": {
|
39 |
+
"Node name for S&R": "TangoFluxLoader"
|
40 |
+
},
|
41 |
+
"widgets_values": [
|
42 |
+
false,
|
43 |
+
0.25
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"id": 13,
|
48 |
+
"type": "TangoFluxVAEDecodeAndPlay",
|
49 |
+
"pos": [
|
50 |
+
1060,
|
51 |
+
320
|
52 |
+
],
|
53 |
+
"size": [
|
54 |
+
315,
|
55 |
+
126
|
56 |
+
],
|
57 |
+
"flags": {},
|
58 |
+
"order": 2,
|
59 |
+
"mode": 0,
|
60 |
+
"inputs": [
|
61 |
+
{
|
62 |
+
"name": "vae",
|
63 |
+
"type": "TANGOFLUX_VAE",
|
64 |
+
"link": 15
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"name": "latents",
|
68 |
+
"type": "TANGOFLUX_LATENTS",
|
69 |
+
"link": 14
|
70 |
+
}
|
71 |
+
],
|
72 |
+
"outputs": [],
|
73 |
+
"properties": {
|
74 |
+
"Node name for S&R": "TangoFluxVAEDecodeAndPlay"
|
75 |
+
},
|
76 |
+
"widgets_values": [
|
77 |
+
"TangoFlux",
|
78 |
+
"wav",
|
79 |
+
true
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"id": 11,
|
84 |
+
"type": "TangoFluxSampler",
|
85 |
+
"pos": [
|
86 |
+
620,
|
87 |
+
320
|
88 |
+
],
|
89 |
+
"size": [
|
90 |
+
400,
|
91 |
+
220
|
92 |
+
],
|
93 |
+
"flags": {},
|
94 |
+
"order": 1,
|
95 |
+
"mode": 0,
|
96 |
+
"inputs": [
|
97 |
+
{
|
98 |
+
"name": "model",
|
99 |
+
"type": "TANGOFLUX_MODEL",
|
100 |
+
"link": 11
|
101 |
+
}
|
102 |
+
],
|
103 |
+
"outputs": [
|
104 |
+
{
|
105 |
+
"name": "latents",
|
106 |
+
"type": "TANGOFLUX_LATENTS",
|
107 |
+
"links": [
|
108 |
+
14
|
109 |
+
],
|
110 |
+
"slot_index": 0
|
111 |
+
}
|
112 |
+
],
|
113 |
+
"properties": {
|
114 |
+
"Node name for S&R": "TangoFluxSampler"
|
115 |
+
},
|
116 |
+
"widgets_values": [
|
117 |
+
"A dog barking near the ocean, ocean waves crashing.",
|
118 |
+
50,
|
119 |
+
3,
|
120 |
+
10,
|
121 |
+
106139285587780,
|
122 |
+
"randomize",
|
123 |
+
1
|
124 |
+
]
|
125 |
+
}
|
126 |
+
],
|
127 |
+
"links": [
|
128 |
+
[
|
129 |
+
11,
|
130 |
+
10,
|
131 |
+
0,
|
132 |
+
11,
|
133 |
+
0,
|
134 |
+
"TANGOFLUX_MODEL"
|
135 |
+
],
|
136 |
+
[
|
137 |
+
14,
|
138 |
+
11,
|
139 |
+
0,
|
140 |
+
13,
|
141 |
+
1,
|
142 |
+
"TANGOFLUX_LATENTS"
|
143 |
+
],
|
144 |
+
[
|
145 |
+
15,
|
146 |
+
10,
|
147 |
+
1,
|
148 |
+
13,
|
149 |
+
0,
|
150 |
+
"TANGOFLUX_VAE"
|
151 |
+
]
|
152 |
+
],
|
153 |
+
"groups": [],
|
154 |
+
"config": {},
|
155 |
+
"extra": {
|
156 |
+
"ds": {
|
157 |
+
"scale": 0.9480295566502464,
|
158 |
+
"offset": [
|
159 |
+
-200.83333333333337,
|
160 |
+
-102.2460379319304
|
161 |
+
]
|
162 |
+
},
|
163 |
+
"node_versions": {
|
164 |
+
"comfyui-tangoflux": "1.0.4"
|
165 |
+
}
|
166 |
+
},
|
167 |
+
"version": 0.4
|
168 |
+
}
|
external_models/TangoFlux/comfyui/install.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
import subprocess
|
5 |
+
import traceback
|
6 |
+
import json
|
7 |
+
import re
|
8 |
+
|
9 |
+
log = logging.getLogger("TangoFlux")
|
10 |
+
|
11 |
+
download_models = True
|
12 |
+
|
13 |
+
EXT_PATH = os.path.dirname(os.path.abspath(__file__))
|
14 |
+
|
15 |
+
try:
|
16 |
+
folder_paths_path = os.path.abspath(os.path.join(EXT_PATH, "..", "..", "..", "folder_paths.py"))
|
17 |
+
|
18 |
+
sys.path.append(os.path.dirname(folder_paths_path))
|
19 |
+
|
20 |
+
import folder_paths
|
21 |
+
|
22 |
+
TANGOFLUX_DIR = os.path.join(folder_paths.models_dir, "tangoflux")
|
23 |
+
TEXT_ENCODER_DIR = os.path.join(folder_paths.models_dir, "text_encoders")
|
24 |
+
except:
|
25 |
+
download_models = False
|
26 |
+
|
27 |
+
try:
|
28 |
+
log.info("Installing requirements")
|
29 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", f"{EXT_PATH}/requirements.txt", "--no-warn-script-location"])
|
30 |
+
|
31 |
+
if download_models:
|
32 |
+
from huggingface_hub import snapshot_download
|
33 |
+
|
34 |
+
log.info("Downloading Necessary models")
|
35 |
+
|
36 |
+
try:
|
37 |
+
log.info(f"Downloading TangoFlux models to: {TANGOFLUX_DIR}")
|
38 |
+
snapshot_download(
|
39 |
+
repo_id="declare-lab/TangoFlux",
|
40 |
+
allow_patterns=["*.json", "*.safetensors"],
|
41 |
+
local_dir=TANGOFLUX_DIR,
|
42 |
+
local_dir_use_symlinks=False,
|
43 |
+
)
|
44 |
+
except Exception:
|
45 |
+
traceback.print_exc()
|
46 |
+
log.error("Failed to download TangoFlux models")
|
47 |
+
|
48 |
+
log.info("Loading config")
|
49 |
+
|
50 |
+
with open(os.path.join(TANGOFLUX_DIR, "config.json"), "r") as f:
|
51 |
+
config = json.load(f)
|
52 |
+
|
53 |
+
try:
|
54 |
+
text_encoder = re.sub(r'[<>:"/\\|?*]', '-', config.get("text_encoder_name", "google/flan-t5-large"))
|
55 |
+
text_encoder_path = os.path.join(TEXT_ENCODER_DIR, text_encoder)
|
56 |
+
|
57 |
+
log.info(f"Downloading text encoders to: {text_encoder_path}")
|
58 |
+
snapshot_download(
|
59 |
+
repo_id=config.get("text_encoder_name", "google/flan-t5-large"),
|
60 |
+
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
61 |
+
local_dir=text_encoder_path,
|
62 |
+
local_dir_use_symlinks=False,
|
63 |
+
)
|
64 |
+
except Exception:
|
65 |
+
traceback.print_exc()
|
66 |
+
log.error("Failed to download text encoders")
|
67 |
+
|
68 |
+
try:
|
69 |
+
log.info("Installing TangoFlux module")
|
70 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", os.path.join(EXT_PATH, "..")])
|
71 |
+
except Exception:
|
72 |
+
traceback.print_exc()
|
73 |
+
log.error("Failed to install TangoFlux module")
|
74 |
+
|
75 |
+
log.info("TangoFlux Installation completed")
|
76 |
+
|
77 |
+
except Exception:
|
78 |
+
traceback.print_exc()
|
79 |
+
log.error("TangoFlux Installation failed")
|
external_models/TangoFlux/comfyui/nodes.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
import re
|
8 |
+
|
9 |
+
from diffusers import AutoencoderOobleck, FluxTransformer2DModel
|
10 |
+
from huggingface_hub import snapshot_download
|
11 |
+
|
12 |
+
from comfy.utils import load_torch_file, ProgressBar
|
13 |
+
import folder_paths
|
14 |
+
|
15 |
+
from tangoflux.model import TangoFlux
|
16 |
+
from .teacache import teacache_forward
|
17 |
+
|
18 |
+
log = logging.getLogger("TangoFlux")
|
19 |
+
|
20 |
+
TANGOFLUX_DIR = os.path.join(folder_paths.models_dir, "tangoflux")
|
21 |
+
if "tangoflux" not in folder_paths.folder_names_and_paths:
|
22 |
+
current_paths = [TANGOFLUX_DIR]
|
23 |
+
else:
|
24 |
+
current_paths, _ = folder_paths.folder_names_and_paths["tangoflux"]
|
25 |
+
folder_paths.folder_names_and_paths["tangoflux"] = (
|
26 |
+
current_paths,
|
27 |
+
folder_paths.supported_pt_extensions,
|
28 |
+
)
|
29 |
+
TEXT_ENCODER_DIR = os.path.join(folder_paths.models_dir, "text_encoders")
|
30 |
+
|
31 |
+
|
32 |
+
class TangoFluxLoader:
|
33 |
+
@classmethod
|
34 |
+
def INPUT_TYPES(cls):
|
35 |
+
return {
|
36 |
+
"required": {
|
37 |
+
"enable_teacache": ("BOOLEAN", {"default": False}),
|
38 |
+
"rel_l1_thresh": (
|
39 |
+
"FLOAT",
|
40 |
+
{"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.01},
|
41 |
+
),
|
42 |
+
},
|
43 |
+
}
|
44 |
+
|
45 |
+
RETURN_TYPES = ("TANGOFLUX_MODEL", "TANGOFLUX_VAE")
|
46 |
+
RETURN_NAMES = ("model", "vae")
|
47 |
+
OUTPUT_TOOLTIPS = ("TangoFlux Model", "TangoFlux Vae")
|
48 |
+
|
49 |
+
CATEGORY = "TangoFlux"
|
50 |
+
FUNCTION = "load_tangoflux"
|
51 |
+
DESCRIPTION = "Load TangoFlux model"
|
52 |
+
|
53 |
+
def __init__(self):
|
54 |
+
self.model = None
|
55 |
+
self.vae = None
|
56 |
+
self.enable_teacache = False
|
57 |
+
self.rel_l1_thresh = 0.25
|
58 |
+
self.original_forward = FluxTransformer2DModel.forward
|
59 |
+
|
60 |
+
def load_tangoflux(
|
61 |
+
self,
|
62 |
+
enable_teacache=False,
|
63 |
+
rel_l1_thresh=0.25,
|
64 |
+
tangoflux_path=TANGOFLUX_DIR,
|
65 |
+
text_encoder_path=TEXT_ENCODER_DIR,
|
66 |
+
device="cuda",
|
67 |
+
):
|
68 |
+
if self.model is None or self.enable_teacache != enable_teacache:
|
69 |
+
|
70 |
+
pbar = ProgressBar(6)
|
71 |
+
|
72 |
+
snapshot_download(
|
73 |
+
repo_id="declare-lab/TangoFlux",
|
74 |
+
allow_patterns=["*.json", "*.safetensors"],
|
75 |
+
local_dir=tangoflux_path,
|
76 |
+
local_dir_use_symlinks=False,
|
77 |
+
)
|
78 |
+
|
79 |
+
pbar.update(1)
|
80 |
+
|
81 |
+
log.info("Loading config")
|
82 |
+
|
83 |
+
with open(os.path.join(tangoflux_path, "config.json"), "r") as f:
|
84 |
+
config = json.load(f)
|
85 |
+
|
86 |
+
pbar.update(1)
|
87 |
+
|
88 |
+
text_encoder = re.sub(
|
89 |
+
r'[<>:"/\\|?*]',
|
90 |
+
"-",
|
91 |
+
config.get("text_encoder_name", "google/flan-t5-large"),
|
92 |
+
)
|
93 |
+
text_encoder_path = os.path.join(text_encoder_path, text_encoder)
|
94 |
+
|
95 |
+
snapshot_download(
|
96 |
+
repo_id=config.get("text_encoder_name", "google/flan-t5-large"),
|
97 |
+
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
98 |
+
local_dir=text_encoder_path,
|
99 |
+
local_dir_use_symlinks=False,
|
100 |
+
)
|
101 |
+
|
102 |
+
pbar.update(1)
|
103 |
+
|
104 |
+
log.info("Loading TangoFlux models")
|
105 |
+
|
106 |
+
del self.model
|
107 |
+
self.model = None
|
108 |
+
|
109 |
+
model_weights = load_torch_file(
|
110 |
+
os.path.join(tangoflux_path, "tangoflux.safetensors"),
|
111 |
+
device=torch.device(device),
|
112 |
+
)
|
113 |
+
|
114 |
+
pbar.update(1)
|
115 |
+
|
116 |
+
if enable_teacache:
|
117 |
+
log.info("Enabling TeaCache")
|
118 |
+
FluxTransformer2DModel.forward = teacache_forward
|
119 |
+
else:
|
120 |
+
log.info("Disabling TeaCache")
|
121 |
+
FluxTransformer2DModel.forward = self.original_forward
|
122 |
+
|
123 |
+
model = TangoFlux(config=config, text_encoder_dir=text_encoder_path)
|
124 |
+
|
125 |
+
model.load_state_dict(model_weights, strict=False)
|
126 |
+
model.to(device)
|
127 |
+
|
128 |
+
if enable_teacache:
|
129 |
+
model.transformer.__class__.enable_teacache = True
|
130 |
+
model.transformer.__class__.cnt = 0
|
131 |
+
model.transformer.__class__.rel_l1_thresh = rel_l1_thresh
|
132 |
+
model.transformer.__class__.accumulated_rel_l1_distance = 0
|
133 |
+
model.transformer.__class__.previous_modulated_input = None
|
134 |
+
model.transformer.__class__.previous_residual = None
|
135 |
+
|
136 |
+
pbar.update(1)
|
137 |
+
|
138 |
+
self.model = model
|
139 |
+
del model
|
140 |
+
self.enable_teacache = enable_teacache
|
141 |
+
self.rel_l1_thresh = rel_l1_thresh
|
142 |
+
|
143 |
+
if self.vae is None:
|
144 |
+
log.info("Loading TangoFlux VAE")
|
145 |
+
|
146 |
+
vae_weights = load_torch_file(
|
147 |
+
os.path.join(tangoflux_path, "vae.safetensors")
|
148 |
+
)
|
149 |
+
self.vae = AutoencoderOobleck()
|
150 |
+
self.vae.load_state_dict(vae_weights)
|
151 |
+
self.vae.to(device)
|
152 |
+
|
153 |
+
pbar.update(1)
|
154 |
+
|
155 |
+
if self.enable_teacache == True and self.rel_l1_thresh != rel_l1_thresh:
|
156 |
+
self.model.transformer.__class__.rel_l1_thresh = rel_l1_thresh
|
157 |
+
|
158 |
+
self.rel_l1_thresh = rel_l1_thresh
|
159 |
+
|
160 |
+
return (self.model, self.vae)
|
161 |
+
|
162 |
+
|
163 |
+
class TangoFluxSampler:
|
164 |
+
@classmethod
|
165 |
+
def INPUT_TYPES(cls):
|
166 |
+
return {
|
167 |
+
"required": {
|
168 |
+
"model": ("TANGOFLUX_MODEL",),
|
169 |
+
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
170 |
+
"steps": ("INT", {"default": 50, "min": 1, "max": 10000, "step": 1}),
|
171 |
+
"guidance_scale": (
|
172 |
+
"FLOAT",
|
173 |
+
{"default": 3, "min": 1, "max": 100, "step": 1},
|
174 |
+
),
|
175 |
+
"duration": ("INT", {"default": 10, "min": 1, "max": 30, "step": 1}),
|
176 |
+
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
|
177 |
+
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
178 |
+
},
|
179 |
+
}
|
180 |
+
|
181 |
+
RETURN_TYPES = ("TANGOFLUX_LATENTS",)
|
182 |
+
RETURN_NAMES = ("latents",)
|
183 |
+
OUTPUT_TOOLTIPS = "TangoFlux Sample"
|
184 |
+
|
185 |
+
CATEGORY = "TangoFlux"
|
186 |
+
FUNCTION = "sample"
|
187 |
+
DESCRIPTION = "Sampler for TangoFlux"
|
188 |
+
|
189 |
+
def sample(
|
190 |
+
self,
|
191 |
+
model,
|
192 |
+
prompt,
|
193 |
+
steps=50,
|
194 |
+
guidance_scale=3,
|
195 |
+
duration=10,
|
196 |
+
seed=0,
|
197 |
+
batch_size=1,
|
198 |
+
device="cuda",
|
199 |
+
):
|
200 |
+
pbar = ProgressBar(steps)
|
201 |
+
|
202 |
+
with torch.no_grad():
|
203 |
+
model.to(device)
|
204 |
+
|
205 |
+
try:
|
206 |
+
if model.transformer.__class__.enable_teacache:
|
207 |
+
model.transformer.__class__.num_steps = steps
|
208 |
+
except:
|
209 |
+
pass
|
210 |
+
|
211 |
+
log.info("Generating latents with TangoFlux")
|
212 |
+
|
213 |
+
latents = model.inference_flow(
|
214 |
+
prompt,
|
215 |
+
duration=duration,
|
216 |
+
num_inference_steps=steps,
|
217 |
+
guidance_scale=guidance_scale,
|
218 |
+
seed=seed,
|
219 |
+
num_samples_per_prompt=batch_size,
|
220 |
+
callback_on_step_end=lambda: pbar.update(1),
|
221 |
+
)
|
222 |
+
|
223 |
+
return ({"latents": latents, "duration": duration},)
|
224 |
+
|
225 |
+
|
226 |
+
class TangoFluxVAEDecodeAndPlay:
|
227 |
+
@classmethod
|
228 |
+
def INPUT_TYPES(cls):
|
229 |
+
return {
|
230 |
+
"required": {
|
231 |
+
"vae": ("TANGOFLUX_VAE",),
|
232 |
+
"latents": ("TANGOFLUX_LATENTS",),
|
233 |
+
"filename_prefix": ("STRING", {"default": "TangoFlux"}),
|
234 |
+
"format": (
|
235 |
+
["wav", "mp3", "flac", "aac", "wma"],
|
236 |
+
{"default": "wav"},
|
237 |
+
),
|
238 |
+
"save_output": ("BOOLEAN", {"default": True}),
|
239 |
+
},
|
240 |
+
}
|
241 |
+
|
242 |
+
RETURN_TYPES = ()
|
243 |
+
OUTPUT_NODE = True
|
244 |
+
|
245 |
+
CATEGORY = "TangoFlux"
|
246 |
+
FUNCTION = "play"
|
247 |
+
DESCRIPTION = "Decoder and Player for TangoFlux"
|
248 |
+
|
249 |
+
def decode(self, vae, latents):
|
250 |
+
results = []
|
251 |
+
|
252 |
+
for latent in latents:
|
253 |
+
decoded = vae.decode(latent.unsqueeze(0).transpose(2, 1)).sample.cpu()
|
254 |
+
results.append(decoded)
|
255 |
+
|
256 |
+
results = torch.cat(results, dim=0)
|
257 |
+
|
258 |
+
return results
|
259 |
+
|
260 |
+
def play(
|
261 |
+
self,
|
262 |
+
vae,
|
263 |
+
latents,
|
264 |
+
filename_prefix="TangoFlux",
|
265 |
+
format="wav",
|
266 |
+
save_output=True,
|
267 |
+
device="cuda",
|
268 |
+
):
|
269 |
+
audios = []
|
270 |
+
pbar = ProgressBar(len(latents) + 2)
|
271 |
+
|
272 |
+
if save_output:
|
273 |
+
output_dir = folder_paths.get_output_directory()
|
274 |
+
prefix_append = ""
|
275 |
+
type = "output"
|
276 |
+
else:
|
277 |
+
output_dir = folder_paths.get_temp_directory()
|
278 |
+
prefix_append = "_temp_" + "".join(
|
279 |
+
random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)
|
280 |
+
)
|
281 |
+
type = "temp"
|
282 |
+
|
283 |
+
filename_prefix += prefix_append
|
284 |
+
full_output_folder, filename, counter, subfolder, _ = (
|
285 |
+
folder_paths.get_save_image_path(filename_prefix, output_dir)
|
286 |
+
)
|
287 |
+
|
288 |
+
os.makedirs(full_output_folder, exist_ok=True)
|
289 |
+
|
290 |
+
pbar.update(1)
|
291 |
+
|
292 |
+
duration = latents["duration"]
|
293 |
+
latents = latents["latents"]
|
294 |
+
|
295 |
+
vae.to(device)
|
296 |
+
|
297 |
+
log.info("Decoding Tangoflux latents")
|
298 |
+
|
299 |
+
waves = self.decode(vae, latents)
|
300 |
+
|
301 |
+
pbar.update(1)
|
302 |
+
|
303 |
+
for wave in waves:
|
304 |
+
waveform_end = int(duration * vae.config.sampling_rate)
|
305 |
+
wave = wave[:, :waveform_end]
|
306 |
+
|
307 |
+
file = f"{filename}_{counter:05}_.{format}"
|
308 |
+
|
309 |
+
torchaudio.save(
|
310 |
+
os.path.join(full_output_folder, file), wave, sample_rate=44100
|
311 |
+
)
|
312 |
+
|
313 |
+
counter += 1
|
314 |
+
|
315 |
+
audios.append({"filename": file, "subfolder": subfolder, "type": type})
|
316 |
+
|
317 |
+
pbar.update(1)
|
318 |
+
|
319 |
+
return {
|
320 |
+
"ui": {"audios": audios},
|
321 |
+
}
|
322 |
+
|
323 |
+
|
324 |
+
NODE_CLASS_MAPPINGS = {
|
325 |
+
"TangoFluxLoader": TangoFluxLoader,
|
326 |
+
"TangoFluxSampler": TangoFluxSampler,
|
327 |
+
"TangoFluxVAEDecodeAndPlay": TangoFluxVAEDecodeAndPlay,
|
328 |
+
}
|
external_models/TangoFlux/comfyui/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchaudio
|
2 |
+
torchlibrosa
|
3 |
+
torchvision
|
4 |
+
diffusers
|
5 |
+
accelerate
|
6 |
+
datasets
|
7 |
+
librosa
|
8 |
+
wandb
|
9 |
+
tqdm
|
external_models/TangoFlux/comfyui/server.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import server
|
3 |
+
import folder_paths
|
4 |
+
|
5 |
+
web = server.web
|
6 |
+
|
7 |
+
|
8 |
+
@server.PromptServer.instance.routes.get("/tangoflux/playaudio")
|
9 |
+
async def play_audio(request):
|
10 |
+
query = request.rel_url.query
|
11 |
+
|
12 |
+
filename = query.get("filename", None)
|
13 |
+
|
14 |
+
if filename is None:
|
15 |
+
return web.Response(status=404)
|
16 |
+
|
17 |
+
if filename[0] == "/" or ".." in filename:
|
18 |
+
return web.Response(status=403)
|
19 |
+
|
20 |
+
filename, output_dir = folder_paths.annotated_filepath(filename)
|
21 |
+
|
22 |
+
if not output_dir:
|
23 |
+
file_type = query.get("type", "output")
|
24 |
+
output_dir = folder_paths.get_directory_by_type(file_type)
|
25 |
+
|
26 |
+
if output_dir is None:
|
27 |
+
return web.Response(status=400)
|
28 |
+
|
29 |
+
subfolder = query.get("subfolder", None)
|
30 |
+
if subfolder:
|
31 |
+
full_output_dir = os.path.join(output_dir, subfolder)
|
32 |
+
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
33 |
+
return web.Response(status=403)
|
34 |
+
output_dir = full_output_dir
|
35 |
+
|
36 |
+
filename = os.path.basename(filename)
|
37 |
+
file_path = os.path.join(output_dir, filename)
|
38 |
+
|
39 |
+
if not os.path.isfile(file_path):
|
40 |
+
return web.Response(status=404)
|
41 |
+
|
42 |
+
_, ext = os.path.splitext(filename)
|
43 |
+
ext = ext.lower()
|
44 |
+
|
45 |
+
content_types = {
|
46 |
+
".wav": "audio/wav",
|
47 |
+
".mp3": "audio/mpeg",
|
48 |
+
".flac": "audio/flac",
|
49 |
+
".aac": "audio/aac",
|
50 |
+
".wma": "audio/x-ms-wma",
|
51 |
+
}
|
52 |
+
|
53 |
+
content_type = content_types.get(ext, None)
|
54 |
+
|
55 |
+
if content_type is None:
|
56 |
+
return web.Response(status=400)
|
57 |
+
|
58 |
+
try:
|
59 |
+
with open(file_path, "rb") as file:
|
60 |
+
data = file.read()
|
61 |
+
except:
|
62 |
+
return web.Response(status=500)
|
63 |
+
|
64 |
+
return web.Response(body=data, content_type=content_type)
|
external_models/TangoFlux/comfyui/teacache.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code from https://github.com/ali-vilab/TeaCache/blob/main/TeaCache4TangoFlux/teacache_tango_flux.py
|
2 |
+
|
3 |
+
from typing import Any, Dict, Optional, Union
|
4 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
5 |
+
from diffusers.utils import (
|
6 |
+
USE_PEFT_BACKEND,
|
7 |
+
is_torch_version,
|
8 |
+
logging,
|
9 |
+
scale_lora_layers,
|
10 |
+
unscale_lora_layers,
|
11 |
+
)
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
17 |
+
|
18 |
+
|
19 |
+
def teacache_forward(
|
20 |
+
self,
|
21 |
+
hidden_states: torch.Tensor,
|
22 |
+
encoder_hidden_states: torch.Tensor = None,
|
23 |
+
pooled_projections: torch.Tensor = None,
|
24 |
+
timestep: torch.LongTensor = None,
|
25 |
+
img_ids: torch.Tensor = None,
|
26 |
+
txt_ids: torch.Tensor = None,
|
27 |
+
guidance: torch.Tensor = None,
|
28 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
29 |
+
return_dict: bool = True,
|
30 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
31 |
+
"""
|
32 |
+
The [`FluxTransformer2DModel`] forward method.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
36 |
+
Input `hidden_states`.
|
37 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
38 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
39 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
40 |
+
from the embeddings of input conditions.
|
41 |
+
timestep ( `torch.LongTensor`):
|
42 |
+
Used to indicate denoising step.
|
43 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
44 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
45 |
+
joint_attention_kwargs (`dict`, *optional*):
|
46 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
47 |
+
`self.processor` in
|
48 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
49 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
50 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
51 |
+
tuple.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
55 |
+
`tuple` where the first element is the sample tensor.
|
56 |
+
"""
|
57 |
+
if joint_attention_kwargs is not None:
|
58 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
59 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
60 |
+
else:
|
61 |
+
lora_scale = 1.0
|
62 |
+
|
63 |
+
if USE_PEFT_BACKEND:
|
64 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
65 |
+
scale_lora_layers(self, lora_scale)
|
66 |
+
else:
|
67 |
+
if (
|
68 |
+
joint_attention_kwargs is not None
|
69 |
+
and joint_attention_kwargs.get("scale", None) is not None
|
70 |
+
):
|
71 |
+
logger.warning(
|
72 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
73 |
+
)
|
74 |
+
hidden_states = self.x_embedder(hidden_states)
|
75 |
+
|
76 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
77 |
+
if guidance is not None:
|
78 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
79 |
+
else:
|
80 |
+
guidance = None
|
81 |
+
temb = (
|
82 |
+
self.time_text_embed(timestep, pooled_projections)
|
83 |
+
if guidance is None
|
84 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
85 |
+
)
|
86 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
87 |
+
|
88 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
89 |
+
image_rotary_emb = self.pos_embed(ids)
|
90 |
+
|
91 |
+
if self.enable_teacache:
|
92 |
+
inp = hidden_states.clone()
|
93 |
+
temb_ = temb.clone()
|
94 |
+
modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
95 |
+
self.transformer_blocks[0].norm1(inp, emb=temb_)
|
96 |
+
)
|
97 |
+
if self.cnt == 0 or self.cnt == self.num_steps - 1:
|
98 |
+
should_calc = True
|
99 |
+
self.accumulated_rel_l1_distance = 0
|
100 |
+
else:
|
101 |
+
coefficients = [
|
102 |
+
4.98651651e02,
|
103 |
+
-2.83781631e02,
|
104 |
+
5.58554382e01,
|
105 |
+
-3.82021401e00,
|
106 |
+
2.64230861e-01,
|
107 |
+
]
|
108 |
+
rescale_func = np.poly1d(coefficients)
|
109 |
+
self.accumulated_rel_l1_distance += rescale_func(
|
110 |
+
(
|
111 |
+
(modulated_inp - self.previous_modulated_input).abs().mean()
|
112 |
+
/ self.previous_modulated_input.abs().mean()
|
113 |
+
)
|
114 |
+
.cpu()
|
115 |
+
.item()
|
116 |
+
)
|
117 |
+
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
118 |
+
should_calc = False
|
119 |
+
else:
|
120 |
+
should_calc = True
|
121 |
+
self.accumulated_rel_l1_distance = 0
|
122 |
+
self.previous_modulated_input = modulated_inp
|
123 |
+
self.cnt += 1
|
124 |
+
if self.cnt == self.num_steps:
|
125 |
+
self.cnt = 0
|
126 |
+
|
127 |
+
if self.enable_teacache:
|
128 |
+
if not should_calc:
|
129 |
+
hidden_states += self.previous_residual
|
130 |
+
else:
|
131 |
+
ori_hidden_states = hidden_states.clone()
|
132 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
133 |
+
if self.training and self.gradient_checkpointing:
|
134 |
+
|
135 |
+
def create_custom_forward(module, return_dict=None):
|
136 |
+
def custom_forward(*inputs):
|
137 |
+
if return_dict is not None:
|
138 |
+
return module(*inputs, return_dict=return_dict)
|
139 |
+
else:
|
140 |
+
return module(*inputs)
|
141 |
+
|
142 |
+
return custom_forward
|
143 |
+
|
144 |
+
ckpt_kwargs: Dict[str, Any] = (
|
145 |
+
{"use_reentrant": False}
|
146 |
+
if is_torch_version(">=", "1.11.0")
|
147 |
+
else {}
|
148 |
+
)
|
149 |
+
encoder_hidden_states, hidden_states = (
|
150 |
+
torch.utils.checkpoint.checkpoint(
|
151 |
+
create_custom_forward(block),
|
152 |
+
hidden_states,
|
153 |
+
encoder_hidden_states,
|
154 |
+
temb,
|
155 |
+
image_rotary_emb,
|
156 |
+
**ckpt_kwargs,
|
157 |
+
)
|
158 |
+
)
|
159 |
+
|
160 |
+
else:
|
161 |
+
encoder_hidden_states, hidden_states = block(
|
162 |
+
hidden_states=hidden_states,
|
163 |
+
encoder_hidden_states=encoder_hidden_states,
|
164 |
+
temb=temb,
|
165 |
+
image_rotary_emb=image_rotary_emb,
|
166 |
+
)
|
167 |
+
|
168 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
169 |
+
|
170 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
171 |
+
if self.training and self.gradient_checkpointing:
|
172 |
+
|
173 |
+
def create_custom_forward(module, return_dict=None):
|
174 |
+
def custom_forward(*inputs):
|
175 |
+
if return_dict is not None:
|
176 |
+
return module(*inputs, return_dict=return_dict)
|
177 |
+
else:
|
178 |
+
return module(*inputs)
|
179 |
+
|
180 |
+
return custom_forward
|
181 |
+
|
182 |
+
ckpt_kwargs: Dict[str, Any] = (
|
183 |
+
{"use_reentrant": False}
|
184 |
+
if is_torch_version(">=", "1.11.0")
|
185 |
+
else {}
|
186 |
+
)
|
187 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
188 |
+
create_custom_forward(block),
|
189 |
+
hidden_states,
|
190 |
+
temb,
|
191 |
+
image_rotary_emb,
|
192 |
+
**ckpt_kwargs,
|
193 |
+
)
|
194 |
+
|
195 |
+
else:
|
196 |
+
hidden_states = block(
|
197 |
+
hidden_states=hidden_states,
|
198 |
+
temb=temb,
|
199 |
+
image_rotary_emb=image_rotary_emb,
|
200 |
+
)
|
201 |
+
|
202 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
203 |
+
self.previous_residual = hidden_states - ori_hidden_states
|
204 |
+
else:
|
205 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
206 |
+
if self.training and self.gradient_checkpointing:
|
207 |
+
|
208 |
+
def create_custom_forward(module, return_dict=None):
|
209 |
+
def custom_forward(*inputs):
|
210 |
+
if return_dict is not None:
|
211 |
+
return module(*inputs, return_dict=return_dict)
|
212 |
+
else:
|
213 |
+
return module(*inputs)
|
214 |
+
|
215 |
+
return custom_forward
|
216 |
+
|
217 |
+
ckpt_kwargs: Dict[str, Any] = (
|
218 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
219 |
+
)
|
220 |
+
encoder_hidden_states, hidden_states = (
|
221 |
+
torch.utils.checkpoint.checkpoint(
|
222 |
+
create_custom_forward(block),
|
223 |
+
hidden_states,
|
224 |
+
encoder_hidden_states,
|
225 |
+
temb,
|
226 |
+
image_rotary_emb,
|
227 |
+
**ckpt_kwargs,
|
228 |
+
)
|
229 |
+
)
|
230 |
+
|
231 |
+
else:
|
232 |
+
encoder_hidden_states, hidden_states = block(
|
233 |
+
hidden_states=hidden_states,
|
234 |
+
encoder_hidden_states=encoder_hidden_states,
|
235 |
+
temb=temb,
|
236 |
+
image_rotary_emb=image_rotary_emb,
|
237 |
+
)
|
238 |
+
|
239 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
240 |
+
|
241 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
242 |
+
if self.training and self.gradient_checkpointing:
|
243 |
+
|
244 |
+
def create_custom_forward(module, return_dict=None):
|
245 |
+
def custom_forward(*inputs):
|
246 |
+
if return_dict is not None:
|
247 |
+
return module(*inputs, return_dict=return_dict)
|
248 |
+
else:
|
249 |
+
return module(*inputs)
|
250 |
+
|
251 |
+
return custom_forward
|
252 |
+
|
253 |
+
ckpt_kwargs: Dict[str, Any] = (
|
254 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
255 |
+
)
|
256 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
257 |
+
create_custom_forward(block),
|
258 |
+
hidden_states,
|
259 |
+
temb,
|
260 |
+
image_rotary_emb,
|
261 |
+
**ckpt_kwargs,
|
262 |
+
)
|
263 |
+
|
264 |
+
else:
|
265 |
+
hidden_states = block(
|
266 |
+
hidden_states=hidden_states,
|
267 |
+
temb=temb,
|
268 |
+
image_rotary_emb=image_rotary_emb,
|
269 |
+
)
|
270 |
+
|
271 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
272 |
+
|
273 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
274 |
+
output = self.proj_out(hidden_states)
|
275 |
+
|
276 |
+
if USE_PEFT_BACKEND:
|
277 |
+
# remove `lora_scale` from each PEFT layer
|
278 |
+
unscale_lora_layers(self, lora_scale)
|
279 |
+
|
280 |
+
if not return_dict:
|
281 |
+
return (output,)
|
282 |
+
|
283 |
+
return Transformer2DModelOutput(sample=output)
|
external_models/TangoFlux/comfyui/web/js/playAudio.js
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { app } from "../../../scripts/app.js";
|
2 |
+
import { api } from "../../../scripts/api.js";
|
3 |
+
|
4 |
+
app.registerExtension({
|
5 |
+
name: "TangoFlux.playAudio",
|
6 |
+
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
7 |
+
if (nodeData.name === "TangoFluxVAEDecodeAndPlay") {
|
8 |
+
const originalNodeCreated = nodeType.prototype.onNodeCreated;
|
9 |
+
|
10 |
+
nodeType.prototype.onNodeCreated = async function () {
|
11 |
+
originalNodeCreated?.apply(this, arguments);
|
12 |
+
this.widgets_count = this.widgets?.length || 0;
|
13 |
+
|
14 |
+
this.addAudioWidgets = (audios) => {
|
15 |
+
if (this.widgets) {
|
16 |
+
for (let i = 0; i < this.widgets.length; i++) {
|
17 |
+
if (this.widgets[i].name.startsWith("_playaudio")) {
|
18 |
+
this.widgets[i].onRemove?.();
|
19 |
+
}
|
20 |
+
}
|
21 |
+
this.widgets.length = this.widgets_count;
|
22 |
+
}
|
23 |
+
|
24 |
+
let index = 0
|
25 |
+
for (const params of audios) {
|
26 |
+
const audioElement = document.createElement("audio");
|
27 |
+
audioElement.controls = true;
|
28 |
+
|
29 |
+
this.addDOMWidget("_playaudio" + index, "playaudio", audioElement, {
|
30 |
+
serialize: false,
|
31 |
+
hideOnZoom: false,
|
32 |
+
});
|
33 |
+
audioElement.src = api.apiURL(
|
34 |
+
`/tangoflux/playaudio?${new URLSearchParams(params)}`
|
35 |
+
);
|
36 |
+
index++
|
37 |
+
}
|
38 |
+
|
39 |
+
requestAnimationFrame(() => {
|
40 |
+
const newSize = this.computeSize();
|
41 |
+
newSize[0] = Math.max(newSize[0], this.size[0]);
|
42 |
+
newSize[1] = Math.max(newSize[1], this.size[1]);
|
43 |
+
this.onResize?.(newSize);
|
44 |
+
app.graph.setDirtyCanvas(true, false);
|
45 |
+
});
|
46 |
+
};
|
47 |
+
};
|
48 |
+
|
49 |
+
const originalNodeExecuted = nodeType.prototype.onExecuted;
|
50 |
+
|
51 |
+
nodeType.prototype.onExecuted = async function (message) {
|
52 |
+
originalNodeExecuted?.apply(this, arguments);
|
53 |
+
if (message?.audios) {
|
54 |
+
this.addAudioWidgets(message.audios);
|
55 |
+
}
|
56 |
+
};
|
57 |
+
}
|
58 |
+
},
|
59 |
+
});
|
external_models/TangoFlux/configs/__init__.py
ADDED
File without changes
|
external_models/TangoFlux/configs/accelerator_config.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"compute_environment": "LOCAL_MACHINE",
|
3 |
+
"distributed_type": "MULTI_GPU",
|
4 |
+
"main_process_port": 29512,
|
5 |
+
"downcast_bf16": false,
|
6 |
+
"machine_rank": 0,
|
7 |
+
"gpu_ids": "0,1",
|
8 |
+
"main_training_function": "main",
|
9 |
+
"mixed_precision": "no",
|
10 |
+
"num_machines": 1,
|
11 |
+
"num_processes": 2,
|
12 |
+
"rdzv_backend": "static",
|
13 |
+
"same_network": true,
|
14 |
+
"tpu_use_cluster": false,
|
15 |
+
"tpu_use_sudo": false,
|
16 |
+
"use_cpu": false
|
17 |
+
}
|
external_models/TangoFlux/configs/tangoflux_config.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Absolute paths for different resources
|
3 |
+
paths:
|
4 |
+
train_file: "data/train.json"
|
5 |
+
val_file: "data/val.json"
|
6 |
+
test_file: "data/val.json"
|
7 |
+
resume_from_checkpoint: ""
|
8 |
+
output_dir: "outputs/"
|
9 |
+
|
10 |
+
# Training-related parameters
|
11 |
+
training:
|
12 |
+
per_device_batch_size: 4
|
13 |
+
learning_rate: 5e-4
|
14 |
+
gradient_accumulation_steps: 1
|
15 |
+
num_train_epochs: 80
|
16 |
+
num_warmup_steps: 1000
|
17 |
+
max_audio_duration: 30
|
18 |
+
|
19 |
+
|
20 |
+
# Model and optimizer parameters,
|
21 |
+
model:
|
22 |
+
num_layers: 6
|
23 |
+
num_single_layers: 18
|
24 |
+
in_channels: 64
|
25 |
+
attention_head_dim: 128
|
26 |
+
joint_attention_dim: 1024
|
27 |
+
num_attention_heads: 8
|
28 |
+
audio_seq_len: 645
|
29 |
+
max_duration: 30
|
30 |
+
uncondition: false
|
31 |
+
text_encoder_name: "google/flan-t5-large"
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
external_models/TangoFlux/crpo.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
python3 tangoflux/generate_crpo.py --json_path='path_to_prompt_bank.json' --sample_size=50 --model='path_to_tangoflux.safetensors' --num_samples=5 --output_dir='outputs'
|
2 |
+
python3 tangoflux/label_crpo.py --json_path='outputs/results.json' --output_dir='outputs/crpo_iteration1' --num_samples=5
|
external_models/TangoFlux/inference.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchaudio
|
2 |
+
from tangoflux import TangoFluxInference
|
3 |
+
|
4 |
+
model = TangoFluxInference(name="declare-lab/TangoFlux")
|
5 |
+
audio = model.generate("Hammer slowly hitting the wooden table", steps=50, duration=10)
|
6 |
+
|
7 |
+
torchaudio.save("output.wav", audio, sample_rate=44100)
|
external_models/TangoFlux/replicate_demo/cog.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://cog.run/yaml
|
3 |
+
|
4 |
+
build:
|
5 |
+
# set to true if your model requires a GPU
|
6 |
+
gpu: true
|
7 |
+
|
8 |
+
# a list of ubuntu apt packages to install
|
9 |
+
system_packages:
|
10 |
+
- "libgl1-mesa-glx"
|
11 |
+
- "libglib2.0-0"
|
12 |
+
|
13 |
+
# python version in the form '3.11' or '3.11.4'
|
14 |
+
python_version: "3.11"
|
15 |
+
|
16 |
+
# a list of packages in the format <package-name>==<version>
|
17 |
+
python_packages:
|
18 |
+
- torch==2.4.0
|
19 |
+
- torchaudio==2.4.0
|
20 |
+
- torchlibrosa==0.1.0
|
21 |
+
- torchvision==0.19.0
|
22 |
+
- transformers==4.44.0
|
23 |
+
- diffusers==0.30.0
|
24 |
+
- accelerate==0.34.2
|
25 |
+
- datasets==2.21.0
|
26 |
+
- librosa
|
27 |
+
- ipython
|
28 |
+
|
29 |
+
run:
|
30 |
+
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
31 |
+
predict: "predict.py:Predictor"
|
external_models/TangoFlux/replicate_demo/predict.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Prediction interface for Cog ⚙️
|
2 |
+
# https://cog.run/python
|
3 |
+
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
import time
|
7 |
+
import json
|
8 |
+
from cog import BasePredictor, Input, Path
|
9 |
+
from diffusers import AutoencoderOobleck
|
10 |
+
import soundfile as sf
|
11 |
+
from safetensors.torch import load_file
|
12 |
+
from huggingface_hub import snapshot_download
|
13 |
+
from tangoflux.model import TangoFlux
|
14 |
+
from tangoflux import TangoFluxInference
|
15 |
+
|
16 |
+
MODEL_CACHE = "model_cache"
|
17 |
+
MODEL_URL = (
|
18 |
+
"https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class CachedTangoFluxInference(TangoFluxInference):
|
23 |
+
## load the weights from replicate.delivery for faster booting
|
24 |
+
def __init__(self, name="declare-lab/TangoFlux", device="cuda", cached_paths=None):
|
25 |
+
if cached_paths:
|
26 |
+
paths = cached_paths
|
27 |
+
else:
|
28 |
+
paths = snapshot_download(repo_id=name)
|
29 |
+
|
30 |
+
self.vae = AutoencoderOobleck()
|
31 |
+
vae_weights = load_file(f"{paths}/vae.safetensors")
|
32 |
+
self.vae.load_state_dict(vae_weights)
|
33 |
+
weights = load_file(f"{paths}/tangoflux.safetensors")
|
34 |
+
|
35 |
+
with open(f"{paths}/config.json", "r") as f:
|
36 |
+
config = json.load(f)
|
37 |
+
self.model = TangoFlux(config)
|
38 |
+
self.model.load_state_dict(weights, strict=False)
|
39 |
+
self.vae.to(device)
|
40 |
+
self.model.to(device)
|
41 |
+
|
42 |
+
|
43 |
+
def download_weights(url, dest):
|
44 |
+
start = time.time()
|
45 |
+
print("downloading url: ", url)
|
46 |
+
print("downloading to: ", dest)
|
47 |
+
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
|
48 |
+
print("downloading took: ", time.time() - start)
|
49 |
+
|
50 |
+
|
51 |
+
class Predictor(BasePredictor):
|
52 |
+
def setup(self) -> None:
|
53 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
54 |
+
|
55 |
+
if not os.path.exists(MODEL_CACHE):
|
56 |
+
print("downloading")
|
57 |
+
download_weights(MODEL_URL, MODEL_CACHE)
|
58 |
+
|
59 |
+
self.model = CachedTangoFluxInference(
|
60 |
+
cached_paths=f"{MODEL_CACHE}/declare-lab/TangoFlux"
|
61 |
+
)
|
62 |
+
|
63 |
+
def predict(
|
64 |
+
self,
|
65 |
+
prompt: str = Input(
|
66 |
+
description="Input prompt", default="Hammer slowly hitting the wooden table"
|
67 |
+
),
|
68 |
+
duration: int = Input(
|
69 |
+
description="Duration of the output audio in seconds", default=10
|
70 |
+
),
|
71 |
+
steps: int = Input(
|
72 |
+
description="Number of inference steps", ge=1, le=200, default=25
|
73 |
+
),
|
74 |
+
guidance_scale: float = Input(
|
75 |
+
description="Scale for classifier-free guidance", ge=1, le=20, default=4.5
|
76 |
+
),
|
77 |
+
) -> Path:
|
78 |
+
"""Run a single prediction on the model"""
|
79 |
+
|
80 |
+
audio = self.model.generate(
|
81 |
+
prompt,
|
82 |
+
steps=steps,
|
83 |
+
guidance_scale=guidance_scale,
|
84 |
+
duration=duration,
|
85 |
+
)
|
86 |
+
audio_numpy = audio.numpy()
|
87 |
+
out_path = "/tmp/out.wav"
|
88 |
+
|
89 |
+
sf.write(
|
90 |
+
out_path, audio_numpy.T, samplerate=self.model.vae.config.sampling_rate
|
91 |
+
)
|
92 |
+
return Path(out_path)
|
external_models/TangoFlux/requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.4.0
|
2 |
+
torchaudio==2.4.0
|
3 |
+
torchlibrosa==0.1.0
|
4 |
+
torchvision==0.19.0
|
5 |
+
transformers==4.44.0
|
6 |
+
diffusers==0.30.0
|
7 |
+
accelerate==0.34.2
|
8 |
+
datasets==2.21.0
|
9 |
+
librosa
|
10 |
+
tqdm
|
11 |
+
wandb
|
12 |
+
|
external_models/TangoFlux/setup.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="tangoflux",
|
5 |
+
description="TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching",
|
6 |
+
version="0.1.0",
|
7 |
+
packages=["tangoflux"],
|
8 |
+
install_requires=[
|
9 |
+
"torch==2.4.0",
|
10 |
+
"torchaudio==2.4.0",
|
11 |
+
"torchlibrosa==0.1.0",
|
12 |
+
"torchvision==0.19.0",
|
13 |
+
"transformers==4.44.0",
|
14 |
+
"diffusers==0.30.0",
|
15 |
+
"accelerate==0.34.2",
|
16 |
+
"datasets==2.21.0",
|
17 |
+
"librosa",
|
18 |
+
"tqdm",
|
19 |
+
"wandb",
|
20 |
+
"click",
|
21 |
+
"gradio",
|
22 |
+
"torchaudio",
|
23 |
+
],
|
24 |
+
entry_points={
|
25 |
+
"console_scripts": [
|
26 |
+
"tangoflux=tangoflux.cli:main",
|
27 |
+
"tangoflux-demo=tangoflux.demo:main",
|
28 |
+
],
|
29 |
+
},
|
30 |
+
)
|
external_models/TangoFlux/tangoflux/__init__.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import AutoencoderOobleck
|
2 |
+
import torch
|
3 |
+
from transformers import T5EncoderModel, T5TokenizerFast
|
4 |
+
from diffusers import FluxTransformer2DModel
|
5 |
+
from torch import nn
|
6 |
+
from typing import List
|
7 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
8 |
+
from diffusers.training_utils import compute_density_for_timestep_sampling
|
9 |
+
import copy
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import numpy as np
|
12 |
+
from tangoflux.model import TangoFlux
|
13 |
+
from huggingface_hub import snapshot_download
|
14 |
+
from tqdm import tqdm
|
15 |
+
from typing import Optional, Union, List
|
16 |
+
from datasets import load_dataset, Audio
|
17 |
+
from math import pi
|
18 |
+
import json
|
19 |
+
import inspect
|
20 |
+
import yaml
|
21 |
+
from safetensors.torch import load_file
|
22 |
+
|
23 |
+
|
24 |
+
class TangoFluxInference:
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
name="declare-lab/TangoFlux",
|
29 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
30 |
+
):
|
31 |
+
|
32 |
+
self.vae = AutoencoderOobleck()
|
33 |
+
|
34 |
+
paths = snapshot_download(repo_id=name)
|
35 |
+
vae_weights = load_file("{}/vae.safetensors".format(paths))
|
36 |
+
self.vae.load_state_dict(vae_weights)
|
37 |
+
weights = load_file("{}/tangoflux.safetensors".format(paths))
|
38 |
+
|
39 |
+
with open("{}/config.json".format(paths), "r") as f:
|
40 |
+
config = json.load(f)
|
41 |
+
self.model = TangoFlux(config)
|
42 |
+
self.model.load_state_dict(weights, strict=False)
|
43 |
+
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
|
44 |
+
self.vae.to(device)
|
45 |
+
self.model.to(device)
|
46 |
+
|
47 |
+
def generate(self, prompt, steps=25, duration=10, guidance_scale=4.5):
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
latents = self.model.inference_flow(
|
51 |
+
prompt,
|
52 |
+
duration=duration,
|
53 |
+
num_inference_steps=steps,
|
54 |
+
guidance_scale=guidance_scale,
|
55 |
+
)
|
56 |
+
|
57 |
+
wave = self.vae.decode(latents.transpose(2, 1)).sample.cpu()[0]
|
58 |
+
waveform_end = int(duration * self.vae.config.sampling_rate)
|
59 |
+
wave = wave[:, :waveform_end]
|
60 |
+
return wave
|
external_models/TangoFlux/tangoflux/cli.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import click
|
2 |
+
import torchaudio
|
3 |
+
from tangoflux import TangoFluxInference
|
4 |
+
|
5 |
+
@click.command()
|
6 |
+
@click.argument('prompt')
|
7 |
+
@click.argument('output_file')
|
8 |
+
@click.option('--duration', default=10, type=int, help='Duration in seconds (1-30)')
|
9 |
+
@click.option('--steps', default=50, type=int, help='Number of inference steps (10-100)')
|
10 |
+
def main(prompt: str, output_file: str, duration: int, steps: int):
|
11 |
+
"""Generate audio from text using TangoFlux.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
prompt: Text description of the audio to generate
|
15 |
+
output_file: Path to save the generated audio file
|
16 |
+
duration: Duration of generated audio in seconds (default: 10)
|
17 |
+
steps: Number of inference steps (default: 50)
|
18 |
+
"""
|
19 |
+
if not 1 <= duration <= 30:
|
20 |
+
raise click.BadParameter('Duration must be between 1 and 30 seconds')
|
21 |
+
if not 10 <= steps <= 100:
|
22 |
+
raise click.BadParameter('Steps must be between 10 and 100')
|
23 |
+
|
24 |
+
model = TangoFluxInference(name="declare-lab/TangoFlux")
|
25 |
+
audio = model.generate(prompt, steps=steps, duration=duration)
|
26 |
+
torchaudio.save(output_file, audio, sample_rate=44100)
|
27 |
+
|
28 |
+
if __name__ == '__main__':
|
29 |
+
main()
|
external_models/TangoFlux/tangoflux/demo.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torchaudio
|
3 |
+
import click
|
4 |
+
import tempfile
|
5 |
+
from tangoflux import TangoFluxInference
|
6 |
+
|
7 |
+
model = TangoFluxInference(name="declare-lab/TangoFlux")
|
8 |
+
|
9 |
+
|
10 |
+
def generate_audio(prompt, duration, steps):
|
11 |
+
audio = model.generate(prompt, steps=steps, duration=duration)
|
12 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
13 |
+
torchaudio.save(f.name, audio, sample_rate=44100)
|
14 |
+
return f.name
|
15 |
+
|
16 |
+
|
17 |
+
examples = [
|
18 |
+
["Hammer slowly hitting the wooden table", 10, 50],
|
19 |
+
["Gentle rain falling on a tin roof", 15, 50],
|
20 |
+
["Wind chimes tinkling in a light breeze", 10, 50],
|
21 |
+
["Rhythmic wooden table tapping overlaid with steady water pouring sound", 10, 50],
|
22 |
+
]
|
23 |
+
|
24 |
+
with gr.Blocks(title="TangoFlux Text-to-Audio Generation") as demo:
|
25 |
+
gr.Markdown("# TangoFlux Text-to-Audio Generation")
|
26 |
+
gr.Markdown("Generate audio from text descriptions using TangoFlux")
|
27 |
+
|
28 |
+
with gr.Row():
|
29 |
+
with gr.Column():
|
30 |
+
prompt = gr.Textbox(
|
31 |
+
label="Text Prompt", placeholder="Enter your audio description..."
|
32 |
+
)
|
33 |
+
duration = gr.Slider(
|
34 |
+
minimum=1, maximum=30, value=10, step=1, label="Duration (seconds)"
|
35 |
+
)
|
36 |
+
steps = gr.Slider(
|
37 |
+
minimum=10, maximum=100, value=50, step=10, label="Number of Steps"
|
38 |
+
)
|
39 |
+
generate_btn = gr.Button("Generate Audio")
|
40 |
+
|
41 |
+
with gr.Column():
|
42 |
+
audio_output = gr.Audio(label="Generated Audio")
|
43 |
+
|
44 |
+
generate_btn.click(
|
45 |
+
fn=generate_audio, inputs=[prompt, duration, steps], outputs=audio_output
|
46 |
+
)
|
47 |
+
|
48 |
+
gr.Examples(
|
49 |
+
examples=examples,
|
50 |
+
inputs=[prompt, duration, steps],
|
51 |
+
outputs=audio_output,
|
52 |
+
fn=generate_audio,
|
53 |
+
)
|
54 |
+
|
55 |
+
@click.command()
|
56 |
+
@click.option('--host', default='127.0.0.1', help='Host to bind to')
|
57 |
+
@click.option('--port', default=None, help='Port to bind to')
|
58 |
+
@click.option('--share', is_flag=True, help='Enable sharing via Gradio')
|
59 |
+
def main(host, port, share):
|
60 |
+
demo.queue().launch(server_name=host, server_port=port, share=share)
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
main()
|
external_models/TangoFlux/tangoflux/generate_crpo_dataset.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import multiprocessing
|
7 |
+
from tqdm import tqdm
|
8 |
+
from safetensors.torch import load_file
|
9 |
+
from diffusers import AutoencoderOobleck
|
10 |
+
import soundfile as sf
|
11 |
+
from model import TangoFlux
|
12 |
+
import random
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def generate_audio_chunk(args, chunk, gpu_id, output_dir, samplerate, return_dict, process_id):
|
18 |
+
"""
|
19 |
+
Function to generate audio for a chunk of text prompts on a specific GPU.
|
20 |
+
"""
|
21 |
+
try:
|
22 |
+
device = f"cuda:{gpu_id}"
|
23 |
+
torch.cuda.set_device(device)
|
24 |
+
print(f"Process {process_id}: Using device {device}")
|
25 |
+
|
26 |
+
# Initialize model
|
27 |
+
config = {
|
28 |
+
'num_layers': 6,
|
29 |
+
'num_single_layers': 18,
|
30 |
+
'in_channels': 64,
|
31 |
+
'attention_head_dim': 128,
|
32 |
+
'joint_attention_dim': 1024,
|
33 |
+
'num_attention_heads': 8,
|
34 |
+
'audio_seq_len': 645,
|
35 |
+
'max_duration': 30,
|
36 |
+
'uncondition': False,
|
37 |
+
'text_encoder_name': "google/flan-t5-large"
|
38 |
+
}
|
39 |
+
|
40 |
+
model = TangoFlux(config)
|
41 |
+
print(f"Process {process_id}: Loading model from {args.model} on {device}")
|
42 |
+
w1 = load_file(args.model)
|
43 |
+
model.load_state_dict(w1, strict=False)
|
44 |
+
model = model.to(device)
|
45 |
+
model.eval()
|
46 |
+
|
47 |
+
# Initialize VAE
|
48 |
+
vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0", subfolder='vae')
|
49 |
+
vae = vae.to(device)
|
50 |
+
vae.eval()
|
51 |
+
|
52 |
+
outputs = []
|
53 |
+
|
54 |
+
# Corrected loop using enumerate properly with tqdm
|
55 |
+
for idx, item in tqdm(enumerate(chunk), total=len(chunk), desc=f"GPU {gpu_id}"):
|
56 |
+
text = item['captions']
|
57 |
+
|
58 |
+
|
59 |
+
if os.path.exists(os.path.join(output_dir, f"id_{item['id']}_sample1.wav")):
|
60 |
+
print("Exist! Skipping!")
|
61 |
+
continue
|
62 |
+
with torch.no_grad():
|
63 |
+
latent = model.inference_flow(
|
64 |
+
text,
|
65 |
+
num_inference_steps=args.num_steps,
|
66 |
+
guidance_scale=args.guidance_scale,
|
67 |
+
duration=10,
|
68 |
+
num_samples_per_prompt=args.num_samples
|
69 |
+
)
|
70 |
+
|
71 |
+
#waveform_end = int(duration * vae.config.sampling_rate)
|
72 |
+
latent = latent[:, :220, :] ## 220 correspond to the latent length of audiocaps encoded with this vae. You can modify this
|
73 |
+
wave = vae.decode(latent.transpose(2, 1)).sample.cpu()
|
74 |
+
|
75 |
+
for i in range(args.num_samples):
|
76 |
+
filename = f"id_{item['id']}_sample{i+1}.wav"
|
77 |
+
filepath = os.path.join(output_dir, filename)
|
78 |
+
|
79 |
+
sf.write(filepath, wave[i].T, samplerate)
|
80 |
+
outputs.append({
|
81 |
+
"id": item['id'],
|
82 |
+
"sample": i + 1,
|
83 |
+
"path": filepath,
|
84 |
+
"captions": text
|
85 |
+
})
|
86 |
+
|
87 |
+
return_dict[process_id] = outputs
|
88 |
+
print(f"Process {process_id}: Completed processing on GPU {gpu_id}")
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
print(f"Process {process_id}: Error on GPU {gpu_id}: {e}")
|
92 |
+
return_dict[process_id] = []
|
93 |
+
|
94 |
+
def split_into_chunks(data, num_chunks):
|
95 |
+
"""
|
96 |
+
Splits data into num_chunks approximately equal parts.
|
97 |
+
"""
|
98 |
+
avg = len(data) // num_chunks
|
99 |
+
chunks = []
|
100 |
+
for i in range(num_chunks):
|
101 |
+
start = i * avg
|
102 |
+
# Ensure the last chunk takes the remainder
|
103 |
+
end = (i + 1) * avg if i != num_chunks - 1 else len(data)
|
104 |
+
chunks.append(data[start:end])
|
105 |
+
return chunks
|
106 |
+
|
107 |
+
def main():
|
108 |
+
parser = argparse.ArgumentParser(description="Generate audio using multiple GPUs")
|
109 |
+
parser.add_argument('--num_steps', type=int, default=50, help='Number of inference steps')
|
110 |
+
parser.add_argument('--model', type=str, required=True, help='Path to tangoflux weights')
|
111 |
+
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples per prompt')
|
112 |
+
parser.add_argument('--output_dir', type=str, default='output', help='Directory to save outputs')
|
113 |
+
parser.add_argument('--json_path', type=str, required=True, help='Path to input JSON file')
|
114 |
+
parser.add_argument('--sample_size', type=int, default=20000, help='Number of prompts to sample for CRPO')
|
115 |
+
parser.add_argument('--guidance_scale', type=float, default=4.5, help='Guidance scale used for generation')
|
116 |
+
args = parser.parse_args()
|
117 |
+
|
118 |
+
# Check GPU availability
|
119 |
+
num_gpus = torch.cuda.device_count()
|
120 |
+
sample_size = args.sample_size
|
121 |
+
|
122 |
+
|
123 |
+
# Load JSON data
|
124 |
+
import json
|
125 |
+
try:
|
126 |
+
with open(args.json_path, 'r') as f:
|
127 |
+
data = json.load(f)
|
128 |
+
|
129 |
+
except Exception as e:
|
130 |
+
print(f"Error loading JSON file {args.json_path}: {e}")
|
131 |
+
return
|
132 |
+
|
133 |
+
if not isinstance(data, list):
|
134 |
+
print("Error: JSON data is not a list.")
|
135 |
+
return
|
136 |
+
|
137 |
+
if len(data) < sample_size:
|
138 |
+
print(f"Warning: JSON data contains only {len(data)} items. Sampling all available data.")
|
139 |
+
sampled = data
|
140 |
+
else:
|
141 |
+
sampled = random.sample(data, sample_size)
|
142 |
+
|
143 |
+
# Split data into chunks based on available GPUs
|
144 |
+
random.shuffle(sampled)
|
145 |
+
chunks = split_into_chunks(sampled, num_gpus)
|
146 |
+
|
147 |
+
# Prepare output directory
|
148 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
149 |
+
samplerate = 44100
|
150 |
+
|
151 |
+
# Manager for inter-process communication
|
152 |
+
manager = multiprocessing.Manager()
|
153 |
+
return_dict = manager.dict()
|
154 |
+
|
155 |
+
processes = []
|
156 |
+
for i in range(num_gpus):
|
157 |
+
p = multiprocessing.Process(
|
158 |
+
target=generate_audio_chunk,
|
159 |
+
args=(
|
160 |
+
args,
|
161 |
+
chunks[i],
|
162 |
+
i, # GPU ID
|
163 |
+
args.output_dir,
|
164 |
+
samplerate,
|
165 |
+
return_dict,
|
166 |
+
i, # Process ID
|
167 |
+
|
168 |
+
)
|
169 |
+
)
|
170 |
+
processes.append(p)
|
171 |
+
p.start()
|
172 |
+
print(f"Started process {i} on GPU {i}")
|
173 |
+
|
174 |
+
for p in processes:
|
175 |
+
p.join()
|
176 |
+
print(f"Process {p.pid} has finished.")
|
177 |
+
|
178 |
+
# Aggregate results
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
audio_info_list = [
|
186 |
+
[{
|
187 |
+
"path": f"{args.output_dir}/id_{sampled[j]['id']}_sample{i}.wav",
|
188 |
+
"duration": sampled[j]["duration"],
|
189 |
+
"captions": sampled[j]["captions"]
|
190 |
+
}
|
191 |
+
for i in range(1, args.num_samples+1) ] for j in range(sample_size)
|
192 |
+
]
|
193 |
+
|
194 |
+
#print(audio_info_list)
|
195 |
+
|
196 |
+
with open(f'{args.output_dir}/results.json','w') as f:
|
197 |
+
json.dump(audio_info_list,f)
|
198 |
+
|
199 |
+
print(f"All audio samples have been generated and saved to {args.output_dir}")
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
multiprocessing.set_start_method('spawn')
|
204 |
+
main()
|
external_models/TangoFlux/tangoflux/label_crpo.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import torch
|
5 |
+
import laion_clap
|
6 |
+
import numpy as np
|
7 |
+
import multiprocessing
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
def parse_args():
|
11 |
+
parser = argparse.ArgumentParser(
|
12 |
+
description="Labelling clap score for crpo dataset"
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--num_samples", type=int, default=5,
|
16 |
+
help="Number of audio samples per prompt"
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--json_path", type=str, required=True,
|
20 |
+
help="Path to input JSON file"
|
21 |
+
)
|
22 |
+
parser.add_argument(
|
23 |
+
"--output_dir", type=str, required=True,
|
24 |
+
help="Directory to save the final JSON with CLAP scores"
|
25 |
+
)
|
26 |
+
return parser.parse_args()
|
27 |
+
|
28 |
+
#python3 label_clap.py --json_path=/mnt/data/chiayu/crpo/crpo_iteration1/results.json --output_dir=/mnt/data/chiayu/crpo/crpo_iteration1
|
29 |
+
@torch.no_grad()
|
30 |
+
def compute_clap(model, audio_files, text_data):
|
31 |
+
# Compute audio and text embeddings, then compute the dot product (CLAP score)
|
32 |
+
audio_embed = model.get_audio_embedding_from_filelist(x=audio_files, use_tensor=True)
|
33 |
+
text_embed = model.get_text_embedding(text_data, use_tensor=True)
|
34 |
+
return audio_embed @ text_embed.T
|
35 |
+
|
36 |
+
def process_chunk(args, chunk, gpu_id, return_dict, process_id):
|
37 |
+
"""
|
38 |
+
Process a chunk of the data on a specific GPU.
|
39 |
+
Loads the CLAP model on the designated device, then for each item in the chunk,
|
40 |
+
computes the CLAP scores and attaches them to the data.
|
41 |
+
"""
|
42 |
+
try:
|
43 |
+
device = f"cuda:{gpu_id}"
|
44 |
+
torch.cuda.set_device(device)
|
45 |
+
print(f"Process {process_id}: Using device {device}")
|
46 |
+
|
47 |
+
# Initialize the CLAP model on this GPU
|
48 |
+
model = laion_clap.CLAP_Module(enable_fusion=False)
|
49 |
+
model.to(device)
|
50 |
+
model.load_ckpt()
|
51 |
+
model.eval()
|
52 |
+
|
53 |
+
for j, item in enumerate(tqdm(chunk, desc=f"GPU {gpu_id}")):
|
54 |
+
# Each item is assumed to be a list of samples.
|
55 |
+
# Skip if already computed.
|
56 |
+
if 'clap_score' in item[0]:
|
57 |
+
continue
|
58 |
+
|
59 |
+
# Collect audio file paths and text data (using the first caption)
|
60 |
+
audio_files = [item[i]['path'] for i in range(args.num_samples)]
|
61 |
+
text_data = [item[0]['captions']]
|
62 |
+
|
63 |
+
try:
|
64 |
+
clap_scores = compute_clap(model, audio_files, text_data)
|
65 |
+
except Exception as e:
|
66 |
+
print(f"Error processing item index {j} on GPU {gpu_id}: {e}")
|
67 |
+
continue
|
68 |
+
|
69 |
+
# Attach the computed score to each sample in the item
|
70 |
+
for k in range(args.num_samples):
|
71 |
+
item[k]['clap_score'] = np.round(clap_scores[k].item(), 3)
|
72 |
+
|
73 |
+
return_dict[process_id] = chunk
|
74 |
+
print(f"Process {process_id}: Completed processing on GPU {gpu_id}")
|
75 |
+
except Exception as e:
|
76 |
+
print(f"Process {process_id}: Error on GPU {gpu_id}: {e}")
|
77 |
+
return_dict[process_id] = []
|
78 |
+
|
79 |
+
def split_into_chunks(data, num_chunks):
|
80 |
+
"""
|
81 |
+
Splits data into num_chunks approximately equal parts.
|
82 |
+
"""
|
83 |
+
avg = len(data) // num_chunks
|
84 |
+
chunks = []
|
85 |
+
for i in range(num_chunks):
|
86 |
+
start = i * avg
|
87 |
+
# Ensure the last chunk takes the remainder of the data
|
88 |
+
end = (i + 1) * avg if i != num_chunks - 1 else len(data)
|
89 |
+
chunks.append(data[start:end])
|
90 |
+
return chunks
|
91 |
+
|
92 |
+
def main():
|
93 |
+
args = parse_args()
|
94 |
+
|
95 |
+
# Load data from JSON and slice by start/end if provided
|
96 |
+
with open(args.json_path, 'r') as f:
|
97 |
+
data = json.load(f)
|
98 |
+
|
99 |
+
# Check GPU availability and split data accordingly
|
100 |
+
num_gpus = torch.cuda.device_count()
|
101 |
+
|
102 |
+
print(f"Found {num_gpus} GPUs. Splitting data into {num_gpus} chunks.")
|
103 |
+
chunks = split_into_chunks(data, num_gpus)
|
104 |
+
|
105 |
+
# Prepare output directory
|
106 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
107 |
+
|
108 |
+
# Create a manager dict to collect results from all processes
|
109 |
+
manager = multiprocessing.Manager()
|
110 |
+
return_dict = manager.dict()
|
111 |
+
processes = []
|
112 |
+
|
113 |
+
for i in range(num_gpus):
|
114 |
+
p = multiprocessing.Process(
|
115 |
+
target=process_chunk,
|
116 |
+
args=(args, chunks[i], i, return_dict, i)
|
117 |
+
)
|
118 |
+
processes.append(p)
|
119 |
+
p.start()
|
120 |
+
print(f"Started process {i} on GPU {i}")
|
121 |
+
|
122 |
+
for p in processes:
|
123 |
+
p.join()
|
124 |
+
print(f"Process {p.pid} has finished.")
|
125 |
+
|
126 |
+
# Aggregate all chunks back into a single list
|
127 |
+
combined_data = []
|
128 |
+
for i in range(num_gpus):
|
129 |
+
combined_data.extend(return_dict[i])
|
130 |
+
|
131 |
+
# Save the combined results to a single JSON file
|
132 |
+
output_file = f"{args.output_dir}/clap_scores.json"
|
133 |
+
with open(output_file, 'w') as f:
|
134 |
+
json.dump(combined_data, f)
|
135 |
+
print(f"All CLAP scores have been computed and saved to {output_file}")
|
136 |
+
|
137 |
+
max_item = [max(x, key=lambda item: item['clap_score']) for x in combined_data]
|
138 |
+
min_item = [min(x, key=lambda item: item['clap_score']) for x in combined_data]
|
139 |
+
|
140 |
+
crpo_dataset = []
|
141 |
+
for chosen,reject in zip(max_item,min_item):
|
142 |
+
crpo_dataset.append({"captions": chosen['captions'],
|
143 |
+
"duration": chosen['duration'],
|
144 |
+
"chosen": chosen['path'],
|
145 |
+
"reject": reject['path']})
|
146 |
+
|
147 |
+
with open(f"{args.output_dir}/train.json",'w') as f:
|
148 |
+
json.dump(crpo_dataset,f)
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
multiprocessing.set_start_method('spawn')
|
153 |
+
main()
|
external_models/TangoFlux/tangoflux/model.py
ADDED
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import T5EncoderModel, T5TokenizerFast
|
2 |
+
import torch
|
3 |
+
from diffusers import FluxTransformer2DModel
|
4 |
+
from torch import nn
|
5 |
+
import random
|
6 |
+
from typing import List
|
7 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
8 |
+
from diffusers.training_utils import compute_density_for_timestep_sampling
|
9 |
+
import copy
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from typing import Optional, Union, List
|
15 |
+
from datasets import load_dataset, Audio
|
16 |
+
from math import pi
|
17 |
+
import inspect
|
18 |
+
import yaml
|
19 |
+
|
20 |
+
|
21 |
+
class StableAudioPositionalEmbedding(nn.Module):
|
22 |
+
"""Used for continuous time
|
23 |
+
Adapted from Stable Audio Open.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, dim: int):
|
27 |
+
super().__init__()
|
28 |
+
assert (dim % 2) == 0
|
29 |
+
half_dim = dim // 2
|
30 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
31 |
+
|
32 |
+
def forward(self, times: torch.Tensor) -> torch.Tensor:
|
33 |
+
times = times[..., None]
|
34 |
+
freqs = times * self.weights[None] * 2 * pi
|
35 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
36 |
+
fouriered = torch.cat((times, fouriered), dim=-1)
|
37 |
+
return fouriered
|
38 |
+
|
39 |
+
|
40 |
+
class DurationEmbedder(nn.Module):
|
41 |
+
"""
|
42 |
+
A simple linear projection model to map numbers to a latent space.
|
43 |
+
|
44 |
+
Code is adapted from
|
45 |
+
https://github.com/Stability-AI/stable-audio-tools
|
46 |
+
|
47 |
+
Args:
|
48 |
+
number_embedding_dim (`int`):
|
49 |
+
Dimensionality of the number embeddings.
|
50 |
+
min_value (`int`):
|
51 |
+
The minimum value of the seconds number conditioning modules.
|
52 |
+
max_value (`int`):
|
53 |
+
The maximum value of the seconds number conditioning modules
|
54 |
+
internal_dim (`int`):
|
55 |
+
Dimensionality of the intermediate number hidden states.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
number_embedding_dim,
|
61 |
+
min_value,
|
62 |
+
max_value,
|
63 |
+
internal_dim: Optional[int] = 256,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
self.time_positional_embedding = nn.Sequential(
|
67 |
+
StableAudioPositionalEmbedding(internal_dim),
|
68 |
+
nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim),
|
69 |
+
)
|
70 |
+
|
71 |
+
self.number_embedding_dim = number_embedding_dim
|
72 |
+
self.min_value = min_value
|
73 |
+
self.max_value = max_value
|
74 |
+
self.dtype = torch.float32
|
75 |
+
|
76 |
+
def forward(
|
77 |
+
self,
|
78 |
+
floats: torch.Tensor,
|
79 |
+
):
|
80 |
+
floats = floats.clamp(self.min_value, self.max_value)
|
81 |
+
|
82 |
+
normalized_floats = (floats - self.min_value) / (
|
83 |
+
self.max_value - self.min_value
|
84 |
+
)
|
85 |
+
|
86 |
+
# Cast floats to same type as embedder
|
87 |
+
embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
|
88 |
+
normalized_floats = normalized_floats.to(embedder_dtype)
|
89 |
+
|
90 |
+
embedding = self.time_positional_embedding(normalized_floats)
|
91 |
+
float_embeds = embedding.view(-1, 1, self.number_embedding_dim)
|
92 |
+
|
93 |
+
return float_embeds
|
94 |
+
|
95 |
+
|
96 |
+
def retrieve_timesteps(
|
97 |
+
scheduler,
|
98 |
+
num_inference_steps: Optional[int] = None,
|
99 |
+
device: Optional[Union[str, torch.device]] = None,
|
100 |
+
timesteps: Optional[List[int]] = None,
|
101 |
+
sigmas: Optional[List[float]] = None,
|
102 |
+
**kwargs,
|
103 |
+
):
|
104 |
+
|
105 |
+
if timesteps is not None and sigmas is not None:
|
106 |
+
raise ValueError(
|
107 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
108 |
+
)
|
109 |
+
if timesteps is not None:
|
110 |
+
accepts_timesteps = "timesteps" in set(
|
111 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
112 |
+
)
|
113 |
+
if not accepts_timesteps:
|
114 |
+
raise ValueError(
|
115 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
116 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
117 |
+
)
|
118 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
119 |
+
timesteps = scheduler.timesteps
|
120 |
+
num_inference_steps = len(timesteps)
|
121 |
+
elif sigmas is not None:
|
122 |
+
accept_sigmas = "sigmas" in set(
|
123 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
124 |
+
)
|
125 |
+
if not accept_sigmas:
|
126 |
+
raise ValueError(
|
127 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
128 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
129 |
+
)
|
130 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
131 |
+
timesteps = scheduler.timesteps
|
132 |
+
num_inference_steps = len(timesteps)
|
133 |
+
else:
|
134 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
135 |
+
timesteps = scheduler.timesteps
|
136 |
+
return timesteps, num_inference_steps
|
137 |
+
|
138 |
+
|
139 |
+
class TangoFlux(nn.Module):
|
140 |
+
|
141 |
+
def __init__(self, config, text_encoder_dir=None, initialize_reference_model=False,):
|
142 |
+
|
143 |
+
super().__init__()
|
144 |
+
|
145 |
+
self.num_layers = config.get("num_layers", 6)
|
146 |
+
self.num_single_layers = config.get("num_single_layers", 18)
|
147 |
+
self.in_channels = config.get("in_channels", 64)
|
148 |
+
self.attention_head_dim = config.get("attention_head_dim", 128)
|
149 |
+
self.joint_attention_dim = config.get("joint_attention_dim", 1024)
|
150 |
+
self.num_attention_heads = config.get("num_attention_heads", 8)
|
151 |
+
self.audio_seq_len = config.get("audio_seq_len", 645)
|
152 |
+
self.max_duration = config.get("max_duration", 30)
|
153 |
+
self.uncondition = config.get("uncondition", False)
|
154 |
+
self.text_encoder_name = config.get("text_encoder_name", "google/flan-t5-large")
|
155 |
+
|
156 |
+
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
157 |
+
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
|
158 |
+
self.max_text_seq_len = 64
|
159 |
+
self.text_encoder = T5EncoderModel.from_pretrained(
|
160 |
+
text_encoder_dir if text_encoder_dir is not None else self.text_encoder_name
|
161 |
+
)
|
162 |
+
self.tokenizer = T5TokenizerFast.from_pretrained(
|
163 |
+
text_encoder_dir if text_encoder_dir is not None else self.text_encoder_name
|
164 |
+
)
|
165 |
+
self.text_embedding_dim = self.text_encoder.config.d_model
|
166 |
+
|
167 |
+
self.fc = nn.Sequential(
|
168 |
+
nn.Linear(self.text_embedding_dim, self.joint_attention_dim), nn.ReLU()
|
169 |
+
)
|
170 |
+
self.duration_emebdder = DurationEmbedder(
|
171 |
+
self.text_embedding_dim, min_value=0, max_value=self.max_duration
|
172 |
+
)
|
173 |
+
|
174 |
+
self.transformer = FluxTransformer2DModel(
|
175 |
+
in_channels=self.in_channels,
|
176 |
+
num_layers=self.num_layers,
|
177 |
+
num_single_layers=self.num_single_layers,
|
178 |
+
attention_head_dim=self.attention_head_dim,
|
179 |
+
num_attention_heads=self.num_attention_heads,
|
180 |
+
joint_attention_dim=self.joint_attention_dim,
|
181 |
+
pooled_projection_dim=self.text_embedding_dim,
|
182 |
+
guidance_embeds=False,
|
183 |
+
)
|
184 |
+
|
185 |
+
self.beta_dpo = 2000 ## this is used for dpo training
|
186 |
+
|
187 |
+
def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32):
|
188 |
+
device = self.text_encoder.device
|
189 |
+
sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
|
190 |
+
|
191 |
+
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device)
|
192 |
+
timesteps = timesteps.to(device)
|
193 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
194 |
+
|
195 |
+
sigma = sigmas[step_indices].flatten()
|
196 |
+
while len(sigma.shape) < n_dim:
|
197 |
+
sigma = sigma.unsqueeze(-1)
|
198 |
+
return sigma
|
199 |
+
|
200 |
+
def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1):
|
201 |
+
device = self.text_encoder.device
|
202 |
+
batch = self.tokenizer(
|
203 |
+
prompt,
|
204 |
+
max_length=self.tokenizer.model_max_length,
|
205 |
+
padding=True,
|
206 |
+
truncation=True,
|
207 |
+
return_tensors="pt",
|
208 |
+
)
|
209 |
+
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
|
210 |
+
device
|
211 |
+
)
|
212 |
+
|
213 |
+
with torch.no_grad():
|
214 |
+
prompt_embeds = self.text_encoder(
|
215 |
+
input_ids=input_ids, attention_mask=attention_mask
|
216 |
+
)[0]
|
217 |
+
|
218 |
+
prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
|
219 |
+
attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
|
220 |
+
|
221 |
+
# get unconditional embeddings for classifier free guidance
|
222 |
+
uncond_tokens = [""]
|
223 |
+
|
224 |
+
max_length = prompt_embeds.shape[1]
|
225 |
+
uncond_batch = self.tokenizer(
|
226 |
+
uncond_tokens,
|
227 |
+
max_length=max_length,
|
228 |
+
padding="max_length",
|
229 |
+
truncation=True,
|
230 |
+
return_tensors="pt",
|
231 |
+
)
|
232 |
+
uncond_input_ids = uncond_batch.input_ids.to(device)
|
233 |
+
uncond_attention_mask = uncond_batch.attention_mask.to(device)
|
234 |
+
|
235 |
+
with torch.no_grad():
|
236 |
+
negative_prompt_embeds = self.text_encoder(
|
237 |
+
input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
|
238 |
+
)[0]
|
239 |
+
|
240 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(
|
241 |
+
num_samples_per_prompt, 0
|
242 |
+
)
|
243 |
+
uncond_attention_mask = uncond_attention_mask.repeat_interleave(
|
244 |
+
num_samples_per_prompt, 0
|
245 |
+
)
|
246 |
+
|
247 |
+
# For classifier free guidance, we need to do two forward passes.
|
248 |
+
# We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
|
249 |
+
|
250 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
251 |
+
prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
|
252 |
+
boolean_prompt_mask = (prompt_mask == 1).to(device)
|
253 |
+
|
254 |
+
return prompt_embeds, boolean_prompt_mask
|
255 |
+
|
256 |
+
@torch.no_grad()
|
257 |
+
def encode_text(self, prompt):
|
258 |
+
device = self.text_encoder.device
|
259 |
+
batch = self.tokenizer(
|
260 |
+
prompt,
|
261 |
+
max_length=self.max_text_seq_len,
|
262 |
+
padding=True,
|
263 |
+
truncation=True,
|
264 |
+
return_tensors="pt",
|
265 |
+
)
|
266 |
+
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
|
267 |
+
device
|
268 |
+
)
|
269 |
+
|
270 |
+
encoder_hidden_states = self.text_encoder(
|
271 |
+
input_ids=input_ids, attention_mask=attention_mask
|
272 |
+
)[0]
|
273 |
+
|
274 |
+
boolean_encoder_mask = (attention_mask == 1).to(device)
|
275 |
+
|
276 |
+
return encoder_hidden_states, boolean_encoder_mask
|
277 |
+
|
278 |
+
def encode_duration(self, duration):
|
279 |
+
return self.duration_emebdder(duration)
|
280 |
+
|
281 |
+
@torch.no_grad()
|
282 |
+
def inference_flow(
|
283 |
+
self,
|
284 |
+
prompt,
|
285 |
+
num_inference_steps=50,
|
286 |
+
timesteps=None,
|
287 |
+
guidance_scale=3,
|
288 |
+
duration=10,
|
289 |
+
seed=0,
|
290 |
+
disable_progress=False,
|
291 |
+
num_samples_per_prompt=1,
|
292 |
+
callback_on_step_end=None,
|
293 |
+
):
|
294 |
+
"""Only tested for single inference. Haven't test for batch inference"""
|
295 |
+
|
296 |
+
torch.manual_seed(seed)
|
297 |
+
if torch.cuda.is_available():
|
298 |
+
torch.cuda.manual_seed(seed)
|
299 |
+
torch.cuda.manual_seed_all(seed)
|
300 |
+
torch.backends.cudnn.deterministic = True
|
301 |
+
|
302 |
+
bsz = num_samples_per_prompt
|
303 |
+
device = self.transformer.device
|
304 |
+
scheduler = self.noise_scheduler
|
305 |
+
|
306 |
+
if not isinstance(prompt, list):
|
307 |
+
prompt = [prompt]
|
308 |
+
if not isinstance(duration, torch.Tensor):
|
309 |
+
duration = torch.tensor([duration], device=device)
|
310 |
+
classifier_free_guidance = guidance_scale > 1.0
|
311 |
+
duration_hidden_states = self.encode_duration(duration)
|
312 |
+
if classifier_free_guidance:
|
313 |
+
bsz = 2 * num_samples_per_prompt
|
314 |
+
|
315 |
+
encoder_hidden_states, boolean_encoder_mask = (
|
316 |
+
self.encode_text_classifier_free(
|
317 |
+
prompt, num_samples_per_prompt=num_samples_per_prompt
|
318 |
+
)
|
319 |
+
)
|
320 |
+
duration_hidden_states = duration_hidden_states.repeat(bsz, 1, 1)
|
321 |
+
|
322 |
+
else:
|
323 |
+
|
324 |
+
encoder_hidden_states, boolean_encoder_mask = self.encode_text(
|
325 |
+
prompt, num_samples_per_prompt=num_samples_per_prompt
|
326 |
+
)
|
327 |
+
|
328 |
+
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(
|
329 |
+
encoder_hidden_states
|
330 |
+
)
|
331 |
+
masked_data = torch.where(
|
332 |
+
mask_expanded, encoder_hidden_states, torch.tensor(float("nan"))
|
333 |
+
)
|
334 |
+
|
335 |
+
pooled = torch.nanmean(masked_data, dim=1)
|
336 |
+
pooled_projection = self.fc(pooled)
|
337 |
+
|
338 |
+
encoder_hidden_states = torch.cat(
|
339 |
+
[encoder_hidden_states, duration_hidden_states], dim=1
|
340 |
+
) ## (bs,seq_len,dim)
|
341 |
+
|
342 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
343 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
344 |
+
scheduler, num_inference_steps, device, timesteps, sigmas
|
345 |
+
)
|
346 |
+
|
347 |
+
latents = torch.randn(num_samples_per_prompt, self.audio_seq_len, 64)
|
348 |
+
weight_dtype = latents.dtype
|
349 |
+
|
350 |
+
progress_bar = tqdm(range(num_inference_steps), disable=disable_progress)
|
351 |
+
|
352 |
+
txt_ids = torch.zeros(bsz, encoder_hidden_states.shape[1], 3).to(device)
|
353 |
+
audio_ids = (
|
354 |
+
torch.arange(self.audio_seq_len)
|
355 |
+
.unsqueeze(0)
|
356 |
+
.unsqueeze(-1)
|
357 |
+
.repeat(bsz, 1, 3)
|
358 |
+
.to(device)
|
359 |
+
)
|
360 |
+
|
361 |
+
timesteps = timesteps.to(device)
|
362 |
+
latents = latents.to(device)
|
363 |
+
encoder_hidden_states = encoder_hidden_states.to(device)
|
364 |
+
|
365 |
+
for i, t in enumerate(timesteps):
|
366 |
+
|
367 |
+
latents_input = (
|
368 |
+
torch.cat([latents] * 2) if classifier_free_guidance else latents
|
369 |
+
)
|
370 |
+
|
371 |
+
noise_pred = self.transformer(
|
372 |
+
hidden_states=latents_input,
|
373 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
374 |
+
timestep=torch.tensor([t / 1000], device=device),
|
375 |
+
guidance=None,
|
376 |
+
pooled_projections=pooled_projection,
|
377 |
+
encoder_hidden_states=encoder_hidden_states,
|
378 |
+
txt_ids=txt_ids,
|
379 |
+
img_ids=audio_ids,
|
380 |
+
return_dict=False,
|
381 |
+
)[0]
|
382 |
+
|
383 |
+
if classifier_free_guidance:
|
384 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
385 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
386 |
+
noise_pred_text - noise_pred_uncond
|
387 |
+
)
|
388 |
+
|
389 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
390 |
+
|
391 |
+
progress_bar.update(1)
|
392 |
+
|
393 |
+
if callback_on_step_end is not None:
|
394 |
+
callback_on_step_end()
|
395 |
+
|
396 |
+
return latents
|
397 |
+
|
398 |
+
def forward(self, latents, prompt, duration=torch.tensor([10]), sft=True):
|
399 |
+
|
400 |
+
device = latents.device
|
401 |
+
audio_seq_length = self.audio_seq_len
|
402 |
+
bsz = latents.shape[0]
|
403 |
+
|
404 |
+
encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
|
405 |
+
duration_hidden_states = self.encode_duration(duration)
|
406 |
+
|
407 |
+
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(
|
408 |
+
encoder_hidden_states
|
409 |
+
)
|
410 |
+
masked_data = torch.where(
|
411 |
+
mask_expanded, encoder_hidden_states, torch.tensor(float("nan"))
|
412 |
+
)
|
413 |
+
pooled = torch.nanmean(masked_data, dim=1)
|
414 |
+
pooled_projection = self.fc(pooled)
|
415 |
+
|
416 |
+
## Add duration hidden states to encoder hidden states
|
417 |
+
encoder_hidden_states = torch.cat(
|
418 |
+
[encoder_hidden_states, duration_hidden_states], dim=1
|
419 |
+
) ## (bs,seq_len,dim)
|
420 |
+
|
421 |
+
txt_ids = torch.zeros(bsz, encoder_hidden_states.shape[1], 3).to(device)
|
422 |
+
audio_ids = (
|
423 |
+
torch.arange(audio_seq_length)
|
424 |
+
.unsqueeze(0)
|
425 |
+
.unsqueeze(-1)
|
426 |
+
.repeat(bsz, 1, 3)
|
427 |
+
.to(device)
|
428 |
+
)
|
429 |
+
|
430 |
+
if sft:
|
431 |
+
|
432 |
+
if self.uncondition:
|
433 |
+
mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
|
434 |
+
if len(mask_indices) > 0:
|
435 |
+
encoder_hidden_states[mask_indices] = 0
|
436 |
+
|
437 |
+
noise = torch.randn_like(latents)
|
438 |
+
|
439 |
+
u = compute_density_for_timestep_sampling(
|
440 |
+
weighting_scheme="logit_normal",
|
441 |
+
batch_size=bsz,
|
442 |
+
logit_mean=0,
|
443 |
+
logit_std=1,
|
444 |
+
mode_scale=None,
|
445 |
+
)
|
446 |
+
|
447 |
+
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
448 |
+
timesteps = self.noise_scheduler_copy.timesteps[indices].to(
|
449 |
+
device=latents.device
|
450 |
+
)
|
451 |
+
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
|
452 |
+
|
453 |
+
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
454 |
+
|
455 |
+
model_pred = self.transformer(
|
456 |
+
hidden_states=noisy_model_input,
|
457 |
+
encoder_hidden_states=encoder_hidden_states,
|
458 |
+
pooled_projections=pooled_projection,
|
459 |
+
img_ids=audio_ids,
|
460 |
+
txt_ids=txt_ids,
|
461 |
+
guidance=None,
|
462 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
463 |
+
timestep=timesteps / 1000,
|
464 |
+
return_dict=False,
|
465 |
+
)[0]
|
466 |
+
|
467 |
+
target = noise - latents
|
468 |
+
loss = torch.mean(
|
469 |
+
((model_pred.float() - target.float()) ** 2).reshape(
|
470 |
+
target.shape[0], -1
|
471 |
+
),
|
472 |
+
1,
|
473 |
+
)
|
474 |
+
loss = loss.mean()
|
475 |
+
raw_model_loss, raw_ref_loss, implicit_acc = (
|
476 |
+
0,
|
477 |
+
0,
|
478 |
+
0,
|
479 |
+
) ## default this to 0 if doing sft
|
480 |
+
|
481 |
+
else:
|
482 |
+
encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1)
|
483 |
+
pooled_projection = pooled_projection.repeat(2, 1)
|
484 |
+
noise = (
|
485 |
+
torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1)
|
486 |
+
) ## Have to sample same noise for preferred and rejected
|
487 |
+
u = compute_density_for_timestep_sampling(
|
488 |
+
weighting_scheme="logit_normal",
|
489 |
+
batch_size=bsz // 2,
|
490 |
+
logit_mean=0,
|
491 |
+
logit_std=1,
|
492 |
+
mode_scale=None,
|
493 |
+
)
|
494 |
+
|
495 |
+
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
496 |
+
timesteps = self.noise_scheduler_copy.timesteps[indices].to(
|
497 |
+
device=latents.device
|
498 |
+
)
|
499 |
+
timesteps = timesteps.repeat(2)
|
500 |
+
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
|
501 |
+
|
502 |
+
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
503 |
+
|
504 |
+
model_pred = self.transformer(
|
505 |
+
hidden_states=noisy_model_input,
|
506 |
+
encoder_hidden_states=encoder_hidden_states,
|
507 |
+
pooled_projections=pooled_projection,
|
508 |
+
img_ids=audio_ids,
|
509 |
+
txt_ids=txt_ids,
|
510 |
+
guidance=None,
|
511 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
512 |
+
timestep=timesteps / 1000,
|
513 |
+
return_dict=False,
|
514 |
+
)[0]
|
515 |
+
target = noise - latents
|
516 |
+
|
517 |
+
model_losses = F.mse_loss(
|
518 |
+
model_pred.float(), target.float(), reduction="none"
|
519 |
+
)
|
520 |
+
model_losses = model_losses.mean(
|
521 |
+
dim=list(range(1, len(model_losses.shape)))
|
522 |
+
)
|
523 |
+
model_losses_w, model_losses_l = model_losses.chunk(2)
|
524 |
+
model_diff = model_losses_w - model_losses_l
|
525 |
+
raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
|
526 |
+
|
527 |
+
with torch.no_grad():
|
528 |
+
ref_preds = self.ref_transformer(
|
529 |
+
hidden_states=noisy_model_input,
|
530 |
+
encoder_hidden_states=encoder_hidden_states,
|
531 |
+
pooled_projections=pooled_projection,
|
532 |
+
img_ids=audio_ids,
|
533 |
+
txt_ids=txt_ids,
|
534 |
+
guidance=None,
|
535 |
+
timestep=timesteps / 1000,
|
536 |
+
return_dict=False,
|
537 |
+
)[0]
|
538 |
+
|
539 |
+
ref_loss = F.mse_loss(
|
540 |
+
ref_preds.float(), target.float(), reduction="none"
|
541 |
+
)
|
542 |
+
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
543 |
+
|
544 |
+
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
545 |
+
ref_diff = ref_losses_w - ref_losses_l
|
546 |
+
raw_ref_loss = ref_loss.mean()
|
547 |
+
|
548 |
+
scale_term = -0.5 * self.beta_dpo
|
549 |
+
inside_term = scale_term * (model_diff - ref_diff)
|
550 |
+
implicit_acc = (
|
551 |
+
scale_term * (model_diff - ref_diff) > 0
|
552 |
+
).sum().float() / inside_term.size(0)
|
553 |
+
loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean()
|
554 |
+
|
555 |
+
## raw_model_loss, raw_ref_loss, implicit_acc is used to help to analyze dpo behaviour.
|
556 |
+
return loss, raw_model_loss, raw_ref_loss, implicit_acc
|
external_models/TangoFlux/tangoflux/train.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import yaml
|
8 |
+
from pathlib import Path
|
9 |
+
import diffusers
|
10 |
+
import datasets
|
11 |
+
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
import wandb
|
14 |
+
import transformers
|
15 |
+
import torch
|
16 |
+
from accelerate import Accelerator
|
17 |
+
from accelerate.logging import get_logger
|
18 |
+
from accelerate.utils import set_seed
|
19 |
+
from datasets import load_dataset
|
20 |
+
from torch.utils.data import Dataset, DataLoader
|
21 |
+
from tqdm.auto import tqdm
|
22 |
+
from transformers import SchedulerType, get_scheduler
|
23 |
+
from model import TangoFlux
|
24 |
+
from datasets import load_dataset, Audio
|
25 |
+
from utils import Text2AudioDataset, read_wav_file, pad_wav
|
26 |
+
|
27 |
+
from diffusers import AutoencoderOobleck
|
28 |
+
import torchaudio
|
29 |
+
|
30 |
+
logger = get_logger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
def parse_args():
|
34 |
+
parser = argparse.ArgumentParser(
|
35 |
+
description="Rectified flow for text to audio generation task."
|
36 |
+
)
|
37 |
+
|
38 |
+
parser.add_argument(
|
39 |
+
"--num_examples",
|
40 |
+
type=int,
|
41 |
+
default=-1,
|
42 |
+
help="How many examples to use for training and validation.",
|
43 |
+
)
|
44 |
+
|
45 |
+
parser.add_argument(
|
46 |
+
"--text_column",
|
47 |
+
type=str,
|
48 |
+
default="captions",
|
49 |
+
help="The name of the column in the datasets containing the input texts.",
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--audio_column",
|
53 |
+
type=str,
|
54 |
+
default="location",
|
55 |
+
help="The name of the column in the datasets containing the audio paths.",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--adam_beta1",
|
59 |
+
type=float,
|
60 |
+
default=0.9,
|
61 |
+
help="The beta1 parameter for the Adam optimizer.",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--adam_beta2",
|
65 |
+
type=float,
|
66 |
+
default=0.95,
|
67 |
+
help="The beta2 parameter for the Adam optimizer.",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--config",
|
71 |
+
type=str,
|
72 |
+
default="tangoflux_config.yaml",
|
73 |
+
help="Config file defining the model size as well as other hyper parameter.",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--prefix",
|
77 |
+
type=str,
|
78 |
+
default="",
|
79 |
+
help="Add prefix in text prompts.",
|
80 |
+
)
|
81 |
+
|
82 |
+
parser.add_argument(
|
83 |
+
"--learning_rate",
|
84 |
+
type=float,
|
85 |
+
default=3e-5,
|
86 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--weight_decay", type=float, default=1e-8, help="Weight decay to use."
|
90 |
+
)
|
91 |
+
|
92 |
+
parser.add_argument(
|
93 |
+
"--max_train_steps",
|
94 |
+
type=int,
|
95 |
+
default=None,
|
96 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
97 |
+
)
|
98 |
+
|
99 |
+
parser.add_argument(
|
100 |
+
"--lr_scheduler_type",
|
101 |
+
type=SchedulerType,
|
102 |
+
default="linear",
|
103 |
+
help="The scheduler type to use.",
|
104 |
+
choices=[
|
105 |
+
"linear",
|
106 |
+
"cosine",
|
107 |
+
"cosine_with_restarts",
|
108 |
+
"polynomial",
|
109 |
+
"constant",
|
110 |
+
"constant_with_warmup",
|
111 |
+
],
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--num_warmup_steps",
|
115 |
+
type=int,
|
116 |
+
default=0,
|
117 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--adam_epsilon",
|
121 |
+
type=float,
|
122 |
+
default=1e-08,
|
123 |
+
help="Epsilon value for the Adam optimizer",
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--adam_weight_decay",
|
127 |
+
type=float,
|
128 |
+
default=1e-2,
|
129 |
+
help="Epsilon value for the Adam optimizer",
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--seed", type=int, default=None, help="A seed for reproducible training."
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--checkpointing_steps",
|
136 |
+
type=str,
|
137 |
+
default="best",
|
138 |
+
help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--save_every",
|
142 |
+
type=int,
|
143 |
+
default=5,
|
144 |
+
help="Save model after every how many epochs when checkpointing_steps is set to best.",
|
145 |
+
)
|
146 |
+
|
147 |
+
parser.add_argument(
|
148 |
+
"--resume_from_checkpoint",
|
149 |
+
type=str,
|
150 |
+
default=None,
|
151 |
+
help="If the training should continue from a local checkpoint folder.",
|
152 |
+
)
|
153 |
+
|
154 |
+
parser.add_argument(
|
155 |
+
"--load_from_checkpoint",
|
156 |
+
type=str,
|
157 |
+
default=None,
|
158 |
+
help="Whether to continue training from a model weight",
|
159 |
+
)
|
160 |
+
|
161 |
+
args = parser.parse_args()
|
162 |
+
|
163 |
+
return args
|
164 |
+
|
165 |
+
|
166 |
+
def main():
|
167 |
+
args = parse_args()
|
168 |
+
accelerator_log_kwargs = {}
|
169 |
+
|
170 |
+
def load_config(config_path):
|
171 |
+
with open(config_path, "r") as file:
|
172 |
+
return yaml.safe_load(file)
|
173 |
+
|
174 |
+
config = load_config(args.config)
|
175 |
+
|
176 |
+
learning_rate = float(config["training"]["learning_rate"])
|
177 |
+
num_train_epochs = int(config["training"]["num_train_epochs"])
|
178 |
+
num_warmup_steps = int(config["training"]["num_warmup_steps"])
|
179 |
+
per_device_batch_size = int(config["training"]["per_device_batch_size"])
|
180 |
+
gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"])
|
181 |
+
|
182 |
+
output_dir = config["paths"]["output_dir"]
|
183 |
+
|
184 |
+
accelerator = Accelerator(
|
185 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
186 |
+
**accelerator_log_kwargs,
|
187 |
+
)
|
188 |
+
|
189 |
+
# Make one log on every process with the configuration for debugging.
|
190 |
+
logging.basicConfig(
|
191 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
192 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
193 |
+
level=logging.INFO,
|
194 |
+
)
|
195 |
+
logger.info(accelerator.state, main_process_only=False)
|
196 |
+
|
197 |
+
datasets.utils.logging.set_verbosity_error()
|
198 |
+
diffusers.utils.logging.set_verbosity_error()
|
199 |
+
transformers.utils.logging.set_verbosity_error()
|
200 |
+
|
201 |
+
# If passed along, set the training seed now.
|
202 |
+
if args.seed is not None:
|
203 |
+
set_seed(args.seed)
|
204 |
+
|
205 |
+
# Handle output directory creation and wandb tracking
|
206 |
+
if accelerator.is_main_process:
|
207 |
+
if output_dir is None or output_dir == "":
|
208 |
+
output_dir = "saved/" + str(int(time.time()))
|
209 |
+
|
210 |
+
if not os.path.exists("saved"):
|
211 |
+
os.makedirs("saved")
|
212 |
+
|
213 |
+
os.makedirs(output_dir, exist_ok=True)
|
214 |
+
|
215 |
+
elif output_dir is not None:
|
216 |
+
os.makedirs(output_dir, exist_ok=True)
|
217 |
+
|
218 |
+
os.makedirs("{}/{}".format(output_dir, "outputs"), exist_ok=True)
|
219 |
+
with open("{}/summary.jsonl".format(output_dir), "a") as f:
|
220 |
+
f.write(json.dumps(dict(vars(args))) + "\n\n")
|
221 |
+
|
222 |
+
accelerator.project_configuration.automatic_checkpoint_naming = False
|
223 |
+
|
224 |
+
wandb.init(
|
225 |
+
project="Text to Audio Flow matching",
|
226 |
+
settings=wandb.Settings(_disable_stats=True),
|
227 |
+
)
|
228 |
+
|
229 |
+
accelerator.wait_for_everyone()
|
230 |
+
|
231 |
+
# Get the datasets
|
232 |
+
data_files = {}
|
233 |
+
# if args.train_file is not None:
|
234 |
+
if config["paths"]["train_file"] != "":
|
235 |
+
data_files["train"] = config["paths"]["train_file"]
|
236 |
+
# if args.validation_file is not None:
|
237 |
+
if config["paths"]["val_file"] != "":
|
238 |
+
data_files["validation"] = config["paths"]["val_file"]
|
239 |
+
if config["paths"]["test_file"] != "":
|
240 |
+
data_files["test"] = config["paths"]["test_file"]
|
241 |
+
else:
|
242 |
+
data_files["test"] = config["paths"]["val_file"]
|
243 |
+
|
244 |
+
extension = "json"
|
245 |
+
raw_datasets = load_dataset(extension, data_files=data_files)
|
246 |
+
text_column, audio_column = args.text_column, args.audio_column
|
247 |
+
|
248 |
+
model = TangoFlux(config=config["model"])
|
249 |
+
vae = AutoencoderOobleck.from_pretrained(
|
250 |
+
"stabilityai/stable-audio-open-1.0", subfolder="vae"
|
251 |
+
)
|
252 |
+
|
253 |
+
## Freeze vae
|
254 |
+
for param in vae.parameters():
|
255 |
+
vae.requires_grad = False
|
256 |
+
vae.eval()
|
257 |
+
|
258 |
+
## Freeze text encoder param
|
259 |
+
for param in model.text_encoder.parameters():
|
260 |
+
param.requires_grad = False
|
261 |
+
model.text_encoder.eval()
|
262 |
+
|
263 |
+
prefix = args.prefix
|
264 |
+
|
265 |
+
with accelerator.main_process_first():
|
266 |
+
train_dataset = Text2AudioDataset(
|
267 |
+
raw_datasets["train"],
|
268 |
+
prefix,
|
269 |
+
text_column,
|
270 |
+
audio_column,
|
271 |
+
"duration",
|
272 |
+
args.num_examples,
|
273 |
+
)
|
274 |
+
eval_dataset = Text2AudioDataset(
|
275 |
+
raw_datasets["validation"],
|
276 |
+
prefix,
|
277 |
+
text_column,
|
278 |
+
audio_column,
|
279 |
+
"duration",
|
280 |
+
args.num_examples,
|
281 |
+
)
|
282 |
+
test_dataset = Text2AudioDataset(
|
283 |
+
raw_datasets["test"],
|
284 |
+
prefix,
|
285 |
+
text_column,
|
286 |
+
audio_column,
|
287 |
+
"duration",
|
288 |
+
args.num_examples,
|
289 |
+
)
|
290 |
+
|
291 |
+
accelerator.print(
|
292 |
+
"Num instances in train: {}, validation: {}, test: {}".format(
|
293 |
+
train_dataset.get_num_instances(),
|
294 |
+
eval_dataset.get_num_instances(),
|
295 |
+
test_dataset.get_num_instances(),
|
296 |
+
)
|
297 |
+
)
|
298 |
+
|
299 |
+
train_dataloader = DataLoader(
|
300 |
+
train_dataset,
|
301 |
+
shuffle=True,
|
302 |
+
batch_size=config["training"]["per_device_batch_size"],
|
303 |
+
collate_fn=train_dataset.collate_fn,
|
304 |
+
)
|
305 |
+
eval_dataloader = DataLoader(
|
306 |
+
eval_dataset,
|
307 |
+
shuffle=True,
|
308 |
+
batch_size=config["training"]["per_device_batch_size"],
|
309 |
+
collate_fn=eval_dataset.collate_fn,
|
310 |
+
)
|
311 |
+
test_dataloader = DataLoader(
|
312 |
+
test_dataset,
|
313 |
+
shuffle=False,
|
314 |
+
batch_size=config["training"]["per_device_batch_size"],
|
315 |
+
collate_fn=test_dataset.collate_fn,
|
316 |
+
)
|
317 |
+
|
318 |
+
# Optimizer
|
319 |
+
|
320 |
+
optimizer_parameters = list(model.transformer.parameters()) + list(
|
321 |
+
model.fc.parameters()
|
322 |
+
)
|
323 |
+
num_trainable_parameters = sum(
|
324 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
325 |
+
)
|
326 |
+
accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
|
327 |
+
|
328 |
+
if args.load_from_checkpoint:
|
329 |
+
from safetensors.torch import load_file
|
330 |
+
|
331 |
+
w1 = load_file(args.load_from_checkpoint)
|
332 |
+
model.load_state_dict(w1, strict=False)
|
333 |
+
logger.info("Weights loaded from{}".format(args.load_from_checkpoint))
|
334 |
+
|
335 |
+
optimizer = torch.optim.AdamW(
|
336 |
+
optimizer_parameters,
|
337 |
+
lr=learning_rate,
|
338 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
339 |
+
weight_decay=args.adam_weight_decay,
|
340 |
+
eps=args.adam_epsilon,
|
341 |
+
)
|
342 |
+
|
343 |
+
# Scheduler and math around the number of training steps.
|
344 |
+
overrode_max_train_steps = False
|
345 |
+
num_update_steps_per_epoch = math.ceil(
|
346 |
+
len(train_dataloader) / gradient_accumulation_steps
|
347 |
+
)
|
348 |
+
if args.max_train_steps is None:
|
349 |
+
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
350 |
+
overrode_max_train_steps = True
|
351 |
+
|
352 |
+
lr_scheduler = get_scheduler(
|
353 |
+
name=args.lr_scheduler_type,
|
354 |
+
optimizer=optimizer,
|
355 |
+
num_warmup_steps=num_warmup_steps
|
356 |
+
* gradient_accumulation_steps
|
357 |
+
* accelerator.num_processes,
|
358 |
+
num_training_steps=args.max_train_steps * gradient_accumulation_steps,
|
359 |
+
)
|
360 |
+
|
361 |
+
# Prepare everything with our `accelerator`.
|
362 |
+
vae, model, optimizer, lr_scheduler = accelerator.prepare(
|
363 |
+
vae, model, optimizer, lr_scheduler
|
364 |
+
)
|
365 |
+
|
366 |
+
train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
|
367 |
+
train_dataloader, eval_dataloader, test_dataloader
|
368 |
+
)
|
369 |
+
|
370 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
371 |
+
num_update_steps_per_epoch = math.ceil(
|
372 |
+
len(train_dataloader) / gradient_accumulation_steps
|
373 |
+
)
|
374 |
+
if overrode_max_train_steps:
|
375 |
+
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
376 |
+
# Afterwards we recalculate our number of training epochs
|
377 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
378 |
+
|
379 |
+
# Figure out how many steps we should save the Accelerator states
|
380 |
+
checkpointing_steps = args.checkpointing_steps
|
381 |
+
if checkpointing_steps is not None and checkpointing_steps.isdigit():
|
382 |
+
checkpointing_steps = int(checkpointing_steps)
|
383 |
+
|
384 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
385 |
+
# The trackers initializes automatically on the main process.
|
386 |
+
|
387 |
+
# Train!
|
388 |
+
total_batch_size = (
|
389 |
+
per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
390 |
+
)
|
391 |
+
|
392 |
+
logger.info("***** Running training *****")
|
393 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
394 |
+
logger.info(f" Num Epochs = {num_train_epochs}")
|
395 |
+
logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
|
396 |
+
logger.info(
|
397 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
398 |
+
)
|
399 |
+
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
400 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
401 |
+
|
402 |
+
# Only show the progress bar once on each machine.
|
403 |
+
progress_bar = tqdm(
|
404 |
+
range(args.max_train_steps), disable=not accelerator.is_local_main_process
|
405 |
+
)
|
406 |
+
|
407 |
+
completed_steps = 0
|
408 |
+
starting_epoch = 0
|
409 |
+
# Potentially load in the weights and states from a previous save
|
410 |
+
resume_from_checkpoint = config["paths"]["resume_from_checkpoint"]
|
411 |
+
if resume_from_checkpoint != "":
|
412 |
+
accelerator.load_state(resume_from_checkpoint)
|
413 |
+
accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}")
|
414 |
+
|
415 |
+
# Duration of the audio clips in seconds
|
416 |
+
best_loss = np.inf
|
417 |
+
length = config["training"]["max_audio_duration"]
|
418 |
+
|
419 |
+
for epoch in range(starting_epoch, num_train_epochs):
|
420 |
+
model.train()
|
421 |
+
total_loss, total_val_loss = 0, 0
|
422 |
+
for step, batch in enumerate(train_dataloader):
|
423 |
+
|
424 |
+
with accelerator.accumulate(model):
|
425 |
+
optimizer.zero_grad()
|
426 |
+
device = model.device
|
427 |
+
text, audios, duration, _ = batch
|
428 |
+
|
429 |
+
with torch.no_grad():
|
430 |
+
audio_list = []
|
431 |
+
|
432 |
+
for audio_path in audios:
|
433 |
+
|
434 |
+
wav = read_wav_file(
|
435 |
+
audio_path, length
|
436 |
+
) ## Only read the first 30 seconds of audio
|
437 |
+
if (
|
438 |
+
wav.shape[0] == 1
|
439 |
+
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
|
440 |
+
wav = wav.repeat(2, 1)
|
441 |
+
audio_list.append(wav)
|
442 |
+
|
443 |
+
audio_input = torch.stack(audio_list, dim=0)
|
444 |
+
audio_input = audio_input.to(device)
|
445 |
+
unwrapped_vae = accelerator.unwrap_model(vae)
|
446 |
+
|
447 |
+
duration = torch.tensor(duration, device=device)
|
448 |
+
duration = torch.clamp(
|
449 |
+
duration, max=length
|
450 |
+
) ## clamp duration to max audio length
|
451 |
+
|
452 |
+
audio_latent = unwrapped_vae.encode(
|
453 |
+
audio_input
|
454 |
+
).latent_dist.sample()
|
455 |
+
audio_latent = audio_latent.transpose(
|
456 |
+
1, 2
|
457 |
+
) ## Tranpose to (bsz, seq_len, channel)
|
458 |
+
|
459 |
+
loss, _, _, _ = model(audio_latent, text, duration=duration)
|
460 |
+
total_loss += loss.detach().float()
|
461 |
+
accelerator.backward(loss)
|
462 |
+
|
463 |
+
if accelerator.sync_gradients:
|
464 |
+
progress_bar.update(1)
|
465 |
+
completed_steps += 1
|
466 |
+
|
467 |
+
optimizer.step()
|
468 |
+
lr_scheduler.step()
|
469 |
+
|
470 |
+
if completed_steps % 10 == 0 and accelerator.is_main_process:
|
471 |
+
|
472 |
+
total_norm = 0.0
|
473 |
+
for p in model.parameters():
|
474 |
+
if p.grad is not None:
|
475 |
+
param_norm = p.grad.data.norm(2)
|
476 |
+
total_norm += param_norm.item() ** 2
|
477 |
+
|
478 |
+
total_norm = total_norm**0.5
|
479 |
+
logger.info(
|
480 |
+
f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}"
|
481 |
+
)
|
482 |
+
|
483 |
+
lr = lr_scheduler.get_last_lr()[0]
|
484 |
+
result = {
|
485 |
+
"train_loss": loss.item(),
|
486 |
+
"grad_norm": total_norm,
|
487 |
+
"learning_rate": lr,
|
488 |
+
}
|
489 |
+
|
490 |
+
# result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
|
491 |
+
wandb.log(result, step=completed_steps)
|
492 |
+
|
493 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
494 |
+
|
495 |
+
if isinstance(checkpointing_steps, int):
|
496 |
+
if completed_steps % checkpointing_steps == 0:
|
497 |
+
output_dir = f"step_{completed_steps }"
|
498 |
+
if output_dir is not None:
|
499 |
+
output_dir = os.path.join(output_dir, output_dir)
|
500 |
+
accelerator.save_state(output_dir)
|
501 |
+
|
502 |
+
if completed_steps >= args.max_train_steps:
|
503 |
+
break
|
504 |
+
|
505 |
+
model.eval()
|
506 |
+
eval_progress_bar = tqdm(
|
507 |
+
range(len(eval_dataloader)), disable=not accelerator.is_local_main_process
|
508 |
+
)
|
509 |
+
for step, batch in enumerate(eval_dataloader):
|
510 |
+
with accelerator.accumulate(model) and torch.no_grad():
|
511 |
+
device = model.device
|
512 |
+
text, audios, duration, _ = batch
|
513 |
+
|
514 |
+
audio_list = []
|
515 |
+
for audio_path in audios:
|
516 |
+
|
517 |
+
wav = read_wav_file(
|
518 |
+
audio_path, length
|
519 |
+
) ## make sure none of audio exceed 30 sec
|
520 |
+
if (
|
521 |
+
wav.shape[0] == 1
|
522 |
+
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
|
523 |
+
wav = wav.repeat(2, 1)
|
524 |
+
audio_list.append(wav)
|
525 |
+
|
526 |
+
audio_input = torch.stack(audio_list, dim=0)
|
527 |
+
audio_input = audio_input.to(device)
|
528 |
+
duration = torch.tensor(duration, device=device)
|
529 |
+
unwrapped_vae = accelerator.unwrap_model(vae)
|
530 |
+
audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
|
531 |
+
audio_latent = audio_latent.transpose(
|
532 |
+
1, 2
|
533 |
+
) ## Tranpose to (bsz, seq_len, channel)
|
534 |
+
|
535 |
+
val_loss, _, _, _ = model(audio_latent, text, duration=duration)
|
536 |
+
|
537 |
+
total_val_loss += val_loss.detach().float()
|
538 |
+
eval_progress_bar.update(1)
|
539 |
+
|
540 |
+
if accelerator.is_main_process:
|
541 |
+
|
542 |
+
result = {}
|
543 |
+
result["epoch"] = float(epoch + 1)
|
544 |
+
|
545 |
+
result["epoch/train_loss"] = round(
|
546 |
+
total_loss.item() / len(train_dataloader), 4
|
547 |
+
)
|
548 |
+
result["epoch/val_loss"] = round(
|
549 |
+
total_val_loss.item() / len(eval_dataloader), 4
|
550 |
+
)
|
551 |
+
|
552 |
+
wandb.log(result, step=completed_steps)
|
553 |
+
|
554 |
+
result_string = "Epoch: {}, Loss Train: {}, Val: {}\n".format(
|
555 |
+
epoch, result["epoch/train_loss"], result["epoch/val_loss"]
|
556 |
+
)
|
557 |
+
|
558 |
+
accelerator.print(result_string)
|
559 |
+
|
560 |
+
with open("{}/summary.jsonl".format(output_dir), "a") as f:
|
561 |
+
f.write(json.dumps(result) + "\n\n")
|
562 |
+
|
563 |
+
logger.info(result)
|
564 |
+
|
565 |
+
if result["epoch/val_loss"] < best_loss:
|
566 |
+
best_loss = result["epoch/val_loss"]
|
567 |
+
save_checkpoint = True
|
568 |
+
else:
|
569 |
+
save_checkpoint = False
|
570 |
+
|
571 |
+
accelerator.wait_for_everyone()
|
572 |
+
if accelerator.is_main_process and args.checkpointing_steps == "best":
|
573 |
+
if save_checkpoint:
|
574 |
+
accelerator.save_state("{}/{}".format(output_dir, "best"))
|
575 |
+
|
576 |
+
if (epoch + 1) % args.save_every == 0:
|
577 |
+
accelerator.save_state(
|
578 |
+
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
|
579 |
+
)
|
580 |
+
|
581 |
+
if accelerator.is_main_process and args.checkpointing_steps == "epoch":
|
582 |
+
accelerator.save_state(
|
583 |
+
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
|
584 |
+
)
|
585 |
+
|
586 |
+
|
587 |
+
if __name__ == "__main__":
|
588 |
+
main()
|
external_models/TangoFlux/tangoflux/train_dpo.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import yaml
|
8 |
+
|
9 |
+
# from tqdm import tqdm
|
10 |
+
import copy
|
11 |
+
from pathlib import Path
|
12 |
+
import diffusers
|
13 |
+
import datasets
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
import wandb
|
17 |
+
import transformers
|
18 |
+
import torch
|
19 |
+
from accelerate import Accelerator
|
20 |
+
from accelerate.logging import get_logger
|
21 |
+
from accelerate.utils import set_seed
|
22 |
+
from datasets import load_dataset
|
23 |
+
from torch.utils.data import Dataset, DataLoader
|
24 |
+
from tqdm.auto import tqdm
|
25 |
+
from transformers import SchedulerType, get_scheduler
|
26 |
+
from tangoflux.model import TangoFlux
|
27 |
+
from datasets import load_dataset, Audio
|
28 |
+
from tangoflux.utils import Text2AudioDataset, read_wav_file, DPOText2AudioDataset
|
29 |
+
|
30 |
+
from diffusers import AutoencoderOobleck
|
31 |
+
import torchaudio
|
32 |
+
|
33 |
+
logger = get_logger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
def parse_args():
|
37 |
+
parser = argparse.ArgumentParser(
|
38 |
+
description="Rectified flow for text to audio generation task."
|
39 |
+
)
|
40 |
+
|
41 |
+
parser.add_argument(
|
42 |
+
"--num_examples",
|
43 |
+
type=int,
|
44 |
+
default=-1,
|
45 |
+
help="How many examples to use for training and validation.",
|
46 |
+
)
|
47 |
+
|
48 |
+
parser.add_argument(
|
49 |
+
"--text_column",
|
50 |
+
type=str,
|
51 |
+
default="captions",
|
52 |
+
help="The name of the column in the datasets containing the input texts.",
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--audio_column",
|
56 |
+
type=str,
|
57 |
+
default="location",
|
58 |
+
help="The name of the column in the datasets containing the audio paths.",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--adam_beta1",
|
62 |
+
type=float,
|
63 |
+
default=0.9,
|
64 |
+
help="The beta1 parameter for the Adam optimizer.",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--adam_beta2",
|
68 |
+
type=float,
|
69 |
+
default=0.95,
|
70 |
+
help="The beta2 parameter for the Adam optimizer.",
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--config",
|
74 |
+
type=str,
|
75 |
+
default="tangoflux_config.yaml",
|
76 |
+
help="Config file defining the model size.",
|
77 |
+
)
|
78 |
+
|
79 |
+
parser.add_argument(
|
80 |
+
"--weight_decay", type=float, default=1e-8, help="Weight decay to use."
|
81 |
+
)
|
82 |
+
|
83 |
+
parser.add_argument(
|
84 |
+
"--max_train_steps",
|
85 |
+
type=int,
|
86 |
+
default=None,
|
87 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
88 |
+
)
|
89 |
+
|
90 |
+
parser.add_argument(
|
91 |
+
"--lr_scheduler_type",
|
92 |
+
type=SchedulerType,
|
93 |
+
default="linear",
|
94 |
+
help="The scheduler type to use.",
|
95 |
+
choices=[
|
96 |
+
"linear",
|
97 |
+
"cosine",
|
98 |
+
"cosine_with_restarts",
|
99 |
+
"polynomial",
|
100 |
+
"constant",
|
101 |
+
"constant_with_warmup",
|
102 |
+
],
|
103 |
+
)
|
104 |
+
|
105 |
+
parser.add_argument(
|
106 |
+
"--adam_epsilon",
|
107 |
+
type=float,
|
108 |
+
default=1e-08,
|
109 |
+
help="Epsilon value for the Adam optimizer",
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--adam_weight_decay",
|
113 |
+
type=float,
|
114 |
+
default=1e-2,
|
115 |
+
help="Epsilon value for the Adam optimizer",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"--seed", type=int, default=None, help="A seed for reproducible training."
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--checkpointing_steps",
|
122 |
+
type=str,
|
123 |
+
default="best",
|
124 |
+
help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--save_every",
|
128 |
+
type=int,
|
129 |
+
default=5,
|
130 |
+
help="Save model after every how many epochs when checkpointing_steps is set to best.",
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
parser.add_argument(
|
136 |
+
"--load_from_checkpoint",
|
137 |
+
type=str,
|
138 |
+
default=None,
|
139 |
+
help="Whether to continue training from a model weight",
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
args = parser.parse_args()
|
144 |
+
|
145 |
+
# Sanity checks
|
146 |
+
# if args.train_file is None and args.validation_file is None:
|
147 |
+
# raise ValueError("Need a training/validation file.")
|
148 |
+
# else:
|
149 |
+
# if args.train_file is not None:
|
150 |
+
# extension = args.train_file.split(".")[-1]
|
151 |
+
# assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
152 |
+
# if args.validation_file is not None:
|
153 |
+
# extension = args.validation_file.split(".")[-1]
|
154 |
+
# assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
155 |
+
|
156 |
+
return args
|
157 |
+
|
158 |
+
|
159 |
+
def main():
|
160 |
+
args = parse_args()
|
161 |
+
accelerator_log_kwargs = {}
|
162 |
+
|
163 |
+
def load_config(config_path):
|
164 |
+
with open(config_path, "r") as file:
|
165 |
+
return yaml.safe_load(file)
|
166 |
+
|
167 |
+
config = load_config(args.config)
|
168 |
+
|
169 |
+
learning_rate = float(config["training"]["learning_rate"])
|
170 |
+
num_train_epochs = int(config["training"]["num_train_epochs"])
|
171 |
+
num_warmup_steps = int(config["training"]["num_warmup_steps"])
|
172 |
+
per_device_batch_size = int(config["training"]["per_device_batch_size"])
|
173 |
+
gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"])
|
174 |
+
|
175 |
+
output_dir = config["paths"]["output_dir"]
|
176 |
+
|
177 |
+
accelerator = Accelerator(
|
178 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
179 |
+
**accelerator_log_kwargs,
|
180 |
+
)
|
181 |
+
|
182 |
+
# Make one log on every process with the configuration for debugging.
|
183 |
+
logging.basicConfig(
|
184 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
185 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
186 |
+
level=logging.INFO,
|
187 |
+
)
|
188 |
+
logger.info(accelerator.state, main_process_only=False)
|
189 |
+
|
190 |
+
datasets.utils.logging.set_verbosity_error()
|
191 |
+
diffusers.utils.logging.set_verbosity_error()
|
192 |
+
transformers.utils.logging.set_verbosity_error()
|
193 |
+
|
194 |
+
# If passed along, set the training seed now.
|
195 |
+
if args.seed is not None:
|
196 |
+
set_seed(args.seed)
|
197 |
+
|
198 |
+
# Handle output directory creation and wandb tracking
|
199 |
+
if accelerator.is_main_process:
|
200 |
+
if output_dir is None or output_dir == "":
|
201 |
+
output_dir = "saved/" + str(int(time.time()))
|
202 |
+
|
203 |
+
if not os.path.exists("saved"):
|
204 |
+
os.makedirs("saved")
|
205 |
+
|
206 |
+
os.makedirs(output_dir, exist_ok=True)
|
207 |
+
|
208 |
+
elif output_dir is not None:
|
209 |
+
os.makedirs(output_dir, exist_ok=True)
|
210 |
+
|
211 |
+
os.makedirs("{}/{}".format(output_dir, "outputs"), exist_ok=True)
|
212 |
+
with open("{}/summary.jsonl".format(output_dir), "a") as f:
|
213 |
+
f.write(json.dumps(dict(vars(args))) + "\n\n")
|
214 |
+
|
215 |
+
accelerator.project_configuration.automatic_checkpoint_naming = False
|
216 |
+
|
217 |
+
wandb.init(
|
218 |
+
project="Text to Audio Flow matching DPO",
|
219 |
+
settings=wandb.Settings(_disable_stats=True),
|
220 |
+
)
|
221 |
+
|
222 |
+
accelerator.wait_for_everyone()
|
223 |
+
|
224 |
+
# Get the datasets
|
225 |
+
data_files = {}
|
226 |
+
# if args.train_file is not None:
|
227 |
+
if config["paths"]["train_file"] != "":
|
228 |
+
data_files["train"] = config["paths"]["train_file"]
|
229 |
+
# if args.validation_file is not None:
|
230 |
+
if config["paths"]["val_file"] != "":
|
231 |
+
data_files["validation"] = config["paths"]["val_file"]
|
232 |
+
if config["paths"]["test_file"] != "":
|
233 |
+
data_files["test"] = config["paths"]["test_file"]
|
234 |
+
else:
|
235 |
+
data_files["test"] = config["paths"]["val_file"]
|
236 |
+
|
237 |
+
extension = "json"
|
238 |
+
train_dataset = load_dataset(extension, data_files=data_files["train"])
|
239 |
+
data_files.pop("train")
|
240 |
+
raw_datasets = load_dataset(extension, data_files=data_files)
|
241 |
+
text_column, audio_column = args.text_column, args.audio_column
|
242 |
+
|
243 |
+
model = TangoFlux(config=config["model"], initialize_reference_model=True)
|
244 |
+
vae = AutoencoderOobleck.from_pretrained(
|
245 |
+
"stabilityai/stable-audio-open-1.0", subfolder="vae"
|
246 |
+
)
|
247 |
+
|
248 |
+
## Freeze vae
|
249 |
+
for param in vae.parameters():
|
250 |
+
vae.requires_grad = False
|
251 |
+
vae.eval()
|
252 |
+
|
253 |
+
## Freeze text encoder param
|
254 |
+
for param in model.text_encoder.parameters():
|
255 |
+
param.requires_grad = False
|
256 |
+
model.text_encoder.eval()
|
257 |
+
|
258 |
+
prefix = ""
|
259 |
+
|
260 |
+
with accelerator.main_process_first():
|
261 |
+
train_dataset = DPOText2AudioDataset(
|
262 |
+
train_dataset["train"],
|
263 |
+
prefix,
|
264 |
+
text_column,
|
265 |
+
"chosen",
|
266 |
+
"reject",
|
267 |
+
"duration",
|
268 |
+
args.num_examples,
|
269 |
+
)
|
270 |
+
eval_dataset = Text2AudioDataset(
|
271 |
+
raw_datasets["validation"],
|
272 |
+
prefix,
|
273 |
+
text_column,
|
274 |
+
audio_column,
|
275 |
+
"duration",
|
276 |
+
args.num_examples,
|
277 |
+
)
|
278 |
+
test_dataset = Text2AudioDataset(
|
279 |
+
raw_datasets["test"],
|
280 |
+
prefix,
|
281 |
+
text_column,
|
282 |
+
audio_column,
|
283 |
+
"duration",
|
284 |
+
args.num_examples,
|
285 |
+
)
|
286 |
+
|
287 |
+
accelerator.print(
|
288 |
+
"Num instances in train: {}, validation: {}, test: {}".format(
|
289 |
+
train_dataset.get_num_instances(),
|
290 |
+
eval_dataset.get_num_instances(),
|
291 |
+
test_dataset.get_num_instances(),
|
292 |
+
)
|
293 |
+
)
|
294 |
+
|
295 |
+
train_dataloader = DataLoader(
|
296 |
+
train_dataset,
|
297 |
+
shuffle=True,
|
298 |
+
batch_size=config["training"]["per_device_batch_size"],
|
299 |
+
collate_fn=train_dataset.collate_fn,
|
300 |
+
)
|
301 |
+
eval_dataloader = DataLoader(
|
302 |
+
eval_dataset,
|
303 |
+
shuffle=True,
|
304 |
+
batch_size=config["training"]["per_device_batch_size"],
|
305 |
+
collate_fn=eval_dataset.collate_fn,
|
306 |
+
)
|
307 |
+
test_dataloader = DataLoader(
|
308 |
+
test_dataset,
|
309 |
+
shuffle=False,
|
310 |
+
batch_size=config["training"]["per_device_batch_size"],
|
311 |
+
collate_fn=test_dataset.collate_fn,
|
312 |
+
)
|
313 |
+
|
314 |
+
# Optimizer
|
315 |
+
|
316 |
+
optimizer_parameters = list(model.transformer.parameters()) + list(
|
317 |
+
model.fc.parameters()
|
318 |
+
)
|
319 |
+
num_trainable_parameters = sum(
|
320 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
321 |
+
)
|
322 |
+
accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
|
323 |
+
|
324 |
+
if args.load_from_checkpoint:
|
325 |
+
from safetensors.torch import load_file
|
326 |
+
|
327 |
+
w1 = load_file(args.load_from_checkpoint)
|
328 |
+
model.load_state_dict(w1, strict=False)
|
329 |
+
logger.info("Weights loaded from{}".format(args.load_from_checkpoint))
|
330 |
+
|
331 |
+
import copy
|
332 |
+
|
333 |
+
model.ref_transformer = copy.deepcopy(model.transformer)
|
334 |
+
model.ref_transformer.requires_grad_ = False
|
335 |
+
model.ref_transformer.eval()
|
336 |
+
for param in model.ref_transformer.parameters():
|
337 |
+
param.requires_grad = False
|
338 |
+
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
optimizer = torch.optim.AdamW(
|
343 |
+
optimizer_parameters,
|
344 |
+
lr=learning_rate,
|
345 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
346 |
+
weight_decay=args.adam_weight_decay,
|
347 |
+
eps=args.adam_epsilon,
|
348 |
+
)
|
349 |
+
|
350 |
+
# Scheduler and math around the number of training steps.
|
351 |
+
overrode_max_train_steps = False
|
352 |
+
num_update_steps_per_epoch = math.ceil(
|
353 |
+
len(train_dataloader) / gradient_accumulation_steps
|
354 |
+
)
|
355 |
+
if args.max_train_steps is None:
|
356 |
+
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
357 |
+
overrode_max_train_steps = True
|
358 |
+
|
359 |
+
lr_scheduler = get_scheduler(
|
360 |
+
name=args.lr_scheduler_type,
|
361 |
+
optimizer=optimizer,
|
362 |
+
num_warmup_steps=num_warmup_steps
|
363 |
+
* gradient_accumulation_steps
|
364 |
+
* accelerator.num_processes,
|
365 |
+
num_training_steps=args.max_train_steps * gradient_accumulation_steps,
|
366 |
+
)
|
367 |
+
|
368 |
+
# Prepare everything with our `accelerator`.
|
369 |
+
vae, model, optimizer, lr_scheduler = accelerator.prepare(
|
370 |
+
vae, model, optimizer, lr_scheduler
|
371 |
+
)
|
372 |
+
|
373 |
+
train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
|
374 |
+
train_dataloader, eval_dataloader, test_dataloader
|
375 |
+
)
|
376 |
+
|
377 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
378 |
+
num_update_steps_per_epoch = math.ceil(
|
379 |
+
len(train_dataloader) / gradient_accumulation_steps
|
380 |
+
)
|
381 |
+
if overrode_max_train_steps:
|
382 |
+
args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
383 |
+
# Afterwards we recalculate our number of training epochs
|
384 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
385 |
+
|
386 |
+
# Figure out how many steps we should save the Accelerator states
|
387 |
+
checkpointing_steps = args.checkpointing_steps
|
388 |
+
if checkpointing_steps is not None and checkpointing_steps.isdigit():
|
389 |
+
checkpointing_steps = int(checkpointing_steps)
|
390 |
+
|
391 |
+
|
392 |
+
|
393 |
+
# Train!
|
394 |
+
total_batch_size = (
|
395 |
+
per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
396 |
+
)
|
397 |
+
|
398 |
+
logger.info("***** Running training *****")
|
399 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
400 |
+
logger.info(f" Num Epochs = {num_train_epochs}")
|
401 |
+
logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
|
402 |
+
logger.info(
|
403 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
404 |
+
)
|
405 |
+
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
406 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
407 |
+
|
408 |
+
# Only show the progress bar once on each machine.
|
409 |
+
progress_bar = tqdm(
|
410 |
+
range(args.max_train_steps), disable=not accelerator.is_local_main_process
|
411 |
+
)
|
412 |
+
|
413 |
+
completed_steps = 0
|
414 |
+
starting_epoch = 0
|
415 |
+
# Potentially load in the weights and states from a previous save
|
416 |
+
resume_from_checkpoint = config["paths"]["resume_from_checkpoint"]
|
417 |
+
if resume_from_checkpoint != "":
|
418 |
+
accelerator.load_state(resume_from_checkpoint)
|
419 |
+
accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}")
|
420 |
+
|
421 |
+
# Duration of the audio clips in seconds
|
422 |
+
best_loss = np.inf
|
423 |
+
length = config["training"]["max_audio_duration"]
|
424 |
+
|
425 |
+
for epoch in range(starting_epoch, num_train_epochs):
|
426 |
+
model.train()
|
427 |
+
total_loss, total_val_loss = 0, 0
|
428 |
+
|
429 |
+
for step, batch in enumerate(train_dataloader):
|
430 |
+
optimizer.zero_grad()
|
431 |
+
with accelerator.accumulate(model):
|
432 |
+
optimizer.zero_grad()
|
433 |
+
device = accelerator.device
|
434 |
+
text, audio_w, audio_l, duration, _ = batch
|
435 |
+
|
436 |
+
with torch.no_grad():
|
437 |
+
audio_list_w = []
|
438 |
+
audio_list_l = []
|
439 |
+
for audio_path in audio_w:
|
440 |
+
|
441 |
+
wav = read_wav_file(
|
442 |
+
audio_path, length
|
443 |
+
) ## Only read the first 30 seconds of audio
|
444 |
+
if (
|
445 |
+
wav.shape[0] == 1
|
446 |
+
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
|
447 |
+
wav = wav.repeat(2, 1)
|
448 |
+
audio_list_w.append(wav)
|
449 |
+
|
450 |
+
for audio_path in audio_l:
|
451 |
+
wav = read_wav_file(
|
452 |
+
audio_path, length
|
453 |
+
) ## Only read the first 30 seconds of audio
|
454 |
+
if (
|
455 |
+
wav.shape[0] == 1
|
456 |
+
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
|
457 |
+
wav = wav.repeat(2, 1)
|
458 |
+
audio_list_l.append(wav)
|
459 |
+
|
460 |
+
audio_input_w = torch.stack(audio_list_w, dim=0).to(device)
|
461 |
+
audio_input_l = torch.stack(audio_list_l, dim=0).to(device)
|
462 |
+
# audio_input_ = audio_input.to(device)
|
463 |
+
unwrapped_vae = accelerator.unwrap_model(vae)
|
464 |
+
|
465 |
+
duration = torch.tensor(duration, device=device)
|
466 |
+
duration = torch.clamp(
|
467 |
+
duration, max=length
|
468 |
+
) ## max duration is 30 sec
|
469 |
+
|
470 |
+
audio_latent_w = unwrapped_vae.encode(
|
471 |
+
audio_input_w
|
472 |
+
).latent_dist.sample()
|
473 |
+
audio_latent_l = unwrapped_vae.encode(
|
474 |
+
audio_input_l
|
475 |
+
).latent_dist.sample()
|
476 |
+
audio_latent = torch.cat((audio_latent_w, audio_latent_l), dim=0)
|
477 |
+
audio_latent = audio_latent.transpose(
|
478 |
+
1, 2
|
479 |
+
) ## Tranpose to (bsz, seq_len, channel)
|
480 |
+
|
481 |
+
loss, raw_model_loss, raw_ref_loss, implicit_acc = model(
|
482 |
+
audio_latent, text, duration=duration, sft=False
|
483 |
+
)
|
484 |
+
|
485 |
+
total_loss += loss.detach().float()
|
486 |
+
accelerator.backward(loss)
|
487 |
+
optimizer.step()
|
488 |
+
lr_scheduler.step()
|
489 |
+
# if accelerator.sync_gradients:
|
490 |
+
if accelerator.sync_gradients:
|
491 |
+
# accelerator.clip_grad_value_(model.parameters(),1.0)
|
492 |
+
progress_bar.update(1)
|
493 |
+
completed_steps += 1
|
494 |
+
|
495 |
+
if completed_steps % 10 == 0 and accelerator.is_main_process:
|
496 |
+
|
497 |
+
total_norm = 0.0
|
498 |
+
for p in model.parameters():
|
499 |
+
if p.grad is not None:
|
500 |
+
param_norm = p.grad.data.norm(2)
|
501 |
+
total_norm += param_norm.item() ** 2
|
502 |
+
|
503 |
+
total_norm = total_norm**0.5
|
504 |
+
logger.info(
|
505 |
+
f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}"
|
506 |
+
)
|
507 |
+
|
508 |
+
lr = lr_scheduler.get_last_lr()[0]
|
509 |
+
|
510 |
+
result = {
|
511 |
+
"train_loss": loss.item(),
|
512 |
+
"grad_norm": total_norm,
|
513 |
+
"learning_rate": lr,
|
514 |
+
"raw_model_loss": raw_model_loss,
|
515 |
+
"raw_ref_loss": raw_ref_loss,
|
516 |
+
"implicit_acc": implicit_acc,
|
517 |
+
}
|
518 |
+
|
519 |
+
# result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
|
520 |
+
wandb.log(result, step=completed_steps)
|
521 |
+
|
522 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
523 |
+
|
524 |
+
if isinstance(checkpointing_steps, int):
|
525 |
+
if completed_steps % checkpointing_steps == 0:
|
526 |
+
output_dir = f"step_{completed_steps }"
|
527 |
+
if output_dir is not None:
|
528 |
+
output_dir = os.path.join(output_dir, output_dir)
|
529 |
+
accelerator.save_state(output_dir)
|
530 |
+
|
531 |
+
if completed_steps >= args.max_train_steps:
|
532 |
+
break
|
533 |
+
|
534 |
+
model.eval()
|
535 |
+
eval_progress_bar = tqdm(
|
536 |
+
range(len(eval_dataloader)), disable=not accelerator.is_local_main_process
|
537 |
+
)
|
538 |
+
for step, batch in enumerate(eval_dataloader):
|
539 |
+
with accelerator.accumulate(model) and torch.no_grad():
|
540 |
+
device = model.device
|
541 |
+
text, audios, duration, _ = batch
|
542 |
+
|
543 |
+
audio_list = []
|
544 |
+
for audio_path in audios:
|
545 |
+
|
546 |
+
wav = read_wav_file(
|
547 |
+
audio_path, length
|
548 |
+
) ## Only read the first 30 seconds of audio
|
549 |
+
if (
|
550 |
+
wav.shape[0] == 1
|
551 |
+
): ## If this audio is mono, we repeat the channel so it become "fake stereo"
|
552 |
+
wav = wav.repeat(2, 1)
|
553 |
+
audio_list.append(wav)
|
554 |
+
|
555 |
+
audio_input = torch.stack(audio_list, dim=0)
|
556 |
+
audio_input = audio_input.to(device)
|
557 |
+
duration = torch.tensor(duration, device=device)
|
558 |
+
unwrapped_vae = accelerator.unwrap_model(vae)
|
559 |
+
audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
|
560 |
+
audio_latent = audio_latent.transpose(
|
561 |
+
1, 2
|
562 |
+
) ## Tranpose to (bsz, seq_len, channel)
|
563 |
+
|
564 |
+
val_loss, _, _, _ = model(
|
565 |
+
audio_latent, text, duration=duration, sft=True
|
566 |
+
)
|
567 |
+
|
568 |
+
total_val_loss += val_loss.detach().float()
|
569 |
+
eval_progress_bar.update(1)
|
570 |
+
|
571 |
+
if accelerator.is_main_process:
|
572 |
+
|
573 |
+
result = {}
|
574 |
+
result["epoch"] = float(epoch + 1)
|
575 |
+
|
576 |
+
result["epoch/train_loss"] = round(
|
577 |
+
total_loss.item() / len(train_dataloader), 4
|
578 |
+
)
|
579 |
+
result["epoch/val_loss"] = round(
|
580 |
+
total_val_loss.item() / len(eval_dataloader), 4
|
581 |
+
)
|
582 |
+
|
583 |
+
wandb.log(result, step=completed_steps)
|
584 |
+
|
585 |
+
with open("{}/summary.jsonl".format(output_dir), "a") as f:
|
586 |
+
f.write(json.dumps(result) + "\n\n")
|
587 |
+
|
588 |
+
logger.info(result)
|
589 |
+
|
590 |
+
save_checkpoint = True
|
591 |
+
accelerator.wait_for_everyone()
|
592 |
+
if accelerator.is_main_process and args.checkpointing_steps == "best":
|
593 |
+
if save_checkpoint:
|
594 |
+
accelerator.save_state("{}/{}".format(output_dir, "best"))
|
595 |
+
|
596 |
+
if (epoch + 1) % args.save_every == 0:
|
597 |
+
accelerator.save_state(
|
598 |
+
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
|
599 |
+
)
|
600 |
+
|
601 |
+
if accelerator.is_main_process and args.checkpointing_steps == "epoch":
|
602 |
+
accelerator.save_state(
|
603 |
+
"{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
|
604 |
+
)
|
605 |
+
|
606 |
+
|
607 |
+
if __name__ == "__main__":
|
608 |
+
main()
|
external_models/TangoFlux/tangoflux/utils.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
import torchaudio
|
7 |
+
import random
|
8 |
+
import itertools
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
|
15 |
+
def normalize_wav(waveform):
|
16 |
+
waveform = waveform - torch.mean(waveform)
|
17 |
+
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
|
18 |
+
return waveform * 0.5
|
19 |
+
|
20 |
+
|
21 |
+
def pad_wav(waveform, segment_length):
|
22 |
+
waveform_length = len(waveform)
|
23 |
+
|
24 |
+
if segment_length is None or waveform_length == segment_length:
|
25 |
+
return waveform
|
26 |
+
elif waveform_length > segment_length:
|
27 |
+
return waveform[:segment_length]
|
28 |
+
else:
|
29 |
+
padded_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
|
30 |
+
waveform = torch.cat([waveform, padded_wav])
|
31 |
+
return waveform
|
32 |
+
|
33 |
+
|
34 |
+
def read_wav_file(filename, duration_sec):
|
35 |
+
info = torchaudio.info(filename)
|
36 |
+
sample_rate = info.sample_rate
|
37 |
+
|
38 |
+
# Calculate the number of frames corresponding to the desired duration
|
39 |
+
num_frames = int(sample_rate * duration_sec)
|
40 |
+
|
41 |
+
waveform, sr = torchaudio.load(filename, num_frames=num_frames) # Faster!!!
|
42 |
+
|
43 |
+
if waveform.shape[0] == 2: ## Stereo audio
|
44 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=44100)
|
45 |
+
resampled_waveform = resampler(waveform)
|
46 |
+
# print(resampled_waveform.shape)
|
47 |
+
padded_left = pad_wav(
|
48 |
+
resampled_waveform[0], int(44100 * duration_sec)
|
49 |
+
) ## We pad left and right seperately
|
50 |
+
padded_right = pad_wav(resampled_waveform[1], int(44100 * duration_sec))
|
51 |
+
|
52 |
+
return torch.stack([padded_left, padded_right])
|
53 |
+
else:
|
54 |
+
waveform = torchaudio.functional.resample(
|
55 |
+
waveform, orig_freq=sr, new_freq=44100
|
56 |
+
)[0]
|
57 |
+
waveform = pad_wav(waveform, int(44100 * duration_sec)).unsqueeze(0)
|
58 |
+
|
59 |
+
return waveform
|
60 |
+
|
61 |
+
|
62 |
+
class DPOText2AudioDataset(Dataset):
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
dataset,
|
66 |
+
prefix,
|
67 |
+
text_column,
|
68 |
+
audio_w_column,
|
69 |
+
audio_l_column,
|
70 |
+
duration,
|
71 |
+
num_examples=-1,
|
72 |
+
):
|
73 |
+
|
74 |
+
inputs = list(dataset[text_column])
|
75 |
+
self.inputs = [prefix + inp for inp in inputs]
|
76 |
+
self.audios_w = list(dataset[audio_w_column])
|
77 |
+
self.audios_l = list(dataset[audio_l_column])
|
78 |
+
self.durations = list(dataset[duration])
|
79 |
+
self.indices = list(range(len(self.inputs)))
|
80 |
+
|
81 |
+
self.mapper = {}
|
82 |
+
for index, audio_w, audio_l, duration, text in zip(
|
83 |
+
self.indices, self.audios_w, self.audios_l, self.durations, inputs
|
84 |
+
):
|
85 |
+
self.mapper[index] = [audio_w, audio_l, duration, text]
|
86 |
+
|
87 |
+
if num_examples != -1:
|
88 |
+
self.inputs, self.audios_w, self.audios_l, self.durations = (
|
89 |
+
self.inputs[:num_examples],
|
90 |
+
self.audios_w[:num_examples],
|
91 |
+
self.audios_l[:num_examples],
|
92 |
+
self.durations[:num_examples],
|
93 |
+
)
|
94 |
+
self.indices = self.indices[:num_examples]
|
95 |
+
|
96 |
+
def __len__(self):
|
97 |
+
return len(self.inputs)
|
98 |
+
|
99 |
+
def get_num_instances(self):
|
100 |
+
return len(self.inputs)
|
101 |
+
|
102 |
+
def __getitem__(self, index):
|
103 |
+
s1, s2, s3, s4, s5 = (
|
104 |
+
self.inputs[index],
|
105 |
+
self.audios_w[index],
|
106 |
+
self.audios_l[index],
|
107 |
+
self.durations[index],
|
108 |
+
self.indices[index],
|
109 |
+
)
|
110 |
+
return s1, s2, s3, s4, s5
|
111 |
+
|
112 |
+
def collate_fn(self, data):
|
113 |
+
dat = pd.DataFrame(data)
|
114 |
+
return [dat[i].tolist() for i in dat]
|
115 |
+
|
116 |
+
|
117 |
+
class Text2AudioDataset(Dataset):
|
118 |
+
def __init__(
|
119 |
+
self, dataset, prefix, text_column, audio_column, duration, num_examples=-1
|
120 |
+
):
|
121 |
+
|
122 |
+
inputs = list(dataset[text_column])
|
123 |
+
self.inputs = [prefix + inp for inp in inputs]
|
124 |
+
self.audios = list(dataset[audio_column])
|
125 |
+
self.durations = list(dataset[duration])
|
126 |
+
self.indices = list(range(len(self.inputs)))
|
127 |
+
|
128 |
+
self.mapper = {}
|
129 |
+
for index, audio, duration, text in zip(
|
130 |
+
self.indices, self.audios, self.durations, inputs
|
131 |
+
):
|
132 |
+
self.mapper[index] = [audio, text, duration]
|
133 |
+
|
134 |
+
if num_examples != -1:
|
135 |
+
self.inputs, self.audios, self.durations = (
|
136 |
+
self.inputs[:num_examples],
|
137 |
+
self.audios[:num_examples],
|
138 |
+
self.durations[:num_examples],
|
139 |
+
)
|
140 |
+
self.indices = self.indices[:num_examples]
|
141 |
+
|
142 |
+
def __len__(self):
|
143 |
+
return len(self.inputs)
|
144 |
+
|
145 |
+
def get_num_instances(self):
|
146 |
+
return len(self.inputs)
|
147 |
+
|
148 |
+
def __getitem__(self, index):
|
149 |
+
s1, s2, s3, s4 = (
|
150 |
+
self.inputs[index],
|
151 |
+
self.audios[index],
|
152 |
+
self.durations[index],
|
153 |
+
self.indices[index],
|
154 |
+
)
|
155 |
+
return s1, s2, s3, s4
|
156 |
+
|
157 |
+
def collate_fn(self, data):
|
158 |
+
dat = pd.DataFrame(data)
|
159 |
+
return [dat[i].tolist() for i in dat]
|
external_models/TangoFlux/train.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
CUDA_VISISBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' tangoflux/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
|
external_models/depth-fm/.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*__pycache__*
|
2 |
+
sandbox
|
3 |
+
*.ckpt
|
4 |
+
*-depth.png
|
5 |
+
evaluation
|