FQiao commited on
Commit
3324de2
·
verified ·
1 Parent(s): 8f70a8a

Upload 70 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. DepthEstimator.py +67 -0
  3. GenerateAudio.py +184 -0
  4. GenerateCaptions.py +494 -0
  5. README.md +50 -13
  6. SoundMapper.py +438 -0
  7. app.py +182 -0
  8. audio_mixer.py +428 -0
  9. config.py +16 -0
  10. environment.yml +8 -0
  11. external_models/TangoFlux/.gitignore +175 -0
  12. external_models/TangoFlux/Demo.ipynb +117 -0
  13. external_models/TangoFlux/Inference.ipynb +0 -0
  14. external_models/TangoFlux/LICENSE.md +51 -0
  15. external_models/TangoFlux/Notice +1 -0
  16. external_models/TangoFlux/README.md +188 -0
  17. external_models/TangoFlux/STABILITY_AI_COMMUNITY_LICENSE.md +57 -0
  18. external_models/TangoFlux/__init__.py +4 -0
  19. external_models/TangoFlux/assets/tangoflux.png +3 -0
  20. external_models/TangoFlux/assets/tf_opener.png +3 -0
  21. external_models/TangoFlux/assets/tf_teaser.png +3 -0
  22. external_models/TangoFlux/comfyui/README.md +78 -0
  23. external_models/TangoFlux/comfyui/__init__.py +6 -0
  24. external_models/TangoFlux/comfyui/example_workflow.json +168 -0
  25. external_models/TangoFlux/comfyui/install.py +79 -0
  26. external_models/TangoFlux/comfyui/nodes.py +328 -0
  27. external_models/TangoFlux/comfyui/requirements.txt +9 -0
  28. external_models/TangoFlux/comfyui/server.py +64 -0
  29. external_models/TangoFlux/comfyui/teacache.py +283 -0
  30. external_models/TangoFlux/comfyui/web/js/playAudio.js +59 -0
  31. external_models/TangoFlux/configs/__init__.py +0 -0
  32. external_models/TangoFlux/configs/accelerator_config.yaml +17 -0
  33. external_models/TangoFlux/configs/tangoflux_config.yaml +36 -0
  34. external_models/TangoFlux/crpo.sh +2 -0
  35. external_models/TangoFlux/inference.py +7 -0
  36. external_models/TangoFlux/replicate_demo/cog.yaml +31 -0
  37. external_models/TangoFlux/replicate_demo/predict.py +92 -0
  38. external_models/TangoFlux/requirements.txt +12 -0
  39. external_models/TangoFlux/setup.py +30 -0
  40. external_models/TangoFlux/tangoflux/__init__.py +60 -0
  41. external_models/TangoFlux/tangoflux/cli.py +29 -0
  42. external_models/TangoFlux/tangoflux/demo.py +63 -0
  43. external_models/TangoFlux/tangoflux/generate_crpo_dataset.py +204 -0
  44. external_models/TangoFlux/tangoflux/label_crpo.py +153 -0
  45. external_models/TangoFlux/tangoflux/model.py +556 -0
  46. external_models/TangoFlux/tangoflux/train.py +588 -0
  47. external_models/TangoFlux/tangoflux/train_dpo.py +608 -0
  48. external_models/TangoFlux/tangoflux/utils.py +159 -0
  49. external_models/TangoFlux/train.sh +2 -0
  50. external_models/depth-fm/.gitignore +5 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ external_models/depth-fm/assets/dog.png filter=lfs diff=lfs merge=lfs -text
37
+ external_models/depth-fm/assets/figures/dfm-cover.png filter=lfs diff=lfs merge=lfs -text
38
+ external_models/depth-fm/assets/figures/radio.png filter=lfs diff=lfs merge=lfs -text
39
+ external_models/TangoFlux/assets/tangoflux.png filter=lfs diff=lfs merge=lfs -text
40
+ external_models/TangoFlux/assets/tf_opener.png filter=lfs diff=lfs merge=lfs -text
41
+ external_models/TangoFlux/assets/tf_teaser.png filter=lfs diff=lfs merge=lfs -text
DepthEstimator.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from accelerate.test_utils.testing import get_backend
3
+ from PIL import Image
4
+ import os
5
+ import sys
6
+ from config import LOGS_DIR, DEPTH_FM_CHECKPOINT, DEPTH_FM_DIR
7
+ sys.path.append(DEPTH_FM_DIR + '/depthfm')
8
+ from dfm import DepthFM
9
+ from unet import UNetModel
10
+ import einops
11
+ import numpy as np
12
+ from torchvision import transforms
13
+
14
+
15
+ class DepthEstimator:
16
+ def __init__(self, image_dir = LOGS_DIR):
17
+ self.device,_,_ = get_backend()
18
+ self.image_dir = image_dir
19
+ self.model = None
20
+
21
+ def _load_model(self):
22
+ if self.model is None:
23
+ self.model = DepthFM(DEPTH_FM_CHECKPOINT).to(self.device).eval()
24
+ else:
25
+ self.model = self.model.to(self.device).eval()
26
+
27
+ def _unload_model(self):
28
+ if self.model is not None:
29
+ self.model = self.model.to("cpu")
30
+ torch.cuda.empty_cache()
31
+
32
+
33
+ def estimate_depth(self, image_path : str) -> list:
34
+ print("Estimating depth...")
35
+ predictions_list = []
36
+ self._load_model()
37
+ for img in os.listdir(image_path):
38
+ if img.endswith(".jpg") or img.endswith(".jpeg") or img.endswith(".png"):
39
+ image = Image.open(os.path.join(image_path, img))
40
+ x = np.array(image)
41
+ x = einops.rearrange(x, 'h w c -> c h w')
42
+ x = x / 127.5 - 1
43
+ x = torch.tensor(x, dtype=torch.float32)[None]
44
+ with torch.no_grad():
45
+ depth = self.model.predict_depth(x.to(self.device), num_steps=2, ensemble_size=4) # returns a tensor
46
+ depth.cpu()
47
+ to_pil = transforms.ToPILImage()
48
+ PIL_image = to_pil(depth.squeeze())
49
+ predictions_list.append({"depth": PIL_image})
50
+ del x, depth
51
+ torch.cuda.empty_cache()
52
+ self._unload_model()
53
+ print("Depth estimation complete.")
54
+ return predictions_list
55
+
56
+ def visualize(self, predictions_list : list) -> None:
57
+ for (i, prediction) in enumerate(predictions_list):
58
+ prediction["depth"].save(f"depth_{i}.png")
59
+
60
+
61
+ # Estimator = DepthEstimator()
62
+ # predictions = Estimator.estimate_depth(Estimator.image_dir)
63
+ # Estimator.visualize(predictions)
64
+
65
+
66
+
67
+
GenerateAudio.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import sys
3
+ import torch
4
+ import random
5
+ from config import TANGO_FLUX_DIR
6
+ sys.path.append(TANGO_FLUX_DIR)
7
+ from tangoflux import TangoFluxInference
8
+ from transformers import AutoTokenizer, T5EncoderModel
9
+ from collections import Counter
10
+
11
+ class GenerateAudio():
12
+ def __init__(self):
13
+ self.device = "cuda"
14
+ self.model = None
15
+ self.text_encoder = None
16
+
17
+ # Basic categories for object classification
18
+ self.categories = {
19
+ 'vehicle': ['car', 'bus', 'truck', 'motorcycle', 'bicycle', 'train', 'vehicle'],
20
+ 'nature': ['tree', 'bird', 'water', 'river', 'lake', 'ocean', 'rain', 'wind', 'forest'],
21
+ 'urban': ['traffic', 'building', 'street', 'signal', 'construction'],
22
+ 'animal': ['dog', 'cat', 'bird', 'insect', 'frog', 'squirrel'],
23
+ 'human': ['person', 'people', 'crowd', 'child', 'footstep', 'voice'],
24
+ 'indoor': ['door', 'window', 'chair', 'table', 'fan', 'appliance', 'tv', 'radio']
25
+ }
26
+
27
+ # Suffixes and prefixes for pattern matching
28
+ self.suffixes = {
29
+ 'tree': 'nature',
30
+ 'bird': 'animal',
31
+ 'car': 'vehicle',
32
+ 'truck': 'vehicle',
33
+ 'signal': 'urban'
34
+ }
35
+
36
+ def _load_model(self):
37
+ if self.model is None:
38
+ self.model = TangoFluxInference(name='declare-lab/TangoFlux')
39
+ if self.text_encoder is None:
40
+ self.text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-large").to(self.device).eval()
41
+ else:
42
+ self.text_encoder = self.text_encoder.to(self.device)
43
+
44
+ def generate_sound(self, prompt, steps=25, duration=10, guidance_scale=4.5, disable_progress=True):
45
+ self._load_model()
46
+ with torch.no_grad():
47
+ latents = self.model.model.inference_flow(
48
+ prompt,
49
+ duration=duration,
50
+ num_inference_steps=steps,
51
+ guidance_scale=guidance_scale,
52
+ disable_progress=disable_progress
53
+ )
54
+ wave = self.model.vae.decode(latents.transpose(2, 1)).sample.cpu()[0]
55
+ waveform_end = int(duration * self.model.vae.config.sampling_rate)
56
+ wave = wave[:, :waveform_end]
57
+
58
+ return wave
59
+
60
+ def _categorize_object(self, object_name):
61
+ """Categorize an object based on keywords or patterns"""
62
+ object_lower = object_name.lower()
63
+
64
+ # Check if the object contains any category keywords
65
+ for category, keywords in self.categories.items():
66
+ for keyword in keywords:
67
+ if keyword in object_lower:
68
+ return category
69
+
70
+ # Check suffix/prefix patterns
71
+ words = object_lower.split()
72
+ for word in words:
73
+ for suffix, category in self.suffixes.items():
74
+ if word.endswith(suffix):
75
+ return category
76
+
77
+ return "unknown"
78
+
79
+ def _describe_object_sound(self, object_name, zone):
80
+ """Generate an appropriate sound description based on object type and distance"""
81
+ category = self._categorize_object(object_name)
82
+
83
+ # Volume descriptor based on zone
84
+ volume_descriptors = {
85
+ "near": ["prominent", "clear", "loud", "distinct"],
86
+ "medium": ["moderate", "audible", "present"],
87
+ "far": ["subtle", "distant", "faint", "soft"]
88
+ }
89
+
90
+ volume = random.choice(volume_descriptors[zone])
91
+
92
+ # Sound descriptors based on category
93
+ sound_templates = {
94
+ "vehicle": [
95
+ "{volume} engine sounds from the {object}",
96
+ "{volume} mechanical noise of the {object}",
97
+ "the {object} creating {volume} road noise",
98
+ "{volume} sounds of the {object} in motion"
99
+ ],
100
+ "nature": [
101
+ "{volume} rustling of the {object}",
102
+ "the {object} making {volume} natural sounds",
103
+ "{volume} environmental sounds from the {object}",
104
+ "the {object} with {volume} movement in the wind"
105
+ ],
106
+ "urban": [
107
+ "{volume} urban sounds around the {object}",
108
+ "the {object} with {volume} city ambience",
109
+ "{volume} noise from the {object}",
110
+ "the {object} contributing to {volume} street sounds"
111
+ ],
112
+ "animal": [
113
+ "{volume} calls from the {object}",
114
+ "the {object} making {volume} animal sounds",
115
+ "{volume} sounds of the {object}",
116
+ "the {object} with its {volume} presence"
117
+ ],
118
+ "human": [
119
+ "{volume} voices from the {object}",
120
+ "the {object} creating {volume} human sounds",
121
+ "{volume} movement sounds from the {object}",
122
+ "the {object} with {volume} activity"
123
+ ],
124
+ "indoor": [
125
+ "{volume} ambient sounds around the {object}",
126
+ "the {object} making {volume} indoor noises",
127
+ "{volume} mechanical sounds from the {object}",
128
+ "the {object} with its {volume} presence"
129
+ ],
130
+ "unknown": [
131
+ "{volume} sounds from the {object}",
132
+ "the {object} creating {volume} audio",
133
+ "{volume} noises associated with the {object}",
134
+ "the {object} with its {volume} acoustic presence"
135
+ ]
136
+ }
137
+
138
+ # Select a template for this category
139
+ templates = sound_templates.get(category, sound_templates["unknown"])
140
+ template = random.choice(templates)
141
+
142
+ # Fill in the template
143
+ description = template.format(volume=volume, object=object_name)
144
+ return description
145
+
146
+ def create_audio_prompt(self, object_depths):
147
+ if not object_depths:
148
+ return "Environmental ambient sounds."
149
+
150
+ for obj in object_depths:
151
+ if obj.get("sound_description") and len(obj["sound_description"]) > 5:
152
+ return obj["sound_description"]
153
+ return f"Sounds of {object_depths[0]['original_label']}."
154
+
155
+ def process_and_generate_audio(self, object_depths, output_path=None, duration=10, steps=25, guidance_scale=4.5):
156
+ self._load_model()
157
+
158
+ if not object_depths:
159
+ prompt = "Environmental ambient sounds."
160
+ else:
161
+ # Sort objects by depth to prioritize closer objects
162
+ sorted_objects = sorted(object_depths, key=lambda x: x["mean_depth"])
163
+ prompt = self.create_audio_prompt(sorted_objects)
164
+
165
+ print(f"Generated audio prompt: {prompt}")
166
+
167
+ wave = self.generate_sound(
168
+ prompt,
169
+ steps=steps,
170
+ duration=duration,
171
+ guidance_scale=guidance_scale
172
+ )
173
+
174
+ sample_rate = self.model.vae.config.sampling_rate
175
+
176
+ if output_path:
177
+ torchaudio.save(
178
+ output_path,
179
+ wave.unsqueeze(0),
180
+ sample_rate
181
+ )
182
+ print(f"Audio saved to: {output_path}")
183
+
184
+ return wave, sample_rate
GenerateCaptions.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ streetsoundtext.py - A pipeline that downloads Google Street View panoramas,
4
+ extracts perspective views, and analyzes them for sound information.
5
+ """
6
+
7
+ import os
8
+ import requests
9
+ import argparse
10
+ import numpy as np
11
+ import torch
12
+ import time
13
+ from PIL import Image
14
+ from io import BytesIO
15
+ from config import LOGS_DIR
16
+ import torchvision.transforms as T
17
+ from torchvision.transforms.functional import InterpolationMode
18
+ from transformers import AutoModel, AutoTokenizer
19
+ from utils import sample_perspective_img
20
+ import cv2
21
+
22
+ log_dir = LOGS_DIR
23
+ os.makedirs(log_dir, exist_ok=True) # Creates the directory if it doesn't exist
24
+
25
+ # soundscape_query = "<image>\nWhat can we expect to hear from the location captured in this image? Name the around five nouns. Avoid speculation and provide a concise response including sound sources visible in the image."
26
+ soundscape_query = """<image>
27
+ Identify 5 potential sound sources visible in this image. For each source, provide both the noun and a brief description of its typical sound.
28
+
29
+ Format your response exactly like these examples (do not include the word "Noun:" in your response):
30
+ Car: engine humming with occasional honking.
31
+ River: gentle flowing water with subtle splashing sounds.
32
+ Trees: rustling leaves moved by the wind.
33
+ """
34
+ # Constants
35
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
36
+ IMAGENET_STD = (0.229, 0.224, 0.225)
37
+
38
+ # Model Leaderboard Paths
39
+ MODEL_LEADERBOARD = {
40
+ "intern_2_5-8B": "OpenGVLab/InternVL2_5-8B-MPO",
41
+ "intern_2_5-4B": "OpenGVLab/InternVL2_5-4B-MPO",
42
+ }
43
+
44
+ class StreetViewDownloader:
45
+ """Downloads panoramic images from Google Street View"""
46
+
47
+ def __init__(self):
48
+ # URLs for API requests
49
+ # https://www.google.ca/maps/rpc/photo/listentityphotos?authuser=0&hl=en&gl=us&pb=!1e3!5m45!2m2!1i203!2i100!3m3!2i4!3sCAEIBAgFCAYgAQ!5b1!7m33!1m3!1e1!2b0!3e3!1m3!1e2!2b1!3e2!1m3!1e2!2b0!3e3!1m3!1e8!2b0!3e3!1m3!1e10!2b0!3e3!1m3!1e10!2b1!3e2!1m3!1e10!2b0!3e4!1m3!1e9!2b1!3e2!2b1!8m0!9b0!11m1!4b1!6m3!1sI63QZ8b4BcSli-gPvPHf-Qc!7e81!15i11021!9m2!2d-90.30324219145255!3d38.636242944711036!10d91.37627840655999
50
+ #self.panoid_req = 'https://www.google.com/maps/preview/reveal?authuser=0&hl=en&gl=us&pb=!2m9!1m3!1d82597.14038230096!2d{}!3d{}!2m0!3m2!1i1523!2i1272!4f13.1!3m2!2d{}!3d{}!4m2!1syPETZOjwLvCIptQPiJum-AQ!7e81!5m5!2m4!1i96!2i64!3i1!4i8'
51
+ self.panoid_req = 'https://www.google.ca/maps/rpc/photo/listentityphotos?authuser=0&hl=en&gl=us&pb=!1e3!5m45!2m2!1i203!2i100!3m3!2i4!3sCAEIBAgFCAYgAQ!5b1!7m33!1m3!1e1!2b0!3e3!1m3!1e2!2b1!3e2!1m3!1e2!2b0!3e3!1m3!1e8!2b0!3e3!1m3!1e10!2b0!3e3!1m3!1e10!2b1!3e2!1m3!1e10!2b0!3e4!1m3!1e9!2b1!3e2!2b1!8m0!9b0!11m1!4b1!6m3!1sI63QZ8b4BcSli-gPvPHf-Qc!7e81!15i11021!9m2!2d{}!3d{}!10d25'
52
+ # https://www.google.com/maps/photometa/v1?authuser=0&hl=en&gl=us&pb=!1m4!1smaps_sv.tactile!11m2!2m1!1b1!2m2!1sen!2sus!3m3!1m2!1e2!2s{}!4m61!1e1!1e2!1e3!1e4!1e5!1e6!1e8!1e12!1e17!2m1!1e1!4m1!1i48!5m1!1e1!5m1!1e2!6m1!1e1!6m1!1e2!9m36!1m3!1e2!2b1!3e2!1m3!1e2!2b0!3e3!1m3!1e3!2b1!3e2!1m3!1e3!2b0!3e3!1m3!1e8!2b0!3e3!1m3!1e1!2b0!3e3!1m3!1e4!2b0!3e3!1m3!1e10!2b1!3e2!1m3!1e10!2b0!3e3!11m2!3m1!4b1 # vmSzE7zkK2eETwAP_r8UdQ
53
+ # https://www.google.ca/maps/photometa/v1?authuser=0&hl=en&gl=us&pb=!1m4!1smaps_sv.tactile!11m2!2m1!1b1!2m2!1sen!2sus!3m3!1m2!1e2!2s{}!4m61!1e1!1e2!1e3!1e4!1e5!1e6!1e8!1e12!1e17!2m1!1e1!4m1!1i48!5m1!1e1!5m1!1e2!6m1!1e1!6m1!1e2!9m36!1m3!1e2!2b1!3e2!1m3!1e2!2b0!3e3!1m3!1e3!2b1!3e2!1m3!1e3!2b0!3e3!1m3!1e8!2b0!3e3!1m3!1e1!2b0!3e3!1m3!1e4!2b0!3e3!1m3!1e10!2b1!3e2!1m3!1e10!2b0!3e3!11m2!3m1!4b1 # -9HfuNFUDOw_IP5SA5IspA
54
+ self.photometa_req = 'https://www.google.com/maps/photometa/v1?authuser=0&hl=en&gl=us&pb=!1m4!1smaps_sv.tactile!11m2!2m1!1b1!2m2!1sen!2sus!3m5!1m2!1e2!2s{}!2m1!5s0x87d8b49f53fc92e9:0x6ecb6e520c6f4d9f!4m57!1e1!1e2!1e3!1e4!1e5!1e6!1e8!1e12!2m1!1e1!4m1!1i48!5m1!1e1!5m1!1e2!6m1!1e1!6m1!1e2!9m36!1m3!1e2!2b1!3e2!1m3!1e2!2b0!3e3!1m3!1e3!2b1!3e2!1m3!1e3!2b0!3e3!1m3!1e8!2b0!3e3!1m3!1e1!2b0!3e3!1m3!1e4!2b0!3e3!1m3!1e10!2b1!3e2!1m3!1e10!2b0!3e3'
55
+ self.panimg_req = 'https://streetviewpixels-pa.googleapis.com/v1/tile?cb_client=maps_sv.tactile&panoid={}&x={}&y={}&zoom={}'
56
+ def get_image_id(self, lat, lon):
57
+ """Get Street View panorama ID for given coordinates"""
58
+ null = None
59
+ pr_response = requests.get(self.panoid_req.format(lon, lat, lon, lat))
60
+ if pr_response.status_code != 200:
61
+ error_message = f"Error fetching panorama ID: HTTP {pr_response.status_code}"
62
+ if pr_response.status_code == 400:
63
+ error_message += " - Bad request. Check coordinates format."
64
+ elif pr_response.status_code == 401 or pr_response.status_code == 403:
65
+ error_message += " - Authentication error. Check API key and permissions."
66
+ elif pr_response.status_code == 404:
67
+ error_message += " - No panorama found at these coordinates."
68
+ elif pr_response.status_code == 429:
69
+ error_message += " - Rate limit exceeded. Try again later."
70
+ elif pr_response.status_code >= 500:
71
+ error_message += " - Server error. Try again later."
72
+ return None
73
+
74
+ pr = BytesIO(pr_response.content).getvalue().decode('utf-8')
75
+ pr = eval(pr[pr.index('\n'):])
76
+ try:
77
+ panoid = pr[0][0][0]
78
+ except:
79
+ return None
80
+
81
+ return panoid
82
+
83
+ def download_image(self, lat, lon, zoom=1):
84
+ """Download Street View panorama and metadata"""
85
+ null = None
86
+ panoid = self.get_image_id(lat, lon)
87
+ if panoid is None:
88
+ raise ValueError(f"get_image_id failed() at coordinates: {lat}, {lon}")
89
+
90
+ # Get metadata
91
+ pm_response = requests.get(self.photometa_req.format(panoid))
92
+ pm = BytesIO(pm_response.content).getvalue().decode('utf-8')
93
+ pm = eval(pm[pm.index('\n'):])
94
+ pan_list = pm[1][0][5][0][3][0]
95
+
96
+ # Extract relevant info
97
+ pid = pan_list[0][0][1]
98
+ plat = pan_list[0][2][0][2]
99
+ plon = pan_list[0][2][0][3]
100
+ p_orient = pan_list[0][2][2][0]
101
+
102
+ # Download image tiles and assemble panorama
103
+ img_part_inds = [(x, y) for x in range(2**zoom) for y in range(2**(zoom-1))]
104
+ img = np.zeros((512*(2**(zoom-1)), 512*(2**zoom), 3), dtype=np.uint8)
105
+
106
+ for x, y in img_part_inds:
107
+ sub_img_response = requests.get(self.panimg_req.format(pid, x, y, zoom))
108
+ sub_img = np.array(Image.open(BytesIO(sub_img_response.content)))
109
+ img[512*y:512*(y+1), 512*x:512*(x+1)] = sub_img
110
+
111
+ if (img[-1] == 0).all():
112
+ # raise ValueError("Failed to download complete panorama")
113
+ print("Failed to download complete panorama")
114
+
115
+ return img, pid, plat, plon, p_orient
116
+
117
+
118
+ class PerspectiveExtractor:
119
+ """Extracts perspective views from panoramic images"""
120
+
121
+ def __init__(self, output_shape=(256, 256), fov=(90, 90)):
122
+ self.output_shape = output_shape
123
+ self.fov = fov
124
+
125
+ def extract_views(self, pano_img, face_size=512):
126
+ """Extract front, back, left, and right views based on orientation"""
127
+ # orientations = {
128
+ # "front": (0, p_orient, 0), # Align front with real orientation
129
+ # "back": (0, p_orient + 180, 0), # Behind
130
+ # "left": (0, p_orient - 90, 0), # Left side
131
+ # "right": (0, p_orient + 90, 0), # Right side
132
+ # }
133
+
134
+ # cutouts = {}
135
+ # for view, rot in orientations.items():
136
+ # cutout, fov, applied_rot = sample_perspective_img(
137
+ # pano_img, self.output_shape, fov=self.fov, rot=rot
138
+ # )
139
+ # cutouts[view] = cutout
140
+
141
+ # return cutouts
142
+ """
143
+ Convert ERP panorama to four cubic faces: Front, Left, Back, Right.
144
+ Args:
145
+ erp_img (numpy.ndarray): The input equirectangular image.
146
+ face_size (int): The size of each cubic face.
147
+ Returns:
148
+ dict: A dictionary with the four cube faces.
149
+ """
150
+ # Get ERP dimensions
151
+ h_erp, w_erp, _ = pano_img.shape
152
+ # Define cube face directions (yaw, pitch, roll)
153
+ cube_faces = {
154
+ "front": (0, 0),
155
+ "left": (90, 0),
156
+ "back": (180, 0),
157
+ "right": (-90, 0),
158
+ }
159
+ # Output faces
160
+ faces = {}
161
+ # Generate each face
162
+ for face_name, (yaw, pitch) in cube_faces.items():
163
+ # Create a perspective transformation matrix
164
+ fov = 90 # Field of view
165
+ K = np.array([
166
+ [face_size / (2 * np.tan(np.radians(fov / 2))), 0, face_size / 2],
167
+ [0, face_size / (2 * np.tan(np.radians(fov / 2))), face_size / 2],
168
+ [0, 0, 1]
169
+ ])
170
+ # Generate 3D world coordinates for the cube face
171
+ x, y = np.meshgrid(np.linspace(-1, 1, face_size), np.linspace(-1, 1, face_size))
172
+ z = np.ones_like(x)
173
+ # Normalize 3D points
174
+ points_3d = np.stack((x, y, z), axis=-1) # Shape: (H, W, 3)
175
+ points_3d /= np.linalg.norm(points_3d, axis=-1, keepdims=True)
176
+ # Apply rotation to align with the cube face
177
+ yaw_rad, pitch_rad = np.radians(yaw), np.radians(pitch)
178
+ Ry = np.array([[np.cos(yaw_rad), 0, np.sin(yaw_rad)], [0, 1, 0], [-np.sin(yaw_rad), 0, np.cos(yaw_rad)]])
179
+ Rx = np.array([[1, 0, 0], [0, np.cos(pitch_rad), -np.sin(pitch_rad)], [0, np.sin(pitch_rad), np.cos(pitch_rad)]])
180
+ R = Ry @ Rx
181
+ # Rotate points
182
+ points_3d_rot = np.einsum('ij,hwj->hwi', R, points_3d)
183
+ # Convert 3D to spherical coordinates
184
+ lon = np.arctan2(points_3d_rot[..., 0], points_3d_rot[..., 2])
185
+ lat = np.arcsin(points_3d_rot[..., 1])
186
+ # Map spherical coordinates to ERP image coordinates
187
+ x_erp = (w_erp * (lon / (2 * np.pi) + 0.5)).astype(np.float32)
188
+ y_erp = (h_erp * (0.5 - lat / np.pi)).astype(np.float32)
189
+ # Sample pixels from ERP image
190
+ face_img = cv2.remap(pano_img, x_erp, y_erp, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP)
191
+ cv2.rotate(face_img, cv2.ROTATE_180, face_img)
192
+ faces[face_name] = face_img
193
+ return faces
194
+
195
+
196
+ class ImageAnalyzer:
197
+ """Analyzes images using Vision-Language Models"""
198
+
199
+ def __init__(self, model_name="intern_2_5-4B", use_cuda=True):
200
+ self.model_name = model_name
201
+ self.use_cuda = use_cuda and torch.cuda.is_available()
202
+ self.model, self.tokenizer, self.device = self._load_model()
203
+
204
+ def _load_model(self):
205
+ """Load selected Vision-Language Model"""
206
+ if self.model_name not in MODEL_LEADERBOARD:
207
+ raise ValueError(f"Model '{self.model_name}' not found. Choose from: {list(MODEL_LEADERBOARD.keys())}")
208
+
209
+ model_path = MODEL_LEADERBOARD[self.model_name]
210
+
211
+ # Configure device and parameters
212
+ if self.use_cuda:
213
+ device = torch.device("cuda")
214
+ torch_dtype = torch.bfloat16
215
+ use_flash_attn = True
216
+ else:
217
+ device = torch.device("cpu")
218
+ torch_dtype = torch.float32
219
+ use_flash_attn = False
220
+
221
+ # Load model and tokenizer
222
+ model = AutoModel.from_pretrained(
223
+ model_path,
224
+ torch_dtype=torch_dtype,
225
+ load_in_8bit=False,
226
+ low_cpu_mem_usage=True,
227
+ use_flash_attn=use_flash_attn,
228
+ trust_remote_code=True,
229
+ ).eval().to(device)
230
+
231
+ tokenizer = AutoTokenizer.from_pretrained(
232
+ model_path,
233
+ trust_remote_code=True,
234
+ use_fast=False
235
+ )
236
+
237
+ return model, tokenizer, device
238
+
239
+ def _build_transform(self, input_size=448):
240
+ """Create image transformation pipeline"""
241
+ transform = T.Compose([
242
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
243
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
244
+ T.ToTensor(),
245
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
246
+ ])
247
+ return transform
248
+
249
+ def _find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
250
+ """Find closest aspect ratio for image tiling"""
251
+ best_ratio_diff = float('inf')
252
+ best_ratio = (1, 1)
253
+ area = width * height
254
+ for ratio in target_ratios:
255
+ target_aspect_ratio = ratio[0] / ratio[1]
256
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
257
+ if ratio_diff < best_ratio_diff:
258
+ best_ratio_diff = ratio_diff
259
+ best_ratio = ratio
260
+ elif ratio_diff == best_ratio_diff:
261
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
262
+ best_ratio = ratio
263
+ return best_ratio
264
+
265
+ def _preprocess_image(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
266
+ """Preprocess image for model input"""
267
+ orig_width, orig_height = image.size
268
+ aspect_ratio = orig_width / orig_height
269
+
270
+ # Calculate possible image aspect ratios
271
+ target_ratios = set(
272
+ (i, j) for n in range(min_num, max_num + 1)
273
+ for i in range(1, n + 1)
274
+ for j in range(1, n + 1)
275
+ if i * j <= max_num and i * j >= min_num
276
+ )
277
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
278
+
279
+ # Find closest aspect ratio
280
+ target_aspect_ratio = self._find_closest_aspect_ratio(
281
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
282
+ )
283
+
284
+ # Calculate target dimensions
285
+ target_width = image_size * target_aspect_ratio[0]
286
+ target_height = image_size * target_aspect_ratio[1]
287
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
288
+
289
+ # Resize and split image
290
+ resized_img = image.resize((target_width, target_height))
291
+ processed_images = []
292
+ for i in range(blocks):
293
+ box = (
294
+ (i % (target_width // image_size)) * image_size,
295
+ (i // (target_width // image_size)) * image_size,
296
+ ((i % (target_width // image_size)) + 1) * image_size,
297
+ ((i // (target_width // image_size)) + 1) * image_size
298
+ )
299
+ split_img = resized_img.crop(box)
300
+ processed_images.append(split_img)
301
+
302
+ assert len(processed_images) == blocks
303
+ if use_thumbnail and len(processed_images) != 1:
304
+ thumbnail_img = image.resize((image_size, image_size))
305
+ processed_images.append(thumbnail_img)
306
+
307
+ return processed_images
308
+
309
+ def load_image(self, image_path, input_size=448, max_num=12):
310
+ """Load and process image for analysis"""
311
+ image = Image.open(image_path).convert('RGB')
312
+ transform = self._build_transform(input_size)
313
+ images = self._preprocess_image(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
314
+ pixel_values = [transform(image) for image in images]
315
+ pixel_values = torch.stack(pixel_values)
316
+ return pixel_values
317
+
318
+ def analyze_image(self, image_path, max_num=12):
319
+ """Analyze image for expected sounds"""
320
+ # Load and process image
321
+ pixel_values = self.load_image(image_path, max_num=max_num)
322
+
323
+ # Move to device with appropriate dtype
324
+ if self.device.type == "cuda":
325
+ pixel_values = pixel_values.to(torch.bfloat16).to(self.device)
326
+ else:
327
+ pixel_values = pixel_values.to(torch.float32).to(self.device)
328
+
329
+ # Create sound-focused query
330
+ query = soundscape_query
331
+
332
+ # Generate response
333
+ generation_config = dict(max_new_tokens=1024, do_sample=True)
334
+ response = self.model.chat(self.tokenizer, pixel_values, query, generation_config)
335
+
336
+ return response
337
+
338
+
339
+ class StreetSoundTextPipeline:
340
+ """Complete pipeline for Street View sound analysis"""
341
+
342
+ def __init__(self, log_dir="logs", model_name="intern_2_5-4B", use_cuda=True):
343
+ # Create log directory if it doesn't exist
344
+ self.log_dir = log_dir
345
+ os.makedirs(log_dir, exist_ok=True)
346
+
347
+ # Initialize components
348
+ self.downloader = StreetViewDownloader()
349
+ self.extractor = PerspectiveExtractor()
350
+ # self.analyzer = ImageAnalyzer(model_name=model_name, use_cuda=use_cuda)
351
+ self.analyzer = None
352
+ self.model_name = model_name
353
+ self.use_cuda = use_cuda
354
+
355
+ def _load_analyzer(self):
356
+ if self.analyzer is None:
357
+ self.analyzer = ImageAnalyzer(model_name=self.model_name, use_cuda=self.use_cuda)
358
+
359
+ def _unload_analyzer(self):
360
+ if self.analyzer is not None:
361
+ if hasattr(self.analyzer, 'model') and self.analyzer.model is not None:
362
+ self.analyzer.model = self.analyzer.model.to("cpu")
363
+ del self.analyzer.model
364
+ self.analyzer.model = None
365
+ torch.cuda.empty_cache()
366
+ self.analyzer = None
367
+
368
+ def process(self, lat, lon, view, panoramic=False):
369
+ """
370
+ Process a location to generate sound description for specified view or all views
371
+
372
+ Args:
373
+ lat (float): Latitude
374
+ lon (float): Longitude
375
+ view (str): Perspective view ('front', 'back', 'left', 'right')
376
+ panoramic (bool): If True, process all views instead of just the specified one
377
+
378
+ Returns:
379
+ dict or list: Results including panorama info and sound description(s)
380
+ """
381
+ if view not in ["front", "back", "left", "right"]:
382
+ raise ValueError(f"Invalid view: {view}. Choose from: front, back, left, right")
383
+
384
+ # Step 1: Download panoramic image
385
+ print(f"Downloading Street View panorama for coordinates: {lat}, {lon}")
386
+
387
+ pano_path = os.path.join(self.log_dir, "panorama.jpg")
388
+ pano_img, pid, plat, plon, p_orient = self.downloader.download_image(lat, lon)
389
+ Image.fromarray(pano_img).save(pano_path)
390
+
391
+ # Step 2: Extract perspective views
392
+ print(f"Extracting perspective views with orientation: {p_orient}°")
393
+ cutouts = self.extractor.extract_views(pano_img, 512)
394
+
395
+ # Save all views
396
+ for v, img in cutouts.items():
397
+ view_path = os.path.join(self.log_dir, f"{v}.jpg")
398
+ Image.fromarray(img).save(view_path)
399
+
400
+ self._load_analyzer()
401
+ print("\n[DEBUG] Current soundscape query:")
402
+ print(soundscape_query)
403
+ print("-" * 50)
404
+ if panoramic:
405
+ # Process all views
406
+ print(f"Analyzing all views for sound information")
407
+ results = []
408
+
409
+ for current_view in ["front", "back", "left", "right"]:
410
+ view_path = os.path.join(self.log_dir, f"{current_view}.jpg")
411
+ sound_description = self.analyzer.analyze_image(view_path)
412
+
413
+ view_result = {
414
+ "panorama_id": pid,
415
+ "coordinates": {"lat": plat, "lon": plon},
416
+ "orientation": p_orient,
417
+ "view": current_view,
418
+ "sound_description": sound_description,
419
+ "files": {
420
+ "panorama": pano_path,
421
+ "view_path": view_path
422
+ }
423
+ }
424
+ results.append(view_result)
425
+
426
+ self._unload_analyzer()
427
+ return results
428
+ else:
429
+ # Process only the selected view
430
+ view_path = os.path.join(self.log_dir, f"{view}.jpg")
431
+ print(f"Analyzing {view} view for sound information")
432
+ sound_description = self.analyzer.analyze_image(view_path)
433
+
434
+ self._unload_analyzer()
435
+
436
+ # Prepare results
437
+ results = {
438
+ "panorama_id": pid,
439
+ "coordinates": {"lat": plat, "lon": plon},
440
+ "orientation": p_orient,
441
+ "view": view,
442
+ "sound_description": sound_description,
443
+ "files": {
444
+ "panorama": pano_path,
445
+ "views": {v: os.path.join(self.log_dir, f"{v}.jpg") for v in cutouts.keys()}
446
+ }
447
+ }
448
+
449
+ return results
450
+
451
+
452
+ def parse_location(location_str):
453
+ """Parse location string in format 'lat,lon' into float tuple"""
454
+ try:
455
+ lat, lon = map(float, location_str.split(','))
456
+ return lat, lon
457
+ except ValueError:
458
+ raise argparse.ArgumentTypeError("Location must be in format 'latitude,longitude'")
459
+
460
+
461
+ def generate_caption(lat, lon, view="front", model="intern_2_5-4B", cpu_only=False, panoramic=False):
462
+ """
463
+ Generate sound captions for one or all views of a street view location
464
+
465
+ Args:
466
+ lat (float/str): Latitude
467
+ lon (float/str): Longitude
468
+ view (str): Perspective view ('front', 'back', 'left', 'right')
469
+ model (str): Model name to use for analysis
470
+ cpu_only (bool): Whether to force CPU usage
471
+ panoramic (bool): If True, process all views instead of just the specified one
472
+
473
+ Returns:
474
+ dict or list: Results with sound descriptions
475
+ """
476
+ pipeline = StreetSoundTextPipeline(
477
+ log_dir=log_dir,
478
+ model_name=model,
479
+ use_cuda=not cpu_only
480
+ )
481
+
482
+ try:
483
+ results = pipeline.process(lat, lon, view, panoramic=panoramic)
484
+
485
+ if panoramic:
486
+ # Process results for all views
487
+ print(f"Generated captions for all views at location: {lat}, {lon}")
488
+ else:
489
+ print(f"Generated caption for {view} view at location: {lat}, {lon}")
490
+
491
+ return results
492
+ except Exception as e:
493
+ print(f"Error: {str(e)}")
494
+ return None
README.md CHANGED
@@ -1,13 +1,50 @@
1
- ---
2
- title: SoundingStreet
3
- emoji: 🏢
4
- colorFrom: gray
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.26.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **A training-free pipeline utilizing pre-trained generative models to synthesize sound for any street on Earth with available Street View panoramic images.**
2
+
3
+ 1. Change to this directory:
4
+ ```
5
+ cd SoundingStreet
6
+ ```
7
+
8
+ 2. Create the conda environment:
9
+ ```
10
+ conda env create -f environment.yml
11
+ conda activate geosynthsound
12
+ ```
13
+
14
+ 3. Make sure to create necessary directories:
15
+ ```
16
+ mkdir -p logs output
17
+ ```
18
+
19
+ 4. Download checkpoint for depth estimator model:
20
+ ```
21
+ wget https://ommer-lab.com/files/depthfm/depthfm-v1.ckpt -P external_models/depth-fm/checkpoints/
22
+ ```
23
+
24
+ 5. Run the `SoundingStreet` demo:
25
+ ```
26
+ python main.py --panoramic --location "52.3436723,4.8529625"
27
+ ```
28
+ Intermediate files such as the downloaded panoramic image and perspective cut-outs can be found in `./logs/`, and output audios for each view as well as the composite audio for the location are saved as `./output/panoramic_composition.wav`
29
+
30
+
31
+ ## Acknowledgements
32
+
33
+ - **InternVL2.5-8B-MPO**
34
+ For vision-language modeling, we employ InternVL2.5-8B-MPO, which is released under the MIT License.
35
+ GitHub: https://github.com/OpenGVLab/InternVL
36
+
37
+ - **Grounding DINO**
38
+ We use Grounding DINO for open-set object detection. Grounding DINO is released under the Apache 2.0 License.
39
+ GitHub: https://github.com/IDEA-Research/GroundingDINO
40
+
41
+ - **DepthFM**
42
+ We utilize the DepthFM model for monocular depth estimation. DepthFM is released under the MIT License.
43
+ GitHub: https://github.com/CompVis/depth-fm
44
+
45
+ - **TangoFlux**
46
+ We incorporate TangoFlux for text-to-audio generation. TangoFlux is available for non-commercial research use only and is subject to the Stability AI Community License, WavCaps license, and the original licenses of the datasets used in training.
47
+ GitHub: https://github.com/declare-lab/TangoFlux
48
+
49
+
50
+ Our repository's license and usage terms adhere to the respective licenses of these models.
SoundMapper.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from DepthEstimator import DepthEstimator
2
+ import numpy as np
3
+ from PIL import Image
4
+ import os
5
+ from GenerateCaptions import generate_caption
6
+ import re
7
+ from config import LOGS_DIR
8
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
9
+ import torch
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ import spacy
12
+ import gc
13
+
14
+ class SoundMapper:
15
+ def __init__(self):
16
+ self.depth_estimator = DepthEstimator()
17
+ # List of depth maps in dict["predicted_depth" ,"depth"] in (tensor, PIL.Image) format
18
+ self.device = "cuda"
19
+ # self.map_list = self.depth_estimator.estimate_depth(self.depth_estimator.image_dir)
20
+ self.map_list = None
21
+ self.image_dir = self.depth_estimator.image_dir
22
+ # self.nlp = spacy.load("en_core_web_sm")
23
+ self.nlp = None
24
+ self.dino = None
25
+ self.dino_processor = None
26
+ # self.dino = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(self.device)
27
+ # self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
28
+
29
+ def _load_nlp(self):
30
+ if self.nlp is None:
31
+ self.nlp = spacy.load("en_core_web_sm")
32
+ return self.nlp
33
+
34
+ def _load_depth_maps(self):
35
+ if self.map_list is None:
36
+ self.map_list = self.depth_estimator.estimate_depth(self.depth_estimator.image_dir)
37
+ return self.map_list
38
+
39
+ def process_depth_maps(self) -> list:
40
+ depth_maps = self._load_depth_maps()
41
+ processed_maps = []
42
+ for item in depth_maps:
43
+ depth_map = item["depth"]
44
+ depth_array = np.array(depth_map)
45
+ normalization = depth_array / 255.0
46
+ processed_maps.append({
47
+ "original": depth_map,
48
+ "normalization": normalization
49
+ })
50
+ return processed_maps
51
+
52
+ # def create_depth_zone(self, processed_maps : list, num_zones = 3):
53
+ # zones_data = []
54
+ # for depth_data in processed_maps:
55
+ # normalized = depth_data["normalization"]
56
+ # thresholds = np.linspace(0, 1, num_zones+1)
57
+ # zones = []
58
+ # for i in range(num_zones):
59
+ # zone_mask = (normalized >= thresholds[i]) & (normalized < thresholds[i+1])
60
+ # zone_percentage = zone_mask.sum() / zone_mask.size
61
+ # zones.append({
62
+ # "range": (thresholds[i], thresholds[i+1]),
63
+ # "percentage": zone_percentage,
64
+ # "mask": zone_mask
65
+ # })
66
+ # zones_data.append(zones)
67
+ # return zones_data
68
+
69
+ def detect_sound_sources(self, caption_text: str) -> dict:
70
+ """
71
+ Extract nouns and their sound descriptions from caption text.
72
+ Returns a dictionary mapping nouns to their descriptions.
73
+ """
74
+ sound_sources = {}
75
+ nlp = self._load_nlp()
76
+
77
+ print(f"\n[DEBUG] Beginning sound source detection")
78
+ print(f"Raw caption text length: {len(caption_text)}")
79
+ print(f"First 100 chars: {caption_text[:100]}...")
80
+
81
+ # Split the caption by newlines to separate entries
82
+ lines = caption_text.strip().split('\n')
83
+ print(f"Found {len(lines)} lines after splitting")
84
+
85
+ for i, line in enumerate(lines):
86
+ # Skip empty lines
87
+ if not line.strip():
88
+ continue
89
+
90
+ print(f"Processing line {i}: {line[:50]}{'...' if len(line) > 50 else ''}")
91
+
92
+ # Check if line matches the expected format (Noun: description)
93
+ if ':' in line:
94
+ parts = line.split(':', 1) # Split only on the first colon
95
+
96
+ # Clean up the noun part - remove numbers and leading/trailing whitespace
97
+ noun_part = parts[0].strip().lower()
98
+ # Remove list numbering (e.g., "1. ", "2. ", etc.)
99
+ noun_part = re.sub(r'^\d+\.\s*', '', noun_part)
100
+
101
+ description = parts[1].strip()
102
+
103
+ # Clean any markdown formatting
104
+ noun = re.sub(r'[*()]', '', noun_part).strip()
105
+ description = re.sub(r'[*()]', '', description).strip()
106
+
107
+ # Separate the description at em dash if present
108
+ if ' — ' in description:
109
+ description = description.split(' — ', 1)[0].strip()
110
+ elif ' - ' in description:
111
+ description = description.split(' - ', 1)[0].strip()
112
+
113
+ print(f" - Found potential noun: '{noun}' with description: '{description[:30]}...'")
114
+
115
+ # Skip if noun contains invalid characters or is too short
116
+ if '##' not in noun and len(noun) > 1 and noun[0].isalpha():
117
+ sound_sources[noun] = description
118
+ print(f" √ Added to sound sources")
119
+ else:
120
+ print(f" × Skipped (invalid format)")
121
+
122
+ # If no structured format found, try to extract nouns from the text
123
+ if not sound_sources:
124
+ print("No structured format found, falling back to noun extraction")
125
+ all_nouns = []
126
+ doc = nlp(caption_text)
127
+ for token in doc:
128
+ if token.pos_ == "NOUN" and len(token.text) > 1:
129
+ if token.text[0].isalpha():
130
+ all_nouns.append(token.text.lower())
131
+ print(f" - Extracted noun: '{token.text.lower()}'")
132
+
133
+ for noun in all_nouns:
134
+ sound_sources[noun] = "" # Empty description
135
+
136
+ print(f"[DEBUG] Final detected sound sources: {list(sound_sources.keys())}")
137
+ return sound_sources
138
+
139
+ def map_bbox_to_depth_zone(self, bbox, depth_map, num_zones=3):
140
+ x1, y1, x2, y2 = [int(coord) for coord in bbox]
141
+
142
+ height, width = depth_map.shape
143
+ x1, y1 = max(0, x1), max(0, y1)
144
+ x2, y2 = min(width, x2), min(height, y2)
145
+
146
+ depth_roi = depth_map[y1:y2, x1:x2]
147
+
148
+ if depth_roi.size == 0:
149
+ return num_zones - 1
150
+
151
+ mean_depth = np.mean(depth_roi)
152
+
153
+ thresholds = self.create_histogram_depth_zones(depth_map, num_zones)
154
+ for i in range(num_zones):
155
+ if thresholds[i] <= mean_depth < thresholds[i+1]:
156
+ return i
157
+ return num_zones - 1
158
+
159
+ def detect_objects(self, nouns : list, image: Image):
160
+ filtered_nouns = []
161
+ for noun in nouns:
162
+ if '##' not in noun and len(noun) > 1 and noun[0].isalpha():
163
+ filtered_nouns.append(noun)
164
+
165
+ print(f"Detecting objects for nouns: {filtered_nouns}")
166
+
167
+ if self.dino is None:
168
+ self.dino = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(self.device)
169
+ self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
170
+ else:
171
+ self.dino = self.dino.to(self.device)
172
+
173
+ text_prompt = " . ".join(filtered_nouns)
174
+ inputs = self.dino_processor(images=image, text=text_prompt, return_tensors="pt").to(self.device)
175
+
176
+ with torch.no_grad():
177
+ outputs = self.dino(**inputs)
178
+ results = self.dino_processor.post_process_grounded_object_detection(
179
+ outputs,
180
+ inputs.input_ids,
181
+ box_threshold=0.25,
182
+ text_threshold=0.25,
183
+ target_sizes=[image.size[::-1]]
184
+ )
185
+
186
+ result = results[0]
187
+ labels = result["labels"]
188
+ bboxes = result["boxes"]
189
+
190
+ clean_labels = []
191
+ for label in labels:
192
+ clean_label = re.sub(r'##\w+', '', label)
193
+ clean_label = self._split_combined_words(clean_label, filtered_nouns)
194
+ clean_labels.append(clean_label)
195
+
196
+ self.dino = self.dino.to("cpu")
197
+ torch.cuda.empty_cache()
198
+ del inputs, outputs, results
199
+
200
+ print(f"Detected objects: {clean_labels}")
201
+
202
+ return (clean_labels, bboxes)
203
+
204
+ def _split_combined_words(self, text, nouns=None):
205
+ nlp = self._load_nlp()
206
+ if nouns is None:
207
+ known_words = set()
208
+ doc = nlp(text)
209
+ for token in doc:
210
+ if token.pos_ == "NOUN" and len(token.text) > 1:
211
+ known_words.add(token.text.lower())
212
+ else:
213
+ known_words = set(nouns)
214
+
215
+ result = []
216
+ for word in text.split():
217
+ if word in known_words:
218
+ result.append(word)
219
+ continue
220
+
221
+ found = False
222
+ for known in known_words:
223
+ if known in word and len(known) > 2:
224
+ result.append(known)
225
+ found = True
226
+
227
+ if not found:
228
+ result.append(word)
229
+
230
+ return " ".join(result)
231
+
232
+ def process_dino_labels(self, labels):
233
+ processed_labels = []
234
+ nlp = self._load_nlp()
235
+
236
+ for label in labels:
237
+ if label.startswith('##'):
238
+ continue
239
+ label = re.sub(r'[*()]', '', label).strip()
240
+
241
+ parts = label.split()
242
+ for part in parts:
243
+ if part.startswith('##'):
244
+ continue
245
+ doc = nlp(part)
246
+ for token in doc:
247
+ if token.pos_ == "NOUN" and len(token.text) > 1:
248
+ processed_labels.append(token.text.lower())
249
+
250
+ unique_labels = []
251
+ for label in processed_labels:
252
+ if label not in unique_labels:
253
+ unique_labels.append(label)
254
+
255
+ return unique_labels
256
+
257
+
258
+ def create_histogram_depth_zones(self, depth_map, num_zones = 3):
259
+ # using 50 bins because it is faster
260
+ hist, bin_edge = np.histogram(depth_map.flatten(), bins=50, range=(0, 1))
261
+ cumulative = np.cumsum(hist) / np.sum(hist)
262
+ thresholds = [0.0]
263
+ for i in range(1, num_zones):
264
+ target = i / num_zones
265
+ idx = np.argmin(np.abs(cumulative - target))
266
+ thresholds.append(bin_edge[idx + 1])
267
+ thresholds.append(1.0)
268
+
269
+ return thresholds
270
+
271
+
272
+ def analyze_object_depths(self, image_path, depth_map, lat, lon, caption_data=None, all_objects=False):
273
+ image = Image.open(image_path)
274
+
275
+ if caption_data is None:
276
+ caption = generate_caption(lat, lon)
277
+ if not caption:
278
+ print(f"Failed to generate caption for {image_path}")
279
+ return []
280
+ caption_text = caption.get("sound_description", "")
281
+ else:
282
+ caption_text = caption_data.get("sound_description", "")
283
+
284
+ # Debug: Print the raw caption text
285
+ print(f"\n[DEBUG] Raw caption text for {os.path.basename(image_path)}:")
286
+ print(caption_text)
287
+ print("-" * 50)
288
+
289
+ if not caption_text:
290
+ print(f"No caption text available for {image_path}")
291
+ return []
292
+
293
+ # Extract nouns and their sound descriptions
294
+ sound_sources = self.detect_sound_sources(caption_text)
295
+
296
+ # Debug: Print the extracted sound sources
297
+ print(f"[DEBUG] Extracted sound sources:")
298
+ for noun, desc in sound_sources.items():
299
+ print(f" - {noun}: {desc}")
300
+ print("-" * 50)
301
+
302
+ if not sound_sources:
303
+ print(f"No sound sources detected in caption for {image_path}")
304
+ return []
305
+
306
+ # Get list of nouns only for object detection
307
+ nouns = list(sound_sources.keys())
308
+
309
+ # Debug: Print the list of nouns being used for detection
310
+ print(f"[DEBUG] Nouns for object detection: {nouns}")
311
+ print("-" * 50)
312
+
313
+ labels, bboxes = self.detect_objects(nouns, image)
314
+ if len(labels) == 0 or len(bboxes) == 0:
315
+ print(f"No objects detected in {image_path}")
316
+ return []
317
+
318
+ object_data = []
319
+ known_objects = set(nouns) if nouns else set()
320
+
321
+ for i, (label, bbox) in enumerate(zip(labels, bboxes)):
322
+ if '##' in label:
323
+ continue
324
+
325
+ x1, y1, x2, y2 = [int(coord) for coord in bbox]
326
+ height, width = depth_map.shape
327
+ x1, y1 = max(0, x1), max(0, y1)
328
+ x2, y2 = min(width, x2), min(height, y2)
329
+
330
+ depth_roi = depth_map[y1:y2, x1:x2]
331
+ if depth_roi.size == 0:
332
+ continue
333
+
334
+ mean_depth = np.mean(depth_roi)
335
+
336
+ matched_noun = None
337
+ matched_desc = None
338
+
339
+ for word in label.split():
340
+ word = word.lower()
341
+ if word in sound_sources:
342
+ matched_noun = word
343
+ matched_desc = sound_sources[word]
344
+ break
345
+ if matched_noun is None:
346
+ for noun in sound_sources:
347
+ if noun in label.lower():
348
+ matched_noun = noun
349
+ matched_desc = sound_sources[noun]
350
+ break
351
+ if matched_noun is None:
352
+ for word in label.split():
353
+ if len(word) > 1 and word[0].isalpha() and '##' not in word:
354
+ matched_noun = word.lower()
355
+ matched_desc = "" # No description available
356
+ break
357
+
358
+ if matched_noun:
359
+ thresholds = self.create_histogram_depth_zones(depth_map, num_zones=3)
360
+ zone = 0 # The default is 0 which is the closest zone
361
+ for i in range(3):
362
+ if thresholds[i] <= mean_depth < thresholds[i+1]:
363
+ zone = i
364
+ break
365
+
366
+ object_data.append({
367
+ "original_label": matched_noun,
368
+ "bbox": bbox.tolist(),
369
+ "depth_zone": zone,
370
+ "zone_description": ["near", "medium", "far"][zone],
371
+ "mean_depth": mean_depth,
372
+ "weight": 1.0 - mean_depth,
373
+ "sound_description": matched_desc
374
+ })
375
+ if all_objects:
376
+ object_data.sort(key=lambda x: x["mean_depth"])
377
+ return object_data
378
+ else:
379
+ if not object_data:
380
+ return []
381
+ closest_object = min(object_data, key=lambda x: x["mean_depth"])
382
+ return [closest_object]
383
+
384
+ def cleanup(self):
385
+ if hasattr(self, 'depth_estimator') and self.depth_estimator is not None:
386
+ del self.depth_estimator
387
+ self.depth_estimator = None
388
+
389
+ if self.map_list is not None:
390
+ del self.map_list
391
+ self.map_list = None
392
+
393
+ if self.dino is not None:
394
+ self.dino = self.dino.to("cpu")
395
+ del self.dino
396
+ self.dino = None
397
+ del self.dino_processor
398
+ self.dino_processor = None
399
+
400
+ if self.nlp is not None:
401
+ del self.nlp
402
+ self.nlp = None
403
+ torch.cuda.empty_cache()
404
+ gc.collect()
405
+
406
+ def test_object_depth_analysis(self):
407
+ """
408
+ Test the object depth analysis on all images in the directory.
409
+ """
410
+ # Process depth maps first
411
+ processed_maps = self.process_depth_maps()
412
+
413
+ # Get list of original image paths
414
+ image_dir = self.depth_estimator.image_dir
415
+ image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".jpg")]
416
+
417
+ results = []
418
+
419
+ # For each image and its corresponding depth map
420
+ for i, (image_path, processed_map) in enumerate(zip(image_paths, processed_maps)):
421
+ # Extract the normalized depth map
422
+ depth_map = processed_map["normalization"]
423
+
424
+ # Analyze objects and their depths
425
+ object_depths = self.analyze_object_depths(image_path, depth_map)
426
+
427
+ # Store results
428
+ results.append({
429
+ "image_path": image_path,
430
+ "object_depths": object_depths
431
+ })
432
+
433
+ # Print some information for debugging
434
+ print(f"Analyzed {image_path}:")
435
+ for obj in object_depths:
436
+ print(f" - {obj['original_label']} (Zone: {obj['zone_description']})")
437
+
438
+ return results
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import torch
7
+ import torchaudio
8
+
9
+ from config import LOGS_DIR, OUTPUT_DIR
10
+ from SoundMapper import SoundMapper
11
+ from GenerateAudio import GenerateAudio
12
+ from GenerateCaptions import generate_caption
13
+ from audio_mixer import compose_audio
14
+
15
+ # Ensure required directories exist
16
+ os.makedirs(LOGS_DIR, exist_ok=True)
17
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
18
+ # Prepare external model dir and download checkpoint if missing
19
+ from pathlib import Path
20
+ depthfm_ckpt = Path('external_models/depth-fm/checkpoints/depthfm-v1.ckpt')
21
+ if not depthfm_ckpt.exists():
22
+ depthfm_ckpt.parent.mkdir(parents=True, exist_ok=True)
23
+ os.system('wget https://ommer-lab.com/files/depthfm/depthfm-v1.ckpt -P external_models/depth-fm/checkpoints/')
24
+
25
+
26
+ # Clear CUDA cache between runs
27
+ def clear_cuda():
28
+ torch.cuda.empty_cache()
29
+ gc.collect()
30
+
31
+
32
+ def process_images(
33
+ image_dir: str,
34
+ output_dir: str,
35
+ panoramic: bool,
36
+ view: str,
37
+ model: str,
38
+ location: str,
39
+ audio_duration: int,
40
+ cpu_only: bool
41
+ ) -> None:
42
+ # Existing processing logic, generates files in OUTPUT_DIR
43
+ lat, lon = location.split(",")
44
+ os.makedirs(output_dir, exist_ok=True)
45
+ sound_mapper = SoundMapper()
46
+ audio_generator = GenerateAudio()
47
+
48
+ if panoramic:
49
+ # Panoramic: generate per-view audio then composition
50
+ view_results = generate_caption(lat, lon, view=view, model=model,
51
+ cpu_only=cpu_only, panoramic=True)
52
+ processed_maps = sound_mapper.process_depth_maps()
53
+ image_paths = sorted(Path(image_dir).glob("*.jpg"))
54
+ audios = {}
55
+ for vr in view_results:
56
+ cv = vr["view"]
57
+ img_file = Path(image_dir) / f"{cv}.jpg"
58
+ if not img_file.exists():
59
+ continue
60
+ idx = [i for i, p in enumerate(image_paths) if p.name == img_file.name]
61
+ if not idx:
62
+ continue
63
+ depth_map = processed_maps[idx[0]]["normalization"]
64
+ obj_depths = sound_mapper.analyze_object_depths(
65
+ str(img_file), depth_map, lat, lon,
66
+ caption_data=vr, all_objects=False
67
+ )
68
+ if not obj_depths:
69
+ continue
70
+ out_wav = Path(output_dir) / f"sound_{cv}.wav"
71
+ audio, sr = audio_generator.process_and_generate_audio(
72
+ obj_depths, duration=audio_duration
73
+ )
74
+ if audio.dim() == 3:
75
+ audio = audio.squeeze(0)
76
+ elif audio.dim() == 1:
77
+ audio = audio.unsqueeze(0)
78
+ torchaudio.save(str(out_wav), audio, sr)
79
+ audios[cv] = str(out_wav)
80
+ # final panoramic composition
81
+ comp = Path(output_dir) / "panoramic_composition.wav"
82
+ compose_audio(list(audios.values()), [1.0]*len(audios), str(comp))
83
+ audios['panorama'] = str(comp)
84
+ clear_cuda()
85
+ return
86
+
87
+ # Single-view: generate one audio
88
+ vr = generate_caption(lat, lon, view=view, model=model,
89
+ cpu_only=cpu_only, panoramic=False)
90
+ img_file = Path(image_dir) / f"{view}.jpg"
91
+ processed_maps = sound_mapper.process_depth_maps()
92
+ image_paths = sorted(Path(image_dir).glob("*.jpg"))
93
+ idx = [i for i, p in enumerate(image_paths) if p.name == img_file.name]
94
+ depth_map = processed_maps[idx[0]]["normalization"]
95
+ obj_depths = sound_mapper.analyze_object_depths(
96
+ str(img_file), depth_map, lat, lon,
97
+ caption_data=vr, all_objects=True
98
+ )
99
+ out_wav = Path(output_dir) / f"sound_{view}.wav"
100
+ audio, sr = audio_generator.process_and_generate_audio(obj_depths, duration=audio_duration)
101
+ if audio.dim() == 3:
102
+ audio = audio.squeeze(0)
103
+ elif audio.dim() == 1:
104
+ audio = audio.unsqueeze(0)
105
+ torchaudio.save(str(out_wav), audio, sr)
106
+ clear_cuda()
107
+
108
+ # Gradio UI
109
+ demo = gr.Blocks(title="Panoramic Audio Generator")
110
+ with demo:
111
+ gr.Markdown("""
112
+ # Panoramic Audio Generator
113
+
114
+ Displays each view with its audio side by side.
115
+ """
116
+ )
117
+
118
+ with gr.Row():
119
+ panoramic = gr.Checkbox(label="Panoramic (multi-view)", value=False)
120
+ view = gr.Dropdown(["front", "back", "left", "right"], value="front", label="View")
121
+ location = gr.Textbox(value="52.3436723,4.8529625", label="Location (lat,lon)")
122
+ # model = gr.Textbox(value="intern_2_5-4B", label="Vision-Language Model")
123
+ model = "intern_2_5-4B"
124
+ audio_duration = gr.Slider(1, 60, value=10, step=1, label="Audio Duration (sec)")
125
+ cpu_only = gr.Checkbox(label="CPU Only", value=False)
126
+ btn = gr.Button("Generate")
127
+
128
+ # Output layout: two rows of two
129
+ with gr.Row():
130
+ with gr.Column():
131
+ img_front = gr.Image(label="Front View", type="filepath")
132
+ aud_front = gr.Audio(label="Front Audio", type="filepath")
133
+ with gr.Column():
134
+ img_back = gr.Image(label="Back View", type="filepath")
135
+ aud_back = gr.Audio(label="Back Audio", type="filepath")
136
+ with gr.Row():
137
+ with gr.Column():
138
+ img_left = gr.Image(label="Left View", type="filepath")
139
+ aud_left = gr.Audio(label="Left Audio", type="filepath")
140
+ with gr.Column():
141
+ img_right = gr.Image(label="Right View", type="filepath")
142
+ aud_right = gr.Audio(label="Right Audio", type="filepath")
143
+ # Panorama at bottom
144
+ img_pan = gr.Image(label="Panorama View", type="filepath")
145
+ aud_pan = gr.Audio(label="Panoramic Audio", type="filepath")
146
+
147
+ # Preview update
148
+ def run_all(pan, vw, loc, mdl, dur, cpu):
149
+ # generate files
150
+ process_images(LOGS_DIR, OUTPUT_DIR, pan, vw, mdl, loc, dur, cpu)
151
+ # collect files
152
+ views = ["front", "back", "left", "right", "panorama"]
153
+ paths = {}
154
+ for v in views:
155
+ img = Path(LOGS_DIR) / f"{v}.jpg"
156
+ audio = Path(OUTPUT_DIR) / ("panoramic_composition.wav" if v == "panorama" else f"sound_{v}.wav")
157
+ paths[v] = {
158
+ 'img': str(img) if img.exists() else None,
159
+ 'aud': str(audio) if audio.exists() else None
160
+ }
161
+ return (
162
+ paths['front']['img'], paths['front']['aud'],
163
+ paths['back']['img'], paths['back']['aud'],
164
+ paths['left']['img'], paths['left']['aud'],
165
+ paths['right']['img'], paths['right']['aud'],
166
+ paths['panorama']['img'], paths['panorama']['aud']
167
+ )
168
+
169
+ btn.click(
170
+ fn=run_all,
171
+ inputs=[panoramic, view, location, model, audio_duration, cpu_only],
172
+ outputs=[
173
+ img_front, aud_front,
174
+ img_back, aud_back,
175
+ img_left, aud_left,
176
+ img_right, aud_right,
177
+ img_pan, aud_pan
178
+ ]
179
+ )
180
+
181
+ if __name__ == "__main__":
182
+ demo.launch(share=True)
audio_mixer.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchaudio
4
+ import torchaudio.transforms as T
5
+ import matplotlib.pyplot as plt
6
+ import os
7
+ from typing import List, Tuple
8
+ from config import LOGS_DIR
9
+
10
+
11
+
12
+ ##Some utils:
13
+ def load_audio_files(file_paths: List[str]) -> List[Tuple[torch.Tensor, int]]:
14
+ """
15
+ Load multiple audio files and ensure they have the same length.
16
+
17
+ Args:
18
+ file_paths: List of paths to audio files
19
+
20
+ Returns:
21
+ List of tuples containing audio data and sample rate
22
+ """
23
+ audio_data = []
24
+
25
+ for path in file_paths:
26
+ # Load audio file
27
+ waveform, sample_rate = torchaudio.load(path)
28
+ # Convert to mono if stereo
29
+ if waveform.shape[0] > 1:
30
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
31
+ audio_data.append((waveform.squeeze(), sample_rate))
32
+
33
+ # Verify all audio files have the same length and sample rate
34
+ lengths = [len(audio) for audio, _ in audio_data]
35
+ sample_rates = [sr for _, sr in audio_data]
36
+
37
+ if len(set(lengths)) > 1:
38
+ raise ValueError(f"Audio files have different lengths: {lengths}")
39
+ if len(set(sample_rates)) > 1:
40
+ raise ValueError(f"Audio files have different sample rates: {sample_rates}")
41
+
42
+ return audio_data
43
+
44
+
45
+ def normalize_audio_volumes(audio_data: List[Tuple[torch.Tensor, int]]) -> List[Tuple[torch.Tensor, int]]:
46
+ """
47
+ Normalize the volume of each audio file to have the same energy level.
48
+
49
+ Args:
50
+ audio_data: List of tuples containing audio data and sample rate
51
+
52
+ Returns:
53
+ List of tuples containing normalized audio data and sample rate
54
+ """
55
+ normalized_data = []
56
+
57
+ # Calculate RMS (Root Mean Square) for each audio
58
+ rms_values = []
59
+ for audio, sr in audio_data:
60
+ # Calculate energy (squared amplitude)
61
+ energy = torch.mean(audio ** 2)
62
+ # Calculate RMS (square root of mean energy)
63
+ rms = torch.sqrt(energy)
64
+ rms_values.append(rms)
65
+
66
+ # Find the target RMS (we'll use the median to avoid outliers)
67
+ target_rms = torch.median(torch.tensor(rms_values))
68
+
69
+ # Normalize each audio to the target RMS
70
+ for (audio, sr), rms in zip(audio_data, rms_values):
71
+ if rms > 0: # Avoid division by zero
72
+ # Calculate scaling factor
73
+ scaling_factor = target_rms / rms
74
+ # Apply scaling
75
+ normalized_audio = audio * scaling_factor
76
+ else:
77
+ normalized_audio = audio
78
+
79
+ normalized_data.append((normalized_audio, sr))
80
+
81
+ return normalized_data
82
+
83
+ def plot_energy_comparison(original_metrics: List[dict], normalized_metrics: List[dict], file_names: List[str], output_path: str = "./logs/energy_comparison.png") -> None:
84
+ """
85
+ Plot a comparison of energy metrics before and after normalization.
86
+
87
+ Args:
88
+ original_metrics: List of dictionaries containing metrics for original audio
89
+ normalized_metrics: List of dictionaries containing metrics for normalized audio
90
+ file_names: List of audio file names
91
+ output_path: Path to save the plot
92
+ """
93
+ fig, axs = plt.subplots(2, 2, figsize=(14, 10))
94
+
95
+ # Extract metrics
96
+ orig_rms = [m['rms'] for m in original_metrics]
97
+ norm_rms = [m['rms'] for m in normalized_metrics]
98
+
99
+ orig_peak = [m['peak'] for m in original_metrics]
100
+ norm_peak = [m['peak'] for m in normalized_metrics]
101
+
102
+ orig_dr = [m['dynamic_range_db'] for m in original_metrics]
103
+ norm_dr = [m['dynamic_range_db'] for m in normalized_metrics]
104
+
105
+ orig_cf = [m['crest_factor'] for m in original_metrics]
106
+ norm_cf = [m['crest_factor'] for m in normalized_metrics]
107
+
108
+ # Prepare x-axis
109
+ x = np.arange(len(file_names))
110
+ width = 0.35
111
+
112
+ # Plot RMS (volume)
113
+ axs[0, 0].bar(x - width/2, orig_rms, width, label='Original')
114
+ axs[0, 0].bar(x + width/2, norm_rms, width, label='Normalized')
115
+ axs[0, 0].set_title('RMS Energy (Volume)')
116
+ axs[0, 0].set_xticks(x)
117
+ axs[0, 0].set_xticklabels(file_names, rotation=45, ha='right')
118
+ axs[0, 0].set_ylabel('RMS Value')
119
+ axs[0, 0].legend()
120
+
121
+ # Plot Peak Amplitude
122
+ axs[0, 1].bar(x - width/2, orig_peak, width, label='Original')
123
+ axs[0, 1].bar(x + width/2, norm_peak, width, label='Normalized')
124
+ axs[0, 1].set_title('Peak Amplitude')
125
+ axs[0, 1].set_xticks(x)
126
+ axs[0, 1].set_xticklabels(file_names, rotation=45, ha='right')
127
+ axs[0, 1].set_ylabel('Peak Value')
128
+ axs[0, 1].legend()
129
+
130
+ # Plot Dynamic Range
131
+ axs[1, 0].bar(x - width/2, orig_dr, width, label='Original')
132
+ axs[1, 0].bar(x + width/2, norm_dr, width, label='Normalized')
133
+ axs[1, 0].set_title('Dynamic Range (dB)')
134
+ axs[1, 0].set_xticks(x)
135
+ axs[1, 0].set_xticklabels(file_names, rotation=45, ha='right')
136
+ axs[1, 0].set_ylabel('dB')
137
+ axs[1, 0].legend()
138
+
139
+ # Plot Crest Factor
140
+ axs[1, 1].bar(x - width/2, orig_cf, width, label='Original')
141
+ axs[1, 1].bar(x + width/2, norm_cf, width, label='Normalized')
142
+ axs[1, 1].set_title('Crest Factor (Peak-to-RMS Ratio)')
143
+ axs[1, 1].set_xticks(x)
144
+ axs[1, 1].set_xticklabels(file_names, rotation=45, ha='right')
145
+ axs[1, 1].set_ylabel('Ratio')
146
+ axs[1, 1].legend()
147
+
148
+ plt.tight_layout()
149
+
150
+ # Create directory if it doesn't exist
151
+ os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
152
+
153
+ # Save the plot
154
+ plt.savefig(output_path)
155
+ plt.close()
156
+
157
+ def calculate_audio_metrics(audio_data: List[Tuple[torch.Tensor, int]]) -> List[dict]:
158
+ """
159
+ Calculate various audio metrics for each audio file.
160
+
161
+ Args:
162
+ audio_data: List of tuples containing audio data and sample rate
163
+
164
+ Returns:
165
+ List of dictionaries containing metrics
166
+ """
167
+ metrics = []
168
+
169
+ for audio, sr in audio_data:
170
+ # Calculate RMS (Root Mean Square)
171
+ energy = torch.mean(audio ** 2)
172
+ rms = torch.sqrt(energy)
173
+
174
+ # Calculate peak amplitude
175
+ peak = torch.max(torch.abs(audio))
176
+
177
+ # Calculate dynamic range
178
+ if torch.min(torch.abs(audio[audio != 0])) > 0:
179
+ min_non_zero = torch.min(torch.abs(audio[audio != 0]))
180
+ dynamic_range_db = 20 * torch.log10(peak / min_non_zero)
181
+ else:
182
+ dynamic_range_db = torch.tensor(float('inf'))
183
+
184
+ # Calculate crest factor (peak to RMS ratio)
185
+ crest_factor = peak / rms if rms > 0 else torch.tensor(float('inf'))
186
+
187
+ metrics.append({
188
+ 'rms': rms.item(),
189
+ 'peak': peak.item(),
190
+ 'dynamic_range_db': dynamic_range_db.item() if not torch.isinf(dynamic_range_db) else float('inf'),
191
+ 'crest_factor': crest_factor.item() if not torch.isinf(crest_factor) else float('inf')
192
+ })
193
+
194
+ return metrics
195
+
196
+
197
+ def create_weighted_composite(
198
+ audio_data: List[Tuple[torch.Tensor, int]],
199
+ weights: List[float]
200
+ ) -> torch.Tensor:
201
+ """
202
+ Create a weighted composite of multiple audio files.
203
+
204
+ Args:
205
+ audio_data: List of tuples containing audio data and sample rate
206
+ weights: List of weights for each audio file
207
+
208
+ Returns:
209
+ Weighted composite audio data
210
+ """
211
+ if len(audio_data) != len(weights):
212
+ raise ValueError("Number of audio files and weights must match")
213
+
214
+ # Normalize weights to sum to 1
215
+ weights = torch.tensor(weights) / sum(weights)
216
+
217
+ # Initialize composite audio with zeros
218
+ composite = torch.zeros_like(audio_data[0][0])
219
+
220
+ # Add weighted audio data
221
+ for (audio, _), weight in zip(audio_data, weights):
222
+ composite += audio * weight
223
+
224
+ # Normalize to prevent clipping
225
+ max_val = torch.max(torch.abs(composite))
226
+ if max_val > 1.0:
227
+ composite = composite / max_val
228
+
229
+ return composite
230
+
231
+
232
+ def create_melspectrograms(
233
+ audio_data: List[Tuple[torch.Tensor, int]],
234
+ composite: torch.Tensor,
235
+ sr: int
236
+ ) -> List[torch.Tensor]:
237
+ """
238
+ Create melspectrograms for individual audio files and the composite.
239
+
240
+ Args:
241
+ audio_data: List of tuples containing audio data and sample rate
242
+ composite: Composite audio data
243
+ sr: Sample rate
244
+
245
+ Returns:
246
+ List of melspectrogram data
247
+ """
248
+ specs = []
249
+
250
+ # Create mel spectrogram transform
251
+ mel_transform = T.MelSpectrogram(
252
+ sample_rate=sr,
253
+ n_fft=2048,
254
+ win_length=2048,
255
+ hop_length=512,
256
+ n_mels=128,
257
+ f_max=8000
258
+ )
259
+
260
+ # Generate spectrograms for individual audio files
261
+ for audio, _ in audio_data:
262
+ melspec = mel_transform(audio)
263
+ specs.append(melspec)
264
+
265
+ # Generate spectrogram for composite audio
266
+ composite_melspec = mel_transform(composite)
267
+ specs.append(composite_melspec)
268
+
269
+ return specs
270
+
271
+
272
+ def plot_melspectrograms(
273
+ specs: List[torch.Tensor],
274
+ sr: int,
275
+ file_names: List[str],
276
+ weights: List[float],
277
+ output_path: str = "melspectrograms.png"
278
+ ) -> None:
279
+ """
280
+ Plot melspectrograms for individual audio files and the composite.
281
+
282
+ Args:
283
+ specs: List of melspectrogram data
284
+ sr: Sample rate
285
+ file_names: List of audio file names
286
+ weights: List of weights for each audio file
287
+ output_path: Path to save the plot
288
+ """
289
+ fig, axs = plt.subplots(len(specs), 1, figsize=(12, 4 * len(specs)))
290
+
291
+ # Create labels for the plots
292
+ labels = [f"{name} (weight: {weight:.2f})" for name, weight in zip(file_names, weights)]
293
+ labels.append("Composite.wav")
294
+
295
+ # Convert to dB scale (similar to librosa's power_to_db)
296
+ def power_to_db(spec):
297
+ return 10 * torch.log10(spec + 1e-10)
298
+
299
+ # Plot each melspectrogram
300
+ for i, (spec, label) in enumerate(zip(specs, labels)):
301
+ spec_db = power_to_db(spec).numpy().squeeze()
302
+
303
+ # For single subplot case
304
+ if len(specs) == 1:
305
+ ax = axs
306
+ else:
307
+ ax = axs[i]
308
+
309
+ img = ax.imshow(
310
+ spec_db,
311
+ aspect='auto',
312
+ origin='lower',
313
+ interpolation='none',
314
+ extent=[0, spec_db.shape[1], 0, sr/2]
315
+ )
316
+ ax.set_title(label)
317
+ ax.set_ylabel('Frequency (Hz)')
318
+ ax.set_xlabel('Time Frames')
319
+
320
+ # No colorbar as requested
321
+
322
+ plt.tight_layout()
323
+
324
+ # Create directory if it doesn't exist
325
+ os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
326
+ # Save the plot
327
+ plt.savefig(output_path,dpi=300)
328
+ plt.close()
329
+
330
+
331
+ def compose_audio(
332
+ file_paths: List[str],
333
+ weights: List[float],
334
+ output_audio_path: str = os.path.join(LOGS_DIR, "composite.wav"),
335
+ output_plot_path: str = os.path.join(LOGS_DIR, "plot/melspectrograms.png"),
336
+ energy_plot_path: str = os.path.join(LOGS_DIR, "plot/energy_comparison.png")
337
+ ) -> None:
338
+ """
339
+ Main function to process audio files and create visualizations.
340
+
341
+ Args:
342
+ file_paths: List of paths to audio files (supports 4 audio files)
343
+ weights: List of weights for each audio file
344
+ output_audio_path: Path to save the composite audio
345
+ output_plot_path: Path to save the melspectrogram plot
346
+ energy_plot_path: Path to save the energy comparison plot
347
+ """
348
+ # Load audio files
349
+ audio_data = load_audio_files(file_paths)
350
+
351
+ # # Calculate metrics for original audio
352
+ print("Calculating metrics for original audio...")
353
+ original_metrics = calculate_audio_metrics(audio_data)
354
+
355
+ # Normalize audio volumes to have same energy level
356
+ print("Normalizing audio volumes...")
357
+ normalized_audio_data = normalize_audio_volumes(audio_data)
358
+
359
+ # Calculate metrics for normalized audio
360
+ print("Calculating metrics for normalized audio...")
361
+ normalized_metrics = calculate_audio_metrics(normalized_audio_data)
362
+
363
+ # Print energy comparison
364
+ print("\nAudio Energy Comparison (RMS values):")
365
+ print("-" * 50)
366
+ print(f"{'File':<20} {'Original':<15} {'Normalized':<15} {'Scaling Factor':<15}")
367
+ print("-" * 50)
368
+ for i, path in enumerate(file_paths):
369
+ file_name = path.split("/")[-1]
370
+ orig_rms = original_metrics[i]['rms']
371
+ norm_rms = normalized_metrics[i]['rms']
372
+ scaling = norm_rms / orig_rms if orig_rms > 0 else float('inf')
373
+ print(f"{file_name[:20]:<20} {orig_rms:<15.6f} {norm_rms:<15.6f} {scaling:<15.6f}")
374
+
375
+ # Create energy comparison plot
376
+ print("\nCreating energy comparison plot...")
377
+ file_names = [path.split("/")[-1] for path in file_paths]
378
+ plot_energy_comparison(original_metrics, normalized_metrics, file_names, energy_plot_path)
379
+
380
+ # Get sample rate (all files have the same sample rate)
381
+ sr = normalized_audio_data[0][1]
382
+
383
+ # Create weighted composite
384
+ print("\nCreating weighted composite...")
385
+ composite = create_weighted_composite(normalized_audio_data, weights)
386
+
387
+ # Create directory if it doesn't exist
388
+ os.makedirs(os.path.dirname(output_audio_path) or '.', exist_ok=True)
389
+
390
+ # Save composite audio
391
+ print("Saving composite audio...")
392
+ torchaudio.save(output_audio_path, composite.unsqueeze(0), sr)
393
+
394
+ # Create melspectrograms for normalized audio (not original)
395
+ print("Creating melspectrograms for normalized audio...")
396
+ specs = create_melspectrograms(normalized_audio_data, composite, sr)
397
+
398
+ # Get file names without path
399
+ labeled_file_names = [path.split("/")[-1] for path in file_paths]
400
+
401
+ # Plot melspectrograms
402
+ print("Plotting melspectrograms...")
403
+ plot_melspectrograms(specs, sr, labeled_file_names, weights, output_plot_path)
404
+
405
+ print(f"\nComposite audio saved to {output_audio_path}")
406
+ print(f"Melspectrograms saved to {output_plot_path}")
407
+ print(f"Energy comparison saved to {energy_plot_path}")
408
+
409
+ print(f"Composite audio saved to {output_audio_path}")
410
+ print(f"Melspectrograms saved to {output_plot_path}")
411
+
412
+
413
+ # if __name__ == "__main__":
414
+ # import argparse
415
+
416
+ # parser = argparse.ArgumentParser(description="Mix audio files with weights and create melspectrograms")
417
+ # parser.add_argument("--files", nargs="+", required=True, help="Paths to audio files")
418
+ # parser.add_argument("--weights", nargs="+", type=float, required=True, help="Weights for each audio file")
419
+ # parser.add_argument("--output-audio", default="./logs/composite.wav", help="Path to save the composite audio")
420
+ # parser.add_argument("--output-plot", default="./logs/melspectrograms.png", help="Path to save the melspectrogram plot")
421
+
422
+ # args = parser.parse_args()
423
+ # os.makedirs("./logs", exist_ok=True)
424
+ # main(args.files, args.weights, args.output_audio, args.output_plot)
425
+
426
+
427
+ # Example usage:
428
+ # python audio_mixer.py --files audio1.wav audio2.wav audio3.wav audio4.wav --weights 0.4 0.3 0.2 0.1
config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Base directories
4
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
5
+ LOGS_DIR = os.path.join(BASE_DIR, "logs")
6
+ OUTPUT_DIR = os.path.join(BASE_DIR, "output")
7
+
8
+ # Model paths
9
+ EXTERNAL_MODELS_DIR = os.path.join(BASE_DIR, "external_models")
10
+ DEPTH_FM_DIR = os.path.join(EXTERNAL_MODELS_DIR, "depth-fm")
11
+ DEPTH_FM_CHECKPOINT = os.path.join(DEPTH_FM_DIR, "checkpoints/depthfm-v1.ckpt") # You will need to download the checkpoint manually. Here is the link: https://github.com/CompVis/depth-fm/tree/main/checkpoints
12
+ TANGO_FLUX_DIR = os.path.join(EXTERNAL_MODELS_DIR, "TangoFlux")
13
+
14
+ # Create required directories
15
+ os.makedirs(LOGS_DIR, exist_ok=True)
16
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
environment.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: geosynthsound
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - python=3.10
7
+ - pip:
8
+ - -r requirements.txt
external_models/TangoFlux/.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+
5
+ # C extensions
6
+ *.so
7
+
8
+ # Distribution / packaging
9
+ .Python
10
+ build/
11
+ develop-eggs/
12
+ dist/
13
+ downloads/
14
+ eggs/
15
+ .eggs/
16
+ lib/
17
+ lib64/
18
+ parts/
19
+ sdist/
20
+ var/
21
+ wheels/
22
+ share/python-wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .nox/
42
+ .coverage
43
+ .coverage.*
44
+ .cache
45
+ nosetests.xml
46
+ coverage.xml
47
+ *.cover
48
+ *.py,cover
49
+ .hypothesis/
50
+ .pytest_cache/
51
+ cover/
52
+
53
+ # Translations
54
+ *.mo
55
+ *.pot
56
+
57
+ # Django stuff:
58
+ *.log
59
+ local_settings.py
60
+ db.sqlite3
61
+ db.sqlite3-journal
62
+
63
+ # Flask stuff:
64
+ instance/
65
+ .webassets-cache
66
+
67
+ # Scrapy stuff:
68
+ .scrapy
69
+
70
+ # Sphinx documentation
71
+ docs/_build/
72
+
73
+ # PyBuilder
74
+ .pybuilder/
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ # For a library or package, you might want to ignore these files since the code is
86
+ # intended to run in multiple environments; otherwise, check them in:
87
+ # .python-version
88
+
89
+ # pipenv
90
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
92
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
93
+ # install all needed dependencies.
94
+ #Pipfile.lock
95
+
96
+ # UV
97
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
98
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
99
+ # commonly ignored for libraries.
100
+ #uv.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
115
+ .pdm.toml
116
+ .pdm-python
117
+ .pdm-build/
118
+
119
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120
+ __pypackages__/
121
+
122
+ # Celery stuff
123
+ celerybeat-schedule
124
+ celerybeat.pid
125
+
126
+ # SageMath parsed files
127
+ *.sage.py
128
+
129
+ # Environments
130
+ .env
131
+ .venv
132
+ env/
133
+ venv/
134
+ ENV/
135
+ env.bak/
136
+ venv.bak/
137
+
138
+ # Spyder project settings
139
+ .spyderproject
140
+ .spyproject
141
+
142
+ # Rope project settings
143
+ .ropeproject
144
+
145
+ # mkdocs documentation
146
+ /site
147
+
148
+ # mypy
149
+ .mypy_cache/
150
+ .dmypy.json
151
+ dmypy.json
152
+
153
+ # Pyre type checker
154
+ .pyre/
155
+
156
+ # pytype static type analyzer
157
+ .pytype/
158
+
159
+ # Cython debug symbols
160
+ cython_debug/
161
+
162
+ # PyCharm
163
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
166
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167
+ #.idea/
168
+
169
+ # PyPI configuration file
170
+ .pypirc
171
+
172
+
173
+ .DS_Store
174
+
175
+ *.wav
external_models/TangoFlux/Demo.ipynb ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "view-in-github",
7
+ "colab_type": "text"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/OIEIEIO/TangoFlux/blob/main/Demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {
17
+ "collapsed": true,
18
+ "id": "xiaRzuzPOP4H"
19
+ },
20
+ "outputs": [],
21
+ "source": [
22
+ "!pip install git+https://github.com/declare-lab/TangoFlux.git"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {
29
+ "collapsed": true,
30
+ "id": "Hfu3zXTDOP4J"
31
+ },
32
+ "outputs": [],
33
+ "source": [
34
+ "import IPython\n",
35
+ "import torchaudio\n",
36
+ "from tangoflux import TangoFluxInference\n",
37
+ "from IPython.display import Audio\n",
38
+ "\n",
39
+ "model = TangoFluxInference(name='declare-lab/TangoFlux')"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {
46
+ "id": "oFiak5QIOP4K"
47
+ },
48
+ "outputs": [],
49
+ "source": [
50
+ "# @title Generate Audio\n",
51
+ "\n",
52
+ "prompt = 'a futuristic space craft with unique engine sound' # @param {type:\"string\"}\n",
53
+ "duration = 10 # @param {type:\"number\"}\n",
54
+ "steps = 50 # @param {type:\"number\"}\n",
55
+ "\n",
56
+ "audio = model.generate(prompt, steps=steps, duration=duration)\n",
57
+ "\n",
58
+ "Audio(data=audio, rate=44100)"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "source": [
64
+ "import IPython\n",
65
+ "import torchaudio\n",
66
+ "from tangoflux import TangoFluxInference\n",
67
+ "from IPython.display import Audio\n",
68
+ "\n",
69
+ "model = TangoFluxInference(name='declare-lab/TangoFlux')\n",
70
+ "\n",
71
+ "# @title Generate Audio\n",
72
+ "prompt = 'Melodic human whistling harmonizing with natural birdsong' # @param {type:\"string\"}\n",
73
+ "duration = 10 # @param {type:\"number\"}\n",
74
+ "steps = 50 # @param {type:\"number\"}\n",
75
+ "\n",
76
+ "# Generate the audio\n",
77
+ "audio = model.generate(prompt, steps=steps, duration=duration)\n",
78
+ "\n",
79
+ "# Ensure audio is in the correct format (2D Tensor: [channels, samples])\n",
80
+ "if len(audio.shape) == 1: # If mono audio (1D tensor)\n",
81
+ " audio_tensor = audio.unsqueeze(0) # Add channel dimension to make it [1, samples]\n",
82
+ "elif len(audio.shape) == 2: # Stereo audio (2D tensor)\n",
83
+ " audio_tensor = audio # Already in correct format\n",
84
+ "else:\n",
85
+ " raise ValueError(f\"Unexpected audio tensor shape: {audio.shape}\")\n",
86
+ "\n",
87
+ "# Save the audio as a .wav file\n",
88
+ "torchaudio.save('generated_audio.wav', audio_tensor, sample_rate=44100)\n",
89
+ "\n",
90
+ "# Optionally play the audio in the notebook\n",
91
+ "Audio(data=audio.numpy(), rate=44100)\n"
92
+ ],
93
+ "metadata": {
94
+ "id": "_Z8elHyOHOQ1"
95
+ },
96
+ "execution_count": null,
97
+ "outputs": []
98
+ }
99
+ ],
100
+ "metadata": {
101
+ "language_info": {
102
+ "name": "python"
103
+ },
104
+ "colab": {
105
+ "provenance": [],
106
+ "machine_shape": "hm",
107
+ "private_outputs": true,
108
+ "include_colab_link": true
109
+ },
110
+ "kernelspec": {
111
+ "name": "python3",
112
+ "display_name": "Python 3"
113
+ }
114
+ },
115
+ "nbformat": 4,
116
+ "nbformat_minor": 0
117
+ }
external_models/TangoFlux/Inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
external_models/TangoFlux/LICENSE.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LICENSE
2
+
3
+ ## 1. Model & License Summary
4
+
5
+ This repository contains **TangoFlux** (the “Model”) created for **non-commercial, research-only** purposes under the **UK data copyright exemption**. The Model is subject to:
6
+
7
+ 1. The **Stability AI Community License Agreement**, provided in the file ```STABILITY_AI_COMMUNITY_LICENSE.md```.
8
+ 2. The **WavCaps** license requirement: **only academic uses** are permitted for data sourced from WavCaps.
9
+ 3. The **original licenses** of the datasets used in training.
10
+
11
+ By using or distributing this Model, you **agree** to adhere to all applicable licenses and restrictions, as summarized below.
12
+
13
+ ---
14
+
15
+ ## 2. Stability AI Community License Requirements
16
+
17
+ - You must comply with the **Stability AI Community License Agreement** (the “Agreement”) for any usage, distribution, or modification of this Model.
18
+ - **Non-Commercial Use**: This Model is for research and academic purposes only. Any commercial usage requires registering with Stability AI or obtaining a separate commercial license.
19
+ - **Attribution & Notice**:
20
+ - Retain the notice:
21
+ ```
22
+ This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved.
23
+ ```
24
+ - Clearly display “Powered by Stability AI” if you build upon or showcase this Model.
25
+ - **Disclaimer & Liability**: This Model is provided **“AS IS”** with **no warranties**. Neither we nor Stability AI will be liable for any claim or damages related to Model use.
26
+
27
+ See ```STABILITY_AI_COMMUNITY_LICENSE.md``` for the full text.
28
+
29
+ ---
30
+
31
+ ## 3. WavCaps & Dataset Usage
32
+
33
+ - **Academic-Only for WavCaps**: By accessing any WavCaps-sourced data (including audio clips via provided links), you agree to use them **strictly for non-commercial, academic research** in accordance with WavCaps’ terms.
34
+ - **WavCaps Audio**: Each WavCaps audio subset has its own license terms. **You** are responsible for reviewing and complying with those licenses, including attribution requirements on your end.
35
+
36
+ ---
37
+
38
+ ## 4. UK Data Copyright Exemption
39
+
40
+ This Model was developed under the **UK data copyright exemption for non-commercial research**. Distribution or use outside these bounds must **not** violate that exemption or infringe on any underlying dataset’s license.
41
+
42
+ ---
43
+
44
+ ## 5. Further Information
45
+
46
+ - **Stability AI License Terms**: <https://stability.ai/community-license>
47
+ - **WavCaps License**: <https://github.com/XinhaoMei/WavCaps?tab=readme-ov-file#license>
48
+
49
+ ---
50
+
51
+ **End of License**.
external_models/TangoFlux/Notice ADDED
@@ -0,0 +1 @@
 
 
1
+ This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved
external_models/TangoFlux/README.md ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <img src="assets/tf_opener.png" alt="TangoFluxOpener" width="1000" />
3
+
4
+ <br/>
5
+
6
+ [![arXiv](https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf)](https://arxiv.org/abs/2412.21037) [![Static Badge](https://img.shields.io/badge/TangoFlux-Hugging_Face-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat)](https://tangoflux.github.io/) [![Static Badge](https://img.shields.io/badge/TangoFlux-Hugging_Face_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux) [![Static Badge](https://img.shields.io/badge/TangoFlux_Dataset-Hugging_Face-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/datasets/declare-lab/CRPO) [![Replicate](https://replicate.com/declare-lab/tangoflux/badge)](https://replicate.com/declare-lab/tangoflux)
7
+
8
+ <img src="assets/tf_teaser.png" alt="TangoFlux" width="1000" />
9
+ <br/>
10
+
11
+ </div>
12
+
13
+ * Powered by **Stability AI**
14
+ ## News
15
+ > 📣 1/3/25: We have released CRPO dataset as well as the script to perform CRPO dataset generation!
16
+
17
+ ## Demos
18
+
19
+ [![Hugging Face Space](https://img.shields.io/badge/Hugging_Face_Space-TangoFlux-blue?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux)](https://huggingface.co/spaces/declare-lab/TangoFlux)
20
+
21
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/declare-lab/TangoFlux/blob/main/Demo.ipynb)
22
+
23
+ ## Overall Pipeline
24
+
25
+ TangoFlux consists of FluxTransformer blocks, which are Diffusion Transformers (DiT) and Multimodal Diffusion Transformers (MMDiT) conditioned on a textual prompt and a duration embedding to generate a 44.1kHz audio up to 30 seconds long. TangoFlux learns a rectified flow trajectory to an audio latent representation encoded by a variational autoencoder (VAE). TangoFlux training pipeline consists of three stages: pre-training, fine-tuning, and preference optimization with CRPO. CRPO, particularly, iteratively generates new synthetic data and constructs preference pairs for preference optimization using DPO loss for flow matching.
26
+
27
+ ![cover-photo](assets/tangoflux.png)
28
+
29
+ 🚀 **TangoFlux can generate 44.1kHz stereo audio up to 30 seconds in ~3 seconds on a single A40 GPU.**
30
+
31
+ ## Installation
32
+
33
+ ```bash
34
+ pip install git+https://github.com/declare-lab/TangoFlux
35
+ ```
36
+
37
+ ## Inference
38
+
39
+ TangoFlux can generate audio up to 30 seconds long. You must pass a duration to the `model.generate` function when using the Python API. Please note that duration should be between 1 and 30.
40
+
41
+ ### Web Interface
42
+
43
+ Run the following command to start the web interface:
44
+
45
+ ```bash
46
+ tangoflux-demo
47
+ ```
48
+
49
+ ### CLI
50
+
51
+ Use the CLI to generate audio from text.
52
+
53
+ ```bash
54
+ tangoflux "Hammer slowly hitting the wooden table" output.wav --duration 10 --steps 50
55
+ ```
56
+
57
+ ### Python API
58
+
59
+ ```python
60
+ import torchaudio
61
+ from tangoflux import TangoFluxInference
62
+
63
+ model = TangoFluxInference(name='declare-lab/TangoFlux')
64
+ audio = model.generate('Hammer slowly hitting the wooden table', steps=50, duration=10)
65
+
66
+ torchaudio.save('output.wav', audio, 44100)
67
+ ```
68
+
69
+ ### [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
70
+
71
+ > This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface.
72
+
73
+ Check [this](https://github.com/LucipherDev/ComfyUI-TangoFlux) repo for the TangoFlux custom node for *ComfyUI*. (Thanks to [LucipherDev](https://github.com/LucipherDev))
74
+
75
+ Our evaluation shows that inference with 50 steps yields the best results. A CFG scale of 3.5, 4, and 4.5 yield similar quality output. Inference with 25 steps yields similar audio quality at a faster speed.
76
+
77
+ ## Training
78
+
79
+ We use the `accelerate` package from Hugging Face for multi-GPU training. Run `accelerate config` to setup your run configuration. The default accelerate config is in the `configs` folder. Please specify the path to your training files in the `configs/tangoflux_config.yaml`. Samples of `train.json` and `val.json` have been provided. Replace them with your own audio.
80
+
81
+ `tangoflux_config.yaml` defines the training file paths and model hyperparameters:
82
+
83
+ ```bash
84
+ CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' tangoflux/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
85
+ ```
86
+
87
+ To perform DPO training, modify the training files such that each data point contains "chosen", "reject", "caption" and "duration" fields. Please specify the path to your training files in `configs/tangoflux_config.yaml`. An example has been provided in `train_dpo.json`. Replace it with your own audio.
88
+
89
+ ```bash
90
+ CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' tangoflux/train_dpo.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
91
+ ```
92
+
93
+ ## Evaluation
94
+
95
+ ### TangoFlux vs. Other Audio Generation Models
96
+
97
+ This key comparison metrics include:
98
+
99
+ - **Output Length**: Represents the duration of the generated audio.
100
+ - **FD**<sub>openl3</sub>: Fréchet Distance.
101
+ - **KL**<sub>passt</sub>: KL divergence.
102
+ - **CLAP**<sub>score</sub>: Alignment score.
103
+
104
+
105
+ All the inference times are observed on the same A40 GPU. The counts of trainable parameters are reported in the **\#Params** column.
106
+
107
+ | Model | Params | Duration | Steps | FD<sub>openl3</sub> ↓ | KL<sub>passt</sub> ↓ | CLAP<sub>score</sub> ↑ | IS ↑ | Inference Time (s) |
108
+ |---|---|---|---|---|---|---|---|---|
109
+ | **AudioLDM 2 (Large)** | 712M | 10 sec | 200 | 108.3 | 1.81 | 0.419 | 7.9 | 24.8 |
110
+ | **Stable Audio Open** | 1056M | 47 sec | 100 | 89.2 | 2.58 | 0.291 | 9.9 | 8.6 |
111
+ | **Tango 2** | 866M | 10 sec | 200 | 108.4 | 1.11 | 0.447 | 9.0 | 22.8 |
112
+ | **TangoFlux (Base)** | 515M | 30 sec | 50 | 80.2 | 1.22 | 0.431 | 11.7 | 3.7 |
113
+ | **TangoFlux** | 515M | 30 sec | 50 | 75.1 | 1.15 | 0.480 | 12.2 | 3.7 |
114
+
115
+ ## CRPO dataset generation
116
+
117
+ There are 2 py files for CRPO dataset generation.
118
+ tangoflux/generate_crpo.py generates the crpo dataset by providing path to prompt bank and model weights. You can specify the sample size as well as number of samples per prompt for crpo in the arguments.
119
+ tangoflux/label_crpo.py labels the generated audio and construct preference pairs. This will also create a train.json in the output dir that can be passed into train_dpo.py
120
+
121
+ You can follow the example in crpo.sh which will generate crpo dataset, then perform reward labelling to generate the train.json
122
+
123
+ To run CRPO for multiple iteration, you can simply repeat the above the process multiple time through setting the correct model weight.
124
+ ## Citation
125
+
126
+ ```bibtex
127
+ @misc{hung2024tangofluxsuperfastfaithful,
128
+ title={TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization},
129
+ author={Chia-Yu Hung and Navonil Majumder and Zhifeng Kong and Ambuj Mehrish and Amir Zadeh and Chuan Li and Rafael Valle and Bryan Catanzaro and Soujanya Poria},
130
+ year={2024},
131
+ eprint={2412.21037},
132
+ archivePrefix={arXiv},
133
+ primaryClass={cs.SD},
134
+ url={https://arxiv.org/abs/2412.21037},
135
+ }
136
+ ```
137
+
138
+ ## LICENSE
139
+
140
+ ### 1. Model & License Summary
141
+
142
+ This repository contains **TangoFlux** (the “Model”) created for **non-commercial, research-only** purposes under the **UK data copyright exemption**. The Model is subject to:
143
+
144
+ 1. The **Stability AI Community License Agreement**, provided in the file ```STABILITY_AI_COMMUNITY_LICENSE.md```.
145
+ 2. The **WavCaps** license requirement: **only academic uses** are permitted for data sourced from WavCaps.
146
+ 3. The **original licenses** of the datasets used in training.
147
+
148
+ By using or distributing this Model, you **agree** to adhere to all applicable licenses and restrictions, as summarized below.
149
+
150
+ ---
151
+
152
+ ### 2. Stability AI Community License Requirements
153
+
154
+ - You must comply with the **Stability AI Community License Agreement** (the “Agreement”) for any usage, distribution, or modification of this Model.
155
+ - **Non-Commercial Use**: This Model is for research and academic purposes only. Any commercial usage requires registering with Stability AI or obtaining a separate commercial license.
156
+ - **Attribution & Notice**:
157
+ - Retain the notice:
158
+ ```
159
+ This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved.
160
+ ```
161
+ - Clearly display “Powered by Stability AI” if you build upon or showcase this Model.
162
+ - **Disclaimer & Liability**: This Model is provided **“AS IS”** with **no warranties**. Neither we nor Stability AI will be liable for any claim or damages related to Model use.
163
+
164
+ See ```STABILITY_AI_COMMUNITY_LICENSE.md``` for the full text.
165
+
166
+ ---
167
+
168
+ ### 3. WavCaps & Dataset Usage
169
+
170
+ - **Academic-Only for WavCaps**: By accessing any WavCaps-sourced data (including audio clips via provided links), you agree to use them **strictly for non-commercial, academic research** in accordance with WavCaps’ terms.
171
+ - **WavCaps Audio**: Each WavCaps audio subset has its own license terms. **You** are responsible for reviewing and complying with those licenses, including attribution requirements on your end.
172
+
173
+ ---
174
+
175
+ ### 4. UK Data Copyright Exemption
176
+
177
+ This Model was developed under the **UK data copyright exemption for non-commercial research**. Distribution or use outside these bounds must **not** violate that exemption or infringe on any underlying dataset’s license.
178
+
179
+ ---
180
+
181
+ ### 5. Further Information
182
+
183
+ - **Stability AI License Terms**: <https://stability.ai/community-license>
184
+ - **WavCaps License**: <https://github.com/XinhaoMei/WavCaps?tab=readme-ov-file#license>
185
+
186
+ ---
187
+
188
+ **End of License**.
external_models/TangoFlux/STABILITY_AI_COMMUNITY_LICENSE.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABILITY AI COMMUNITY LICENSE AGREEMENT
2
+
3
+ Last Updated: July 5, 2024
4
+ 1. INTRODUCTION
5
+
6
+ This Agreement applies to any individual person or entity (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
7
+
8
+ This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
9
+
10
+ By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then “You” includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity’s behalf.
11
+
12
+ 2. RESEARCH & NON-COMMERCIAL USE LICENSE
13
+
14
+ Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. “Research Purpose” means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. “Non-Commercial Purpose” means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
15
+
16
+ 3. COMMERCIAL USE LICENSE
17
+
18
+ Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. “Commercial Purpose” means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business’s or organization’s internal operations.
19
+ If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
20
+
21
+ 4. GENERAL TERMS
22
+
23
+ Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
24
+ a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved”, and (iii) prominently display “Powered by Stability AI” on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the “Notice” text file that You changed the Stability AI Materials and how it was modified.
25
+ b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
26
+ c. Intellectual Property.
27
+ (i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
28
+ (ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
29
+ (iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
30
+ (iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
31
+ (v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI’s existing or prospective technology, products or services (collectively, “Feedback”). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided “AS IS” and You make no warranties whatsoever about any Feedback.
32
+ d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
33
+ e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
34
+ f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
35
+ g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
36
+
37
+ 5. DEFINITIONS
38
+
39
+ “Affiliate(s)” means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, “control” means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
40
+
41
+ "Agreement" means this Stability AI Community License Agreement.
42
+
43
+ “AUP” means the Stability AI Acceptable Use Policy available at (https://stability.ai/use-policy), as may be updated from time to time.
44
+
45
+ "Derivative Work(s)” means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output, including “fine tune” and “low-rank adaptation” models derived from a Model or a Model’s output, but do not include the output of any Model.
46
+
47
+ “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
48
+
49
+ “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability’s Core Models Webpage available at (https://stability.ai/core-models), as may be updated from time to time.
50
+
51
+ "Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
52
+
53
+ "Software" means Stability AI’s proprietary software made available under this Agreement now or in the future.
54
+
55
+ “Stability AI Materials” means, collectively, Stability’s proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
56
+
57
+ “Trade Control Laws” means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
external_models/TangoFlux/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ try:
2
+ from .comfyui import *
3
+ except:
4
+ pass
external_models/TangoFlux/assets/tangoflux.png ADDED

Git LFS Details

  • SHA256: e8e19e12b3c2c991a29987d7fceaed80aa8ed306827cfaa0894d666b5c250702
  • Pointer size: 131 Bytes
  • Size of remote file: 304 kB
external_models/TangoFlux/assets/tf_opener.png ADDED

Git LFS Details

  • SHA256: 58934ca2300804d67bc73c7116c3a0d956d770e0bd6e816aa9dbe9034f5b32fe
  • Pointer size: 131 Bytes
  • Size of remote file: 465 kB
external_models/TangoFlux/assets/tf_teaser.png ADDED

Git LFS Details

  • SHA256: 475a101c58ee8cb7481172d24763fddcc1da59f578aaeccf9d8052f5a86401b6
  • Pointer size: 131 Bytes
  • Size of remote file: 778 kB
external_models/TangoFlux/comfyui/README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI-TangoFlux
2
+ ComfyUI Custom Nodes for ["TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching"](https://arxiv.org/abs/2412.21037). These nodes, adapted from [the official implementations](https://github.com/declare-lab/TangoFlux/), generates high-quality 44.1kHz audio up to 30 seconds using just a text promptproduction.
3
+
4
+ ## Installation
5
+
6
+ 1. Navigate to your ComfyUI's custom_nodes directory:
7
+ ```bash
8
+ cd ComfyUI/custom_nodes
9
+ ```
10
+
11
+ 2. Clone this repository:
12
+ ```bash
13
+ git clone https://github.com/declare-lab/TangoFlux ComfyUI-TangoFlux
14
+ ```
15
+
16
+ 3. Install requirements:
17
+ ```bash
18
+ cd ComfyUI-TangoFlux/comfyui
19
+ python install.py
20
+ ```
21
+
22
+ ### Or Install via ComfyUI Manager
23
+
24
+ #### Check out some demos from [the official demo page](https://tangoflux.github.io/)
25
+
26
+ ## Example Workflow
27
+
28
+ ![example_workflow](https://github.com/user-attachments/assets/afbf7b53-d712-4c9c-a538-53f0dc001f45)
29
+
30
+ ## Usage
31
+
32
+ **All the necessary models should be automatically downloaded when the TangoFluxLoader node is used for the first time.**
33
+
34
+ **Models can also be downloaded using the `install.py` script**
35
+
36
+ ![models_folder_structure](https://github.com/user-attachments/assets/94d8a54a-10d6-4f90-bb4d-3ee181dee3a2)
37
+
38
+ **Manual Download:**
39
+ - Download TangoFlux from [here](https://huggingface.co/declare-lab/TangoFlux/tree/main) into `models/tangoflux`
40
+ - Download text encoders from [here](https://huggingface.co/google/flan-t5-large/tree/main) into `models/text_encoders/google-flan-t5-large`
41
+
42
+ *(Include Everything as shown in the screenshot above. Do Not Rename Anything)*
43
+
44
+ The nodes can be found in "TangoFlux" category as `TangoFluxLoader`, `TangoFluxSampler`, `TangoFluxVAEDecodeAndPlay`.
45
+
46
+ ![teacache_options](https://github.com/user-attachments/assets/29e676d9-902b-4ea2-9f72-18d3607996e8)
47
+
48
+ > [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup TangoFlux 2x without much audio quality degradation, in a training-free manner.
49
+ >
50
+ >
51
+ > ## 📈 Inference Latency Comparisons on a Single A800
52
+ >
53
+ >
54
+ > | TangoFlux | TeaCache (0.25) | TeaCache (0.4) |
55
+ > |:-------------------:|:----------------------------:|:--------------------:|
56
+ > | ~4.08 s | ~2.42 s | ~1.95 s |
57
+
58
+ ## Citation
59
+
60
+ ```bibtex
61
+ @misc{hung2024tangofluxsuperfastfaithful,
62
+ title={TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization},
63
+ author={Chia-Yu Hung and Navonil Majumder and Zhifeng Kong and Ambuj Mehrish and Rafael Valle and Bryan Catanzaro and Soujanya Poria},
64
+ year={2024},
65
+ eprint={2412.21037},
66
+ archivePrefix={arXiv},
67
+ primaryClass={cs.SD},
68
+ url={https://arxiv.org/abs/2412.21037},
69
+ }
70
+ ```
71
+ ```
72
+ @article{liu2024timestep,
73
+ title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
74
+ author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
75
+ journal={arXiv preprint arXiv:2411.19108},
76
+ year={2024}
77
+ }
78
+ ```
external_models/TangoFlux/comfyui/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .nodes import NODE_CLASS_MAPPINGS
2
+ from .server import *
3
+
4
+ WEB_DIRECTORY = "./comfyui/web"
5
+
6
+ __all__ = ["NODE_CLASS_MAPPINGS", "WEB_DIRECTORY"]
external_models/TangoFlux/comfyui/example_workflow.json ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "last_node_id": 13,
3
+ "last_link_id": 15,
4
+ "nodes": [
5
+ {
6
+ "id": 10,
7
+ "type": "TangoFluxLoader",
8
+ "pos": [
9
+ 380,
10
+ 320
11
+ ],
12
+ "size": [
13
+ 210,
14
+ 102
15
+ ],
16
+ "flags": {},
17
+ "order": 0,
18
+ "mode": 0,
19
+ "inputs": [],
20
+ "outputs": [
21
+ {
22
+ "name": "model",
23
+ "type": "TANGOFLUX_MODEL",
24
+ "links": [
25
+ 11
26
+ ],
27
+ "slot_index": 0
28
+ },
29
+ {
30
+ "name": "vae",
31
+ "type": "TANGOFLUX_VAE",
32
+ "links": [
33
+ 15
34
+ ],
35
+ "slot_index": 1
36
+ }
37
+ ],
38
+ "properties": {
39
+ "Node name for S&R": "TangoFluxLoader"
40
+ },
41
+ "widgets_values": [
42
+ false,
43
+ 0.25
44
+ ]
45
+ },
46
+ {
47
+ "id": 13,
48
+ "type": "TangoFluxVAEDecodeAndPlay",
49
+ "pos": [
50
+ 1060,
51
+ 320
52
+ ],
53
+ "size": [
54
+ 315,
55
+ 126
56
+ ],
57
+ "flags": {},
58
+ "order": 2,
59
+ "mode": 0,
60
+ "inputs": [
61
+ {
62
+ "name": "vae",
63
+ "type": "TANGOFLUX_VAE",
64
+ "link": 15
65
+ },
66
+ {
67
+ "name": "latents",
68
+ "type": "TANGOFLUX_LATENTS",
69
+ "link": 14
70
+ }
71
+ ],
72
+ "outputs": [],
73
+ "properties": {
74
+ "Node name for S&R": "TangoFluxVAEDecodeAndPlay"
75
+ },
76
+ "widgets_values": [
77
+ "TangoFlux",
78
+ "wav",
79
+ true
80
+ ]
81
+ },
82
+ {
83
+ "id": 11,
84
+ "type": "TangoFluxSampler",
85
+ "pos": [
86
+ 620,
87
+ 320
88
+ ],
89
+ "size": [
90
+ 400,
91
+ 220
92
+ ],
93
+ "flags": {},
94
+ "order": 1,
95
+ "mode": 0,
96
+ "inputs": [
97
+ {
98
+ "name": "model",
99
+ "type": "TANGOFLUX_MODEL",
100
+ "link": 11
101
+ }
102
+ ],
103
+ "outputs": [
104
+ {
105
+ "name": "latents",
106
+ "type": "TANGOFLUX_LATENTS",
107
+ "links": [
108
+ 14
109
+ ],
110
+ "slot_index": 0
111
+ }
112
+ ],
113
+ "properties": {
114
+ "Node name for S&R": "TangoFluxSampler"
115
+ },
116
+ "widgets_values": [
117
+ "A dog barking near the ocean, ocean waves crashing.",
118
+ 50,
119
+ 3,
120
+ 10,
121
+ 106139285587780,
122
+ "randomize",
123
+ 1
124
+ ]
125
+ }
126
+ ],
127
+ "links": [
128
+ [
129
+ 11,
130
+ 10,
131
+ 0,
132
+ 11,
133
+ 0,
134
+ "TANGOFLUX_MODEL"
135
+ ],
136
+ [
137
+ 14,
138
+ 11,
139
+ 0,
140
+ 13,
141
+ 1,
142
+ "TANGOFLUX_LATENTS"
143
+ ],
144
+ [
145
+ 15,
146
+ 10,
147
+ 1,
148
+ 13,
149
+ 0,
150
+ "TANGOFLUX_VAE"
151
+ ]
152
+ ],
153
+ "groups": [],
154
+ "config": {},
155
+ "extra": {
156
+ "ds": {
157
+ "scale": 0.9480295566502464,
158
+ "offset": [
159
+ -200.83333333333337,
160
+ -102.2460379319304
161
+ ]
162
+ },
163
+ "node_versions": {
164
+ "comfyui-tangoflux": "1.0.4"
165
+ }
166
+ },
167
+ "version": 0.4
168
+ }
external_models/TangoFlux/comfyui/install.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import logging
4
+ import subprocess
5
+ import traceback
6
+ import json
7
+ import re
8
+
9
+ log = logging.getLogger("TangoFlux")
10
+
11
+ download_models = True
12
+
13
+ EXT_PATH = os.path.dirname(os.path.abspath(__file__))
14
+
15
+ try:
16
+ folder_paths_path = os.path.abspath(os.path.join(EXT_PATH, "..", "..", "..", "folder_paths.py"))
17
+
18
+ sys.path.append(os.path.dirname(folder_paths_path))
19
+
20
+ import folder_paths
21
+
22
+ TANGOFLUX_DIR = os.path.join(folder_paths.models_dir, "tangoflux")
23
+ TEXT_ENCODER_DIR = os.path.join(folder_paths.models_dir, "text_encoders")
24
+ except:
25
+ download_models = False
26
+
27
+ try:
28
+ log.info("Installing requirements")
29
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", f"{EXT_PATH}/requirements.txt", "--no-warn-script-location"])
30
+
31
+ if download_models:
32
+ from huggingface_hub import snapshot_download
33
+
34
+ log.info("Downloading Necessary models")
35
+
36
+ try:
37
+ log.info(f"Downloading TangoFlux models to: {TANGOFLUX_DIR}")
38
+ snapshot_download(
39
+ repo_id="declare-lab/TangoFlux",
40
+ allow_patterns=["*.json", "*.safetensors"],
41
+ local_dir=TANGOFLUX_DIR,
42
+ local_dir_use_symlinks=False,
43
+ )
44
+ except Exception:
45
+ traceback.print_exc()
46
+ log.error("Failed to download TangoFlux models")
47
+
48
+ log.info("Loading config")
49
+
50
+ with open(os.path.join(TANGOFLUX_DIR, "config.json"), "r") as f:
51
+ config = json.load(f)
52
+
53
+ try:
54
+ text_encoder = re.sub(r'[<>:"/\\|?*]', '-', config.get("text_encoder_name", "google/flan-t5-large"))
55
+ text_encoder_path = os.path.join(TEXT_ENCODER_DIR, text_encoder)
56
+
57
+ log.info(f"Downloading text encoders to: {text_encoder_path}")
58
+ snapshot_download(
59
+ repo_id=config.get("text_encoder_name", "google/flan-t5-large"),
60
+ allow_patterns=["*.json", "*.safetensors", "*.model"],
61
+ local_dir=text_encoder_path,
62
+ local_dir_use_symlinks=False,
63
+ )
64
+ except Exception:
65
+ traceback.print_exc()
66
+ log.error("Failed to download text encoders")
67
+
68
+ try:
69
+ log.info("Installing TangoFlux module")
70
+ subprocess.check_call([sys.executable, "-m", "pip", "install", os.path.join(EXT_PATH, "..")])
71
+ except Exception:
72
+ traceback.print_exc()
73
+ log.error("Failed to install TangoFlux module")
74
+
75
+ log.info("TangoFlux Installation completed")
76
+
77
+ except Exception:
78
+ traceback.print_exc()
79
+ log.error("TangoFlux Installation failed")
external_models/TangoFlux/comfyui/nodes.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import json
4
+ import random
5
+ import torch
6
+ import torchaudio
7
+ import re
8
+
9
+ from diffusers import AutoencoderOobleck, FluxTransformer2DModel
10
+ from huggingface_hub import snapshot_download
11
+
12
+ from comfy.utils import load_torch_file, ProgressBar
13
+ import folder_paths
14
+
15
+ from tangoflux.model import TangoFlux
16
+ from .teacache import teacache_forward
17
+
18
+ log = logging.getLogger("TangoFlux")
19
+
20
+ TANGOFLUX_DIR = os.path.join(folder_paths.models_dir, "tangoflux")
21
+ if "tangoflux" not in folder_paths.folder_names_and_paths:
22
+ current_paths = [TANGOFLUX_DIR]
23
+ else:
24
+ current_paths, _ = folder_paths.folder_names_and_paths["tangoflux"]
25
+ folder_paths.folder_names_and_paths["tangoflux"] = (
26
+ current_paths,
27
+ folder_paths.supported_pt_extensions,
28
+ )
29
+ TEXT_ENCODER_DIR = os.path.join(folder_paths.models_dir, "text_encoders")
30
+
31
+
32
+ class TangoFluxLoader:
33
+ @classmethod
34
+ def INPUT_TYPES(cls):
35
+ return {
36
+ "required": {
37
+ "enable_teacache": ("BOOLEAN", {"default": False}),
38
+ "rel_l1_thresh": (
39
+ "FLOAT",
40
+ {"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.01},
41
+ ),
42
+ },
43
+ }
44
+
45
+ RETURN_TYPES = ("TANGOFLUX_MODEL", "TANGOFLUX_VAE")
46
+ RETURN_NAMES = ("model", "vae")
47
+ OUTPUT_TOOLTIPS = ("TangoFlux Model", "TangoFlux Vae")
48
+
49
+ CATEGORY = "TangoFlux"
50
+ FUNCTION = "load_tangoflux"
51
+ DESCRIPTION = "Load TangoFlux model"
52
+
53
+ def __init__(self):
54
+ self.model = None
55
+ self.vae = None
56
+ self.enable_teacache = False
57
+ self.rel_l1_thresh = 0.25
58
+ self.original_forward = FluxTransformer2DModel.forward
59
+
60
+ def load_tangoflux(
61
+ self,
62
+ enable_teacache=False,
63
+ rel_l1_thresh=0.25,
64
+ tangoflux_path=TANGOFLUX_DIR,
65
+ text_encoder_path=TEXT_ENCODER_DIR,
66
+ device="cuda",
67
+ ):
68
+ if self.model is None or self.enable_teacache != enable_teacache:
69
+
70
+ pbar = ProgressBar(6)
71
+
72
+ snapshot_download(
73
+ repo_id="declare-lab/TangoFlux",
74
+ allow_patterns=["*.json", "*.safetensors"],
75
+ local_dir=tangoflux_path,
76
+ local_dir_use_symlinks=False,
77
+ )
78
+
79
+ pbar.update(1)
80
+
81
+ log.info("Loading config")
82
+
83
+ with open(os.path.join(tangoflux_path, "config.json"), "r") as f:
84
+ config = json.load(f)
85
+
86
+ pbar.update(1)
87
+
88
+ text_encoder = re.sub(
89
+ r'[<>:"/\\|?*]',
90
+ "-",
91
+ config.get("text_encoder_name", "google/flan-t5-large"),
92
+ )
93
+ text_encoder_path = os.path.join(text_encoder_path, text_encoder)
94
+
95
+ snapshot_download(
96
+ repo_id=config.get("text_encoder_name", "google/flan-t5-large"),
97
+ allow_patterns=["*.json", "*.safetensors", "*.model"],
98
+ local_dir=text_encoder_path,
99
+ local_dir_use_symlinks=False,
100
+ )
101
+
102
+ pbar.update(1)
103
+
104
+ log.info("Loading TangoFlux models")
105
+
106
+ del self.model
107
+ self.model = None
108
+
109
+ model_weights = load_torch_file(
110
+ os.path.join(tangoflux_path, "tangoflux.safetensors"),
111
+ device=torch.device(device),
112
+ )
113
+
114
+ pbar.update(1)
115
+
116
+ if enable_teacache:
117
+ log.info("Enabling TeaCache")
118
+ FluxTransformer2DModel.forward = teacache_forward
119
+ else:
120
+ log.info("Disabling TeaCache")
121
+ FluxTransformer2DModel.forward = self.original_forward
122
+
123
+ model = TangoFlux(config=config, text_encoder_dir=text_encoder_path)
124
+
125
+ model.load_state_dict(model_weights, strict=False)
126
+ model.to(device)
127
+
128
+ if enable_teacache:
129
+ model.transformer.__class__.enable_teacache = True
130
+ model.transformer.__class__.cnt = 0
131
+ model.transformer.__class__.rel_l1_thresh = rel_l1_thresh
132
+ model.transformer.__class__.accumulated_rel_l1_distance = 0
133
+ model.transformer.__class__.previous_modulated_input = None
134
+ model.transformer.__class__.previous_residual = None
135
+
136
+ pbar.update(1)
137
+
138
+ self.model = model
139
+ del model
140
+ self.enable_teacache = enable_teacache
141
+ self.rel_l1_thresh = rel_l1_thresh
142
+
143
+ if self.vae is None:
144
+ log.info("Loading TangoFlux VAE")
145
+
146
+ vae_weights = load_torch_file(
147
+ os.path.join(tangoflux_path, "vae.safetensors")
148
+ )
149
+ self.vae = AutoencoderOobleck()
150
+ self.vae.load_state_dict(vae_weights)
151
+ self.vae.to(device)
152
+
153
+ pbar.update(1)
154
+
155
+ if self.enable_teacache == True and self.rel_l1_thresh != rel_l1_thresh:
156
+ self.model.transformer.__class__.rel_l1_thresh = rel_l1_thresh
157
+
158
+ self.rel_l1_thresh = rel_l1_thresh
159
+
160
+ return (self.model, self.vae)
161
+
162
+
163
+ class TangoFluxSampler:
164
+ @classmethod
165
+ def INPUT_TYPES(cls):
166
+ return {
167
+ "required": {
168
+ "model": ("TANGOFLUX_MODEL",),
169
+ "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
170
+ "steps": ("INT", {"default": 50, "min": 1, "max": 10000, "step": 1}),
171
+ "guidance_scale": (
172
+ "FLOAT",
173
+ {"default": 3, "min": 1, "max": 100, "step": 1},
174
+ ),
175
+ "duration": ("INT", {"default": 10, "min": 1, "max": 30, "step": 1}),
176
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
177
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
178
+ },
179
+ }
180
+
181
+ RETURN_TYPES = ("TANGOFLUX_LATENTS",)
182
+ RETURN_NAMES = ("latents",)
183
+ OUTPUT_TOOLTIPS = "TangoFlux Sample"
184
+
185
+ CATEGORY = "TangoFlux"
186
+ FUNCTION = "sample"
187
+ DESCRIPTION = "Sampler for TangoFlux"
188
+
189
+ def sample(
190
+ self,
191
+ model,
192
+ prompt,
193
+ steps=50,
194
+ guidance_scale=3,
195
+ duration=10,
196
+ seed=0,
197
+ batch_size=1,
198
+ device="cuda",
199
+ ):
200
+ pbar = ProgressBar(steps)
201
+
202
+ with torch.no_grad():
203
+ model.to(device)
204
+
205
+ try:
206
+ if model.transformer.__class__.enable_teacache:
207
+ model.transformer.__class__.num_steps = steps
208
+ except:
209
+ pass
210
+
211
+ log.info("Generating latents with TangoFlux")
212
+
213
+ latents = model.inference_flow(
214
+ prompt,
215
+ duration=duration,
216
+ num_inference_steps=steps,
217
+ guidance_scale=guidance_scale,
218
+ seed=seed,
219
+ num_samples_per_prompt=batch_size,
220
+ callback_on_step_end=lambda: pbar.update(1),
221
+ )
222
+
223
+ return ({"latents": latents, "duration": duration},)
224
+
225
+
226
+ class TangoFluxVAEDecodeAndPlay:
227
+ @classmethod
228
+ def INPUT_TYPES(cls):
229
+ return {
230
+ "required": {
231
+ "vae": ("TANGOFLUX_VAE",),
232
+ "latents": ("TANGOFLUX_LATENTS",),
233
+ "filename_prefix": ("STRING", {"default": "TangoFlux"}),
234
+ "format": (
235
+ ["wav", "mp3", "flac", "aac", "wma"],
236
+ {"default": "wav"},
237
+ ),
238
+ "save_output": ("BOOLEAN", {"default": True}),
239
+ },
240
+ }
241
+
242
+ RETURN_TYPES = ()
243
+ OUTPUT_NODE = True
244
+
245
+ CATEGORY = "TangoFlux"
246
+ FUNCTION = "play"
247
+ DESCRIPTION = "Decoder and Player for TangoFlux"
248
+
249
+ def decode(self, vae, latents):
250
+ results = []
251
+
252
+ for latent in latents:
253
+ decoded = vae.decode(latent.unsqueeze(0).transpose(2, 1)).sample.cpu()
254
+ results.append(decoded)
255
+
256
+ results = torch.cat(results, dim=0)
257
+
258
+ return results
259
+
260
+ def play(
261
+ self,
262
+ vae,
263
+ latents,
264
+ filename_prefix="TangoFlux",
265
+ format="wav",
266
+ save_output=True,
267
+ device="cuda",
268
+ ):
269
+ audios = []
270
+ pbar = ProgressBar(len(latents) + 2)
271
+
272
+ if save_output:
273
+ output_dir = folder_paths.get_output_directory()
274
+ prefix_append = ""
275
+ type = "output"
276
+ else:
277
+ output_dir = folder_paths.get_temp_directory()
278
+ prefix_append = "_temp_" + "".join(
279
+ random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)
280
+ )
281
+ type = "temp"
282
+
283
+ filename_prefix += prefix_append
284
+ full_output_folder, filename, counter, subfolder, _ = (
285
+ folder_paths.get_save_image_path(filename_prefix, output_dir)
286
+ )
287
+
288
+ os.makedirs(full_output_folder, exist_ok=True)
289
+
290
+ pbar.update(1)
291
+
292
+ duration = latents["duration"]
293
+ latents = latents["latents"]
294
+
295
+ vae.to(device)
296
+
297
+ log.info("Decoding Tangoflux latents")
298
+
299
+ waves = self.decode(vae, latents)
300
+
301
+ pbar.update(1)
302
+
303
+ for wave in waves:
304
+ waveform_end = int(duration * vae.config.sampling_rate)
305
+ wave = wave[:, :waveform_end]
306
+
307
+ file = f"{filename}_{counter:05}_.{format}"
308
+
309
+ torchaudio.save(
310
+ os.path.join(full_output_folder, file), wave, sample_rate=44100
311
+ )
312
+
313
+ counter += 1
314
+
315
+ audios.append({"filename": file, "subfolder": subfolder, "type": type})
316
+
317
+ pbar.update(1)
318
+
319
+ return {
320
+ "ui": {"audios": audios},
321
+ }
322
+
323
+
324
+ NODE_CLASS_MAPPINGS = {
325
+ "TangoFluxLoader": TangoFluxLoader,
326
+ "TangoFluxSampler": TangoFluxSampler,
327
+ "TangoFluxVAEDecodeAndPlay": TangoFluxVAEDecodeAndPlay,
328
+ }
external_models/TangoFlux/comfyui/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torchaudio
2
+ torchlibrosa
3
+ torchvision
4
+ diffusers
5
+ accelerate
6
+ datasets
7
+ librosa
8
+ wandb
9
+ tqdm
external_models/TangoFlux/comfyui/server.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import server
3
+ import folder_paths
4
+
5
+ web = server.web
6
+
7
+
8
+ @server.PromptServer.instance.routes.get("/tangoflux/playaudio")
9
+ async def play_audio(request):
10
+ query = request.rel_url.query
11
+
12
+ filename = query.get("filename", None)
13
+
14
+ if filename is None:
15
+ return web.Response(status=404)
16
+
17
+ if filename[0] == "/" or ".." in filename:
18
+ return web.Response(status=403)
19
+
20
+ filename, output_dir = folder_paths.annotated_filepath(filename)
21
+
22
+ if not output_dir:
23
+ file_type = query.get("type", "output")
24
+ output_dir = folder_paths.get_directory_by_type(file_type)
25
+
26
+ if output_dir is None:
27
+ return web.Response(status=400)
28
+
29
+ subfolder = query.get("subfolder", None)
30
+ if subfolder:
31
+ full_output_dir = os.path.join(output_dir, subfolder)
32
+ if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
33
+ return web.Response(status=403)
34
+ output_dir = full_output_dir
35
+
36
+ filename = os.path.basename(filename)
37
+ file_path = os.path.join(output_dir, filename)
38
+
39
+ if not os.path.isfile(file_path):
40
+ return web.Response(status=404)
41
+
42
+ _, ext = os.path.splitext(filename)
43
+ ext = ext.lower()
44
+
45
+ content_types = {
46
+ ".wav": "audio/wav",
47
+ ".mp3": "audio/mpeg",
48
+ ".flac": "audio/flac",
49
+ ".aac": "audio/aac",
50
+ ".wma": "audio/x-ms-wma",
51
+ }
52
+
53
+ content_type = content_types.get(ext, None)
54
+
55
+ if content_type is None:
56
+ return web.Response(status=400)
57
+
58
+ try:
59
+ with open(file_path, "rb") as file:
60
+ data = file.read()
61
+ except:
62
+ return web.Response(status=500)
63
+
64
+ return web.Response(body=data, content_type=content_type)
external_models/TangoFlux/comfyui/teacache.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code from https://github.com/ali-vilab/TeaCache/blob/main/TeaCache4TangoFlux/teacache_tango_flux.py
2
+
3
+ from typing import Any, Dict, Optional, Union
4
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
5
+ from diffusers.utils import (
6
+ USE_PEFT_BACKEND,
7
+ is_torch_version,
8
+ logging,
9
+ scale_lora_layers,
10
+ unscale_lora_layers,
11
+ )
12
+ import torch
13
+ import numpy as np
14
+
15
+
16
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
+
18
+
19
+ def teacache_forward(
20
+ self,
21
+ hidden_states: torch.Tensor,
22
+ encoder_hidden_states: torch.Tensor = None,
23
+ pooled_projections: torch.Tensor = None,
24
+ timestep: torch.LongTensor = None,
25
+ img_ids: torch.Tensor = None,
26
+ txt_ids: torch.Tensor = None,
27
+ guidance: torch.Tensor = None,
28
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
29
+ return_dict: bool = True,
30
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
31
+ """
32
+ The [`FluxTransformer2DModel`] forward method.
33
+
34
+ Args:
35
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
36
+ Input `hidden_states`.
37
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
38
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
39
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
40
+ from the embeddings of input conditions.
41
+ timestep ( `torch.LongTensor`):
42
+ Used to indicate denoising step.
43
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
44
+ A list of tensors that if specified are added to the residuals of transformer blocks.
45
+ joint_attention_kwargs (`dict`, *optional*):
46
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
47
+ `self.processor` in
48
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
49
+ return_dict (`bool`, *optional*, defaults to `True`):
50
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
51
+ tuple.
52
+
53
+ Returns:
54
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
55
+ `tuple` where the first element is the sample tensor.
56
+ """
57
+ if joint_attention_kwargs is not None:
58
+ joint_attention_kwargs = joint_attention_kwargs.copy()
59
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
60
+ else:
61
+ lora_scale = 1.0
62
+
63
+ if USE_PEFT_BACKEND:
64
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
65
+ scale_lora_layers(self, lora_scale)
66
+ else:
67
+ if (
68
+ joint_attention_kwargs is not None
69
+ and joint_attention_kwargs.get("scale", None) is not None
70
+ ):
71
+ logger.warning(
72
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
73
+ )
74
+ hidden_states = self.x_embedder(hidden_states)
75
+
76
+ timestep = timestep.to(hidden_states.dtype) * 1000
77
+ if guidance is not None:
78
+ guidance = guidance.to(hidden_states.dtype) * 1000
79
+ else:
80
+ guidance = None
81
+ temb = (
82
+ self.time_text_embed(timestep, pooled_projections)
83
+ if guidance is None
84
+ else self.time_text_embed(timestep, guidance, pooled_projections)
85
+ )
86
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
87
+
88
+ ids = torch.cat((txt_ids, img_ids), dim=1)
89
+ image_rotary_emb = self.pos_embed(ids)
90
+
91
+ if self.enable_teacache:
92
+ inp = hidden_states.clone()
93
+ temb_ = temb.clone()
94
+ modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
95
+ self.transformer_blocks[0].norm1(inp, emb=temb_)
96
+ )
97
+ if self.cnt == 0 or self.cnt == self.num_steps - 1:
98
+ should_calc = True
99
+ self.accumulated_rel_l1_distance = 0
100
+ else:
101
+ coefficients = [
102
+ 4.98651651e02,
103
+ -2.83781631e02,
104
+ 5.58554382e01,
105
+ -3.82021401e00,
106
+ 2.64230861e-01,
107
+ ]
108
+ rescale_func = np.poly1d(coefficients)
109
+ self.accumulated_rel_l1_distance += rescale_func(
110
+ (
111
+ (modulated_inp - self.previous_modulated_input).abs().mean()
112
+ / self.previous_modulated_input.abs().mean()
113
+ )
114
+ .cpu()
115
+ .item()
116
+ )
117
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
118
+ should_calc = False
119
+ else:
120
+ should_calc = True
121
+ self.accumulated_rel_l1_distance = 0
122
+ self.previous_modulated_input = modulated_inp
123
+ self.cnt += 1
124
+ if self.cnt == self.num_steps:
125
+ self.cnt = 0
126
+
127
+ if self.enable_teacache:
128
+ if not should_calc:
129
+ hidden_states += self.previous_residual
130
+ else:
131
+ ori_hidden_states = hidden_states.clone()
132
+ for index_block, block in enumerate(self.transformer_blocks):
133
+ if self.training and self.gradient_checkpointing:
134
+
135
+ def create_custom_forward(module, return_dict=None):
136
+ def custom_forward(*inputs):
137
+ if return_dict is not None:
138
+ return module(*inputs, return_dict=return_dict)
139
+ else:
140
+ return module(*inputs)
141
+
142
+ return custom_forward
143
+
144
+ ckpt_kwargs: Dict[str, Any] = (
145
+ {"use_reentrant": False}
146
+ if is_torch_version(">=", "1.11.0")
147
+ else {}
148
+ )
149
+ encoder_hidden_states, hidden_states = (
150
+ torch.utils.checkpoint.checkpoint(
151
+ create_custom_forward(block),
152
+ hidden_states,
153
+ encoder_hidden_states,
154
+ temb,
155
+ image_rotary_emb,
156
+ **ckpt_kwargs,
157
+ )
158
+ )
159
+
160
+ else:
161
+ encoder_hidden_states, hidden_states = block(
162
+ hidden_states=hidden_states,
163
+ encoder_hidden_states=encoder_hidden_states,
164
+ temb=temb,
165
+ image_rotary_emb=image_rotary_emb,
166
+ )
167
+
168
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
169
+
170
+ for index_block, block in enumerate(self.single_transformer_blocks):
171
+ if self.training and self.gradient_checkpointing:
172
+
173
+ def create_custom_forward(module, return_dict=None):
174
+ def custom_forward(*inputs):
175
+ if return_dict is not None:
176
+ return module(*inputs, return_dict=return_dict)
177
+ else:
178
+ return module(*inputs)
179
+
180
+ return custom_forward
181
+
182
+ ckpt_kwargs: Dict[str, Any] = (
183
+ {"use_reentrant": False}
184
+ if is_torch_version(">=", "1.11.0")
185
+ else {}
186
+ )
187
+ hidden_states = torch.utils.checkpoint.checkpoint(
188
+ create_custom_forward(block),
189
+ hidden_states,
190
+ temb,
191
+ image_rotary_emb,
192
+ **ckpt_kwargs,
193
+ )
194
+
195
+ else:
196
+ hidden_states = block(
197
+ hidden_states=hidden_states,
198
+ temb=temb,
199
+ image_rotary_emb=image_rotary_emb,
200
+ )
201
+
202
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
203
+ self.previous_residual = hidden_states - ori_hidden_states
204
+ else:
205
+ for index_block, block in enumerate(self.transformer_blocks):
206
+ if self.training and self.gradient_checkpointing:
207
+
208
+ def create_custom_forward(module, return_dict=None):
209
+ def custom_forward(*inputs):
210
+ if return_dict is not None:
211
+ return module(*inputs, return_dict=return_dict)
212
+ else:
213
+ return module(*inputs)
214
+
215
+ return custom_forward
216
+
217
+ ckpt_kwargs: Dict[str, Any] = (
218
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
219
+ )
220
+ encoder_hidden_states, hidden_states = (
221
+ torch.utils.checkpoint.checkpoint(
222
+ create_custom_forward(block),
223
+ hidden_states,
224
+ encoder_hidden_states,
225
+ temb,
226
+ image_rotary_emb,
227
+ **ckpt_kwargs,
228
+ )
229
+ )
230
+
231
+ else:
232
+ encoder_hidden_states, hidden_states = block(
233
+ hidden_states=hidden_states,
234
+ encoder_hidden_states=encoder_hidden_states,
235
+ temb=temb,
236
+ image_rotary_emb=image_rotary_emb,
237
+ )
238
+
239
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
240
+
241
+ for index_block, block in enumerate(self.single_transformer_blocks):
242
+ if self.training and self.gradient_checkpointing:
243
+
244
+ def create_custom_forward(module, return_dict=None):
245
+ def custom_forward(*inputs):
246
+ if return_dict is not None:
247
+ return module(*inputs, return_dict=return_dict)
248
+ else:
249
+ return module(*inputs)
250
+
251
+ return custom_forward
252
+
253
+ ckpt_kwargs: Dict[str, Any] = (
254
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
255
+ )
256
+ hidden_states = torch.utils.checkpoint.checkpoint(
257
+ create_custom_forward(block),
258
+ hidden_states,
259
+ temb,
260
+ image_rotary_emb,
261
+ **ckpt_kwargs,
262
+ )
263
+
264
+ else:
265
+ hidden_states = block(
266
+ hidden_states=hidden_states,
267
+ temb=temb,
268
+ image_rotary_emb=image_rotary_emb,
269
+ )
270
+
271
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
272
+
273
+ hidden_states = self.norm_out(hidden_states, temb)
274
+ output = self.proj_out(hidden_states)
275
+
276
+ if USE_PEFT_BACKEND:
277
+ # remove `lora_scale` from each PEFT layer
278
+ unscale_lora_layers(self, lora_scale)
279
+
280
+ if not return_dict:
281
+ return (output,)
282
+
283
+ return Transformer2DModelOutput(sample=output)
external_models/TangoFlux/comfyui/web/js/playAudio.js ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { app } from "../../../scripts/app.js";
2
+ import { api } from "../../../scripts/api.js";
3
+
4
+ app.registerExtension({
5
+ name: "TangoFlux.playAudio",
6
+ async beforeRegisterNodeDef(nodeType, nodeData, app) {
7
+ if (nodeData.name === "TangoFluxVAEDecodeAndPlay") {
8
+ const originalNodeCreated = nodeType.prototype.onNodeCreated;
9
+
10
+ nodeType.prototype.onNodeCreated = async function () {
11
+ originalNodeCreated?.apply(this, arguments);
12
+ this.widgets_count = this.widgets?.length || 0;
13
+
14
+ this.addAudioWidgets = (audios) => {
15
+ if (this.widgets) {
16
+ for (let i = 0; i < this.widgets.length; i++) {
17
+ if (this.widgets[i].name.startsWith("_playaudio")) {
18
+ this.widgets[i].onRemove?.();
19
+ }
20
+ }
21
+ this.widgets.length = this.widgets_count;
22
+ }
23
+
24
+ let index = 0
25
+ for (const params of audios) {
26
+ const audioElement = document.createElement("audio");
27
+ audioElement.controls = true;
28
+
29
+ this.addDOMWidget("_playaudio" + index, "playaudio", audioElement, {
30
+ serialize: false,
31
+ hideOnZoom: false,
32
+ });
33
+ audioElement.src = api.apiURL(
34
+ `/tangoflux/playaudio?${new URLSearchParams(params)}`
35
+ );
36
+ index++
37
+ }
38
+
39
+ requestAnimationFrame(() => {
40
+ const newSize = this.computeSize();
41
+ newSize[0] = Math.max(newSize[0], this.size[0]);
42
+ newSize[1] = Math.max(newSize[1], this.size[1]);
43
+ this.onResize?.(newSize);
44
+ app.graph.setDirtyCanvas(true, false);
45
+ });
46
+ };
47
+ };
48
+
49
+ const originalNodeExecuted = nodeType.prototype.onExecuted;
50
+
51
+ nodeType.prototype.onExecuted = async function (message) {
52
+ originalNodeExecuted?.apply(this, arguments);
53
+ if (message?.audios) {
54
+ this.addAudioWidgets(message.audios);
55
+ }
56
+ };
57
+ }
58
+ },
59
+ });
external_models/TangoFlux/configs/__init__.py ADDED
File without changes
external_models/TangoFlux/configs/accelerator_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compute_environment": "LOCAL_MACHINE",
3
+ "distributed_type": "MULTI_GPU",
4
+ "main_process_port": 29512,
5
+ "downcast_bf16": false,
6
+ "machine_rank": 0,
7
+ "gpu_ids": "0,1",
8
+ "main_training_function": "main",
9
+ "mixed_precision": "no",
10
+ "num_machines": 1,
11
+ "num_processes": 2,
12
+ "rdzv_backend": "static",
13
+ "same_network": true,
14
+ "tpu_use_cluster": false,
15
+ "tpu_use_sudo": false,
16
+ "use_cpu": false
17
+ }
external_models/TangoFlux/configs/tangoflux_config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Absolute paths for different resources
3
+ paths:
4
+ train_file: "data/train.json"
5
+ val_file: "data/val.json"
6
+ test_file: "data/val.json"
7
+ resume_from_checkpoint: ""
8
+ output_dir: "outputs/"
9
+
10
+ # Training-related parameters
11
+ training:
12
+ per_device_batch_size: 4
13
+ learning_rate: 5e-4
14
+ gradient_accumulation_steps: 1
15
+ num_train_epochs: 80
16
+ num_warmup_steps: 1000
17
+ max_audio_duration: 30
18
+
19
+
20
+ # Model and optimizer parameters,
21
+ model:
22
+ num_layers: 6
23
+ num_single_layers: 18
24
+ in_channels: 64
25
+ attention_head_dim: 128
26
+ joint_attention_dim: 1024
27
+ num_attention_heads: 8
28
+ audio_seq_len: 645
29
+ max_duration: 30
30
+ uncondition: false
31
+ text_encoder_name: "google/flan-t5-large"
32
+
33
+
34
+
35
+
36
+
external_models/TangoFlux/crpo.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ python3 tangoflux/generate_crpo.py --json_path='path_to_prompt_bank.json' --sample_size=50 --model='path_to_tangoflux.safetensors' --num_samples=5 --output_dir='outputs'
2
+ python3 tangoflux/label_crpo.py --json_path='outputs/results.json' --output_dir='outputs/crpo_iteration1' --num_samples=5
external_models/TangoFlux/inference.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ from tangoflux import TangoFluxInference
3
+
4
+ model = TangoFluxInference(name="declare-lab/TangoFlux")
5
+ audio = model.generate("Hammer slowly hitting the wooden table", steps=50, duration=10)
6
+
7
+ torchaudio.save("output.wav", audio, sample_rate=44100)
external_models/TangoFlux/replicate_demo/cog.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://cog.run/yaml
3
+
4
+ build:
5
+ # set to true if your model requires a GPU
6
+ gpu: true
7
+
8
+ # a list of ubuntu apt packages to install
9
+ system_packages:
10
+ - "libgl1-mesa-glx"
11
+ - "libglib2.0-0"
12
+
13
+ # python version in the form '3.11' or '3.11.4'
14
+ python_version: "3.11"
15
+
16
+ # a list of packages in the format <package-name>==<version>
17
+ python_packages:
18
+ - torch==2.4.0
19
+ - torchaudio==2.4.0
20
+ - torchlibrosa==0.1.0
21
+ - torchvision==0.19.0
22
+ - transformers==4.44.0
23
+ - diffusers==0.30.0
24
+ - accelerate==0.34.2
25
+ - datasets==2.21.0
26
+ - librosa
27
+ - ipython
28
+
29
+ run:
30
+ - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
31
+ predict: "predict.py:Predictor"
external_models/TangoFlux/replicate_demo/predict.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://cog.run/python
3
+
4
+ import os
5
+ import subprocess
6
+ import time
7
+ import json
8
+ from cog import BasePredictor, Input, Path
9
+ from diffusers import AutoencoderOobleck
10
+ import soundfile as sf
11
+ from safetensors.torch import load_file
12
+ from huggingface_hub import snapshot_download
13
+ from tangoflux.model import TangoFlux
14
+ from tangoflux import TangoFluxInference
15
+
16
+ MODEL_CACHE = "model_cache"
17
+ MODEL_URL = (
18
+ "https://weights.replicate.delivery/default/declare-lab/TangoFlux/model_cache.tar"
19
+ )
20
+
21
+
22
+ class CachedTangoFluxInference(TangoFluxInference):
23
+ ## load the weights from replicate.delivery for faster booting
24
+ def __init__(self, name="declare-lab/TangoFlux", device="cuda", cached_paths=None):
25
+ if cached_paths:
26
+ paths = cached_paths
27
+ else:
28
+ paths = snapshot_download(repo_id=name)
29
+
30
+ self.vae = AutoencoderOobleck()
31
+ vae_weights = load_file(f"{paths}/vae.safetensors")
32
+ self.vae.load_state_dict(vae_weights)
33
+ weights = load_file(f"{paths}/tangoflux.safetensors")
34
+
35
+ with open(f"{paths}/config.json", "r") as f:
36
+ config = json.load(f)
37
+ self.model = TangoFlux(config)
38
+ self.model.load_state_dict(weights, strict=False)
39
+ self.vae.to(device)
40
+ self.model.to(device)
41
+
42
+
43
+ def download_weights(url, dest):
44
+ start = time.time()
45
+ print("downloading url: ", url)
46
+ print("downloading to: ", dest)
47
+ subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
48
+ print("downloading took: ", time.time() - start)
49
+
50
+
51
+ class Predictor(BasePredictor):
52
+ def setup(self) -> None:
53
+ """Load the model into memory to make running multiple predictions efficient"""
54
+
55
+ if not os.path.exists(MODEL_CACHE):
56
+ print("downloading")
57
+ download_weights(MODEL_URL, MODEL_CACHE)
58
+
59
+ self.model = CachedTangoFluxInference(
60
+ cached_paths=f"{MODEL_CACHE}/declare-lab/TangoFlux"
61
+ )
62
+
63
+ def predict(
64
+ self,
65
+ prompt: str = Input(
66
+ description="Input prompt", default="Hammer slowly hitting the wooden table"
67
+ ),
68
+ duration: int = Input(
69
+ description="Duration of the output audio in seconds", default=10
70
+ ),
71
+ steps: int = Input(
72
+ description="Number of inference steps", ge=1, le=200, default=25
73
+ ),
74
+ guidance_scale: float = Input(
75
+ description="Scale for classifier-free guidance", ge=1, le=20, default=4.5
76
+ ),
77
+ ) -> Path:
78
+ """Run a single prediction on the model"""
79
+
80
+ audio = self.model.generate(
81
+ prompt,
82
+ steps=steps,
83
+ guidance_scale=guidance_scale,
84
+ duration=duration,
85
+ )
86
+ audio_numpy = audio.numpy()
87
+ out_path = "/tmp/out.wav"
88
+
89
+ sf.write(
90
+ out_path, audio_numpy.T, samplerate=self.model.vae.config.sampling_rate
91
+ )
92
+ return Path(out_path)
external_models/TangoFlux/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.0
2
+ torchaudio==2.4.0
3
+ torchlibrosa==0.1.0
4
+ torchvision==0.19.0
5
+ transformers==4.44.0
6
+ diffusers==0.30.0
7
+ accelerate==0.34.2
8
+ datasets==2.21.0
9
+ librosa
10
+ tqdm
11
+ wandb
12
+
external_models/TangoFlux/setup.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ setup(
4
+ name="tangoflux",
5
+ description="TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching",
6
+ version="0.1.0",
7
+ packages=["tangoflux"],
8
+ install_requires=[
9
+ "torch==2.4.0",
10
+ "torchaudio==2.4.0",
11
+ "torchlibrosa==0.1.0",
12
+ "torchvision==0.19.0",
13
+ "transformers==4.44.0",
14
+ "diffusers==0.30.0",
15
+ "accelerate==0.34.2",
16
+ "datasets==2.21.0",
17
+ "librosa",
18
+ "tqdm",
19
+ "wandb",
20
+ "click",
21
+ "gradio",
22
+ "torchaudio",
23
+ ],
24
+ entry_points={
25
+ "console_scripts": [
26
+ "tangoflux=tangoflux.cli:main",
27
+ "tangoflux-demo=tangoflux.demo:main",
28
+ ],
29
+ },
30
+ )
external_models/TangoFlux/tangoflux/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoencoderOobleck
2
+ import torch
3
+ from transformers import T5EncoderModel, T5TokenizerFast
4
+ from diffusers import FluxTransformer2DModel
5
+ from torch import nn
6
+ from typing import List
7
+ from diffusers import FlowMatchEulerDiscreteScheduler
8
+ from diffusers.training_utils import compute_density_for_timestep_sampling
9
+ import copy
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ from tangoflux.model import TangoFlux
13
+ from huggingface_hub import snapshot_download
14
+ from tqdm import tqdm
15
+ from typing import Optional, Union, List
16
+ from datasets import load_dataset, Audio
17
+ from math import pi
18
+ import json
19
+ import inspect
20
+ import yaml
21
+ from safetensors.torch import load_file
22
+
23
+
24
+ class TangoFluxInference:
25
+
26
+ def __init__(
27
+ self,
28
+ name="declare-lab/TangoFlux",
29
+ device="cuda" if torch.cuda.is_available() else "cpu",
30
+ ):
31
+
32
+ self.vae = AutoencoderOobleck()
33
+
34
+ paths = snapshot_download(repo_id=name)
35
+ vae_weights = load_file("{}/vae.safetensors".format(paths))
36
+ self.vae.load_state_dict(vae_weights)
37
+ weights = load_file("{}/tangoflux.safetensors".format(paths))
38
+
39
+ with open("{}/config.json".format(paths), "r") as f:
40
+ config = json.load(f)
41
+ self.model = TangoFlux(config)
42
+ self.model.load_state_dict(weights, strict=False)
43
+ # _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
44
+ self.vae.to(device)
45
+ self.model.to(device)
46
+
47
+ def generate(self, prompt, steps=25, duration=10, guidance_scale=4.5):
48
+
49
+ with torch.no_grad():
50
+ latents = self.model.inference_flow(
51
+ prompt,
52
+ duration=duration,
53
+ num_inference_steps=steps,
54
+ guidance_scale=guidance_scale,
55
+ )
56
+
57
+ wave = self.vae.decode(latents.transpose(2, 1)).sample.cpu()[0]
58
+ waveform_end = int(duration * self.vae.config.sampling_rate)
59
+ wave = wave[:, :waveform_end]
60
+ return wave
external_models/TangoFlux/tangoflux/cli.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import torchaudio
3
+ from tangoflux import TangoFluxInference
4
+
5
+ @click.command()
6
+ @click.argument('prompt')
7
+ @click.argument('output_file')
8
+ @click.option('--duration', default=10, type=int, help='Duration in seconds (1-30)')
9
+ @click.option('--steps', default=50, type=int, help='Number of inference steps (10-100)')
10
+ def main(prompt: str, output_file: str, duration: int, steps: int):
11
+ """Generate audio from text using TangoFlux.
12
+
13
+ Args:
14
+ prompt: Text description of the audio to generate
15
+ output_file: Path to save the generated audio file
16
+ duration: Duration of generated audio in seconds (default: 10)
17
+ steps: Number of inference steps (default: 50)
18
+ """
19
+ if not 1 <= duration <= 30:
20
+ raise click.BadParameter('Duration must be between 1 and 30 seconds')
21
+ if not 10 <= steps <= 100:
22
+ raise click.BadParameter('Steps must be between 10 and 100')
23
+
24
+ model = TangoFluxInference(name="declare-lab/TangoFlux")
25
+ audio = model.generate(prompt, steps=steps, duration=duration)
26
+ torchaudio.save(output_file, audio, sample_rate=44100)
27
+
28
+ if __name__ == '__main__':
29
+ main()
external_models/TangoFlux/tangoflux/demo.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torchaudio
3
+ import click
4
+ import tempfile
5
+ from tangoflux import TangoFluxInference
6
+
7
+ model = TangoFluxInference(name="declare-lab/TangoFlux")
8
+
9
+
10
+ def generate_audio(prompt, duration, steps):
11
+ audio = model.generate(prompt, steps=steps, duration=duration)
12
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
13
+ torchaudio.save(f.name, audio, sample_rate=44100)
14
+ return f.name
15
+
16
+
17
+ examples = [
18
+ ["Hammer slowly hitting the wooden table", 10, 50],
19
+ ["Gentle rain falling on a tin roof", 15, 50],
20
+ ["Wind chimes tinkling in a light breeze", 10, 50],
21
+ ["Rhythmic wooden table tapping overlaid with steady water pouring sound", 10, 50],
22
+ ]
23
+
24
+ with gr.Blocks(title="TangoFlux Text-to-Audio Generation") as demo:
25
+ gr.Markdown("# TangoFlux Text-to-Audio Generation")
26
+ gr.Markdown("Generate audio from text descriptions using TangoFlux")
27
+
28
+ with gr.Row():
29
+ with gr.Column():
30
+ prompt = gr.Textbox(
31
+ label="Text Prompt", placeholder="Enter your audio description..."
32
+ )
33
+ duration = gr.Slider(
34
+ minimum=1, maximum=30, value=10, step=1, label="Duration (seconds)"
35
+ )
36
+ steps = gr.Slider(
37
+ minimum=10, maximum=100, value=50, step=10, label="Number of Steps"
38
+ )
39
+ generate_btn = gr.Button("Generate Audio")
40
+
41
+ with gr.Column():
42
+ audio_output = gr.Audio(label="Generated Audio")
43
+
44
+ generate_btn.click(
45
+ fn=generate_audio, inputs=[prompt, duration, steps], outputs=audio_output
46
+ )
47
+
48
+ gr.Examples(
49
+ examples=examples,
50
+ inputs=[prompt, duration, steps],
51
+ outputs=audio_output,
52
+ fn=generate_audio,
53
+ )
54
+
55
+ @click.command()
56
+ @click.option('--host', default='127.0.0.1', help='Host to bind to')
57
+ @click.option('--port', default=None, help='Port to bind to')
58
+ @click.option('--share', is_flag=True, help='Enable sharing via Gradio')
59
+ def main(host, port, share):
60
+ demo.queue().launch(server_name=host, server_port=port, share=share)
61
+
62
+ if __name__ == "__main__":
63
+ main()
external_models/TangoFlux/tangoflux/generate_crpo_dataset.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import torch
5
+ import argparse
6
+ import multiprocessing
7
+ from tqdm import tqdm
8
+ from safetensors.torch import load_file
9
+ from diffusers import AutoencoderOobleck
10
+ import soundfile as sf
11
+ from model import TangoFlux
12
+ import random
13
+
14
+
15
+
16
+
17
+ def generate_audio_chunk(args, chunk, gpu_id, output_dir, samplerate, return_dict, process_id):
18
+ """
19
+ Function to generate audio for a chunk of text prompts on a specific GPU.
20
+ """
21
+ try:
22
+ device = f"cuda:{gpu_id}"
23
+ torch.cuda.set_device(device)
24
+ print(f"Process {process_id}: Using device {device}")
25
+
26
+ # Initialize model
27
+ config = {
28
+ 'num_layers': 6,
29
+ 'num_single_layers': 18,
30
+ 'in_channels': 64,
31
+ 'attention_head_dim': 128,
32
+ 'joint_attention_dim': 1024,
33
+ 'num_attention_heads': 8,
34
+ 'audio_seq_len': 645,
35
+ 'max_duration': 30,
36
+ 'uncondition': False,
37
+ 'text_encoder_name': "google/flan-t5-large"
38
+ }
39
+
40
+ model = TangoFlux(config)
41
+ print(f"Process {process_id}: Loading model from {args.model} on {device}")
42
+ w1 = load_file(args.model)
43
+ model.load_state_dict(w1, strict=False)
44
+ model = model.to(device)
45
+ model.eval()
46
+
47
+ # Initialize VAE
48
+ vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0", subfolder='vae')
49
+ vae = vae.to(device)
50
+ vae.eval()
51
+
52
+ outputs = []
53
+
54
+ # Corrected loop using enumerate properly with tqdm
55
+ for idx, item in tqdm(enumerate(chunk), total=len(chunk), desc=f"GPU {gpu_id}"):
56
+ text = item['captions']
57
+
58
+
59
+ if os.path.exists(os.path.join(output_dir, f"id_{item['id']}_sample1.wav")):
60
+ print("Exist! Skipping!")
61
+ continue
62
+ with torch.no_grad():
63
+ latent = model.inference_flow(
64
+ text,
65
+ num_inference_steps=args.num_steps,
66
+ guidance_scale=args.guidance_scale,
67
+ duration=10,
68
+ num_samples_per_prompt=args.num_samples
69
+ )
70
+
71
+ #waveform_end = int(duration * vae.config.sampling_rate)
72
+ latent = latent[:, :220, :] ## 220 correspond to the latent length of audiocaps encoded with this vae. You can modify this
73
+ wave = vae.decode(latent.transpose(2, 1)).sample.cpu()
74
+
75
+ for i in range(args.num_samples):
76
+ filename = f"id_{item['id']}_sample{i+1}.wav"
77
+ filepath = os.path.join(output_dir, filename)
78
+
79
+ sf.write(filepath, wave[i].T, samplerate)
80
+ outputs.append({
81
+ "id": item['id'],
82
+ "sample": i + 1,
83
+ "path": filepath,
84
+ "captions": text
85
+ })
86
+
87
+ return_dict[process_id] = outputs
88
+ print(f"Process {process_id}: Completed processing on GPU {gpu_id}")
89
+
90
+ except Exception as e:
91
+ print(f"Process {process_id}: Error on GPU {gpu_id}: {e}")
92
+ return_dict[process_id] = []
93
+
94
+ def split_into_chunks(data, num_chunks):
95
+ """
96
+ Splits data into num_chunks approximately equal parts.
97
+ """
98
+ avg = len(data) // num_chunks
99
+ chunks = []
100
+ for i in range(num_chunks):
101
+ start = i * avg
102
+ # Ensure the last chunk takes the remainder
103
+ end = (i + 1) * avg if i != num_chunks - 1 else len(data)
104
+ chunks.append(data[start:end])
105
+ return chunks
106
+
107
+ def main():
108
+ parser = argparse.ArgumentParser(description="Generate audio using multiple GPUs")
109
+ parser.add_argument('--num_steps', type=int, default=50, help='Number of inference steps')
110
+ parser.add_argument('--model', type=str, required=True, help='Path to tangoflux weights')
111
+ parser.add_argument('--num_samples', type=int, default=5, help='Number of samples per prompt')
112
+ parser.add_argument('--output_dir', type=str, default='output', help='Directory to save outputs')
113
+ parser.add_argument('--json_path', type=str, required=True, help='Path to input JSON file')
114
+ parser.add_argument('--sample_size', type=int, default=20000, help='Number of prompts to sample for CRPO')
115
+ parser.add_argument('--guidance_scale', type=float, default=4.5, help='Guidance scale used for generation')
116
+ args = parser.parse_args()
117
+
118
+ # Check GPU availability
119
+ num_gpus = torch.cuda.device_count()
120
+ sample_size = args.sample_size
121
+
122
+
123
+ # Load JSON data
124
+ import json
125
+ try:
126
+ with open(args.json_path, 'r') as f:
127
+ data = json.load(f)
128
+
129
+ except Exception as e:
130
+ print(f"Error loading JSON file {args.json_path}: {e}")
131
+ return
132
+
133
+ if not isinstance(data, list):
134
+ print("Error: JSON data is not a list.")
135
+ return
136
+
137
+ if len(data) < sample_size:
138
+ print(f"Warning: JSON data contains only {len(data)} items. Sampling all available data.")
139
+ sampled = data
140
+ else:
141
+ sampled = random.sample(data, sample_size)
142
+
143
+ # Split data into chunks based on available GPUs
144
+ random.shuffle(sampled)
145
+ chunks = split_into_chunks(sampled, num_gpus)
146
+
147
+ # Prepare output directory
148
+ os.makedirs(args.output_dir, exist_ok=True)
149
+ samplerate = 44100
150
+
151
+ # Manager for inter-process communication
152
+ manager = multiprocessing.Manager()
153
+ return_dict = manager.dict()
154
+
155
+ processes = []
156
+ for i in range(num_gpus):
157
+ p = multiprocessing.Process(
158
+ target=generate_audio_chunk,
159
+ args=(
160
+ args,
161
+ chunks[i],
162
+ i, # GPU ID
163
+ args.output_dir,
164
+ samplerate,
165
+ return_dict,
166
+ i, # Process ID
167
+
168
+ )
169
+ )
170
+ processes.append(p)
171
+ p.start()
172
+ print(f"Started process {i} on GPU {i}")
173
+
174
+ for p in processes:
175
+ p.join()
176
+ print(f"Process {p.pid} has finished.")
177
+
178
+ # Aggregate results
179
+
180
+
181
+
182
+
183
+
184
+
185
+ audio_info_list = [
186
+ [{
187
+ "path": f"{args.output_dir}/id_{sampled[j]['id']}_sample{i}.wav",
188
+ "duration": sampled[j]["duration"],
189
+ "captions": sampled[j]["captions"]
190
+ }
191
+ for i in range(1, args.num_samples+1) ] for j in range(sample_size)
192
+ ]
193
+
194
+ #print(audio_info_list)
195
+
196
+ with open(f'{args.output_dir}/results.json','w') as f:
197
+ json.dump(audio_info_list,f)
198
+
199
+ print(f"All audio samples have been generated and saved to {args.output_dir}")
200
+
201
+
202
+ if __name__ == "__main__":
203
+ multiprocessing.set_start_method('spawn')
204
+ main()
external_models/TangoFlux/tangoflux/label_crpo.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import torch
5
+ import laion_clap
6
+ import numpy as np
7
+ import multiprocessing
8
+ from tqdm import tqdm
9
+
10
+ def parse_args():
11
+ parser = argparse.ArgumentParser(
12
+ description="Labelling clap score for crpo dataset"
13
+ )
14
+ parser.add_argument(
15
+ "--num_samples", type=int, default=5,
16
+ help="Number of audio samples per prompt"
17
+ )
18
+ parser.add_argument(
19
+ "--json_path", type=str, required=True,
20
+ help="Path to input JSON file"
21
+ )
22
+ parser.add_argument(
23
+ "--output_dir", type=str, required=True,
24
+ help="Directory to save the final JSON with CLAP scores"
25
+ )
26
+ return parser.parse_args()
27
+
28
+ #python3 label_clap.py --json_path=/mnt/data/chiayu/crpo/crpo_iteration1/results.json --output_dir=/mnt/data/chiayu/crpo/crpo_iteration1
29
+ @torch.no_grad()
30
+ def compute_clap(model, audio_files, text_data):
31
+ # Compute audio and text embeddings, then compute the dot product (CLAP score)
32
+ audio_embed = model.get_audio_embedding_from_filelist(x=audio_files, use_tensor=True)
33
+ text_embed = model.get_text_embedding(text_data, use_tensor=True)
34
+ return audio_embed @ text_embed.T
35
+
36
+ def process_chunk(args, chunk, gpu_id, return_dict, process_id):
37
+ """
38
+ Process a chunk of the data on a specific GPU.
39
+ Loads the CLAP model on the designated device, then for each item in the chunk,
40
+ computes the CLAP scores and attaches them to the data.
41
+ """
42
+ try:
43
+ device = f"cuda:{gpu_id}"
44
+ torch.cuda.set_device(device)
45
+ print(f"Process {process_id}: Using device {device}")
46
+
47
+ # Initialize the CLAP model on this GPU
48
+ model = laion_clap.CLAP_Module(enable_fusion=False)
49
+ model.to(device)
50
+ model.load_ckpt()
51
+ model.eval()
52
+
53
+ for j, item in enumerate(tqdm(chunk, desc=f"GPU {gpu_id}")):
54
+ # Each item is assumed to be a list of samples.
55
+ # Skip if already computed.
56
+ if 'clap_score' in item[0]:
57
+ continue
58
+
59
+ # Collect audio file paths and text data (using the first caption)
60
+ audio_files = [item[i]['path'] for i in range(args.num_samples)]
61
+ text_data = [item[0]['captions']]
62
+
63
+ try:
64
+ clap_scores = compute_clap(model, audio_files, text_data)
65
+ except Exception as e:
66
+ print(f"Error processing item index {j} on GPU {gpu_id}: {e}")
67
+ continue
68
+
69
+ # Attach the computed score to each sample in the item
70
+ for k in range(args.num_samples):
71
+ item[k]['clap_score'] = np.round(clap_scores[k].item(), 3)
72
+
73
+ return_dict[process_id] = chunk
74
+ print(f"Process {process_id}: Completed processing on GPU {gpu_id}")
75
+ except Exception as e:
76
+ print(f"Process {process_id}: Error on GPU {gpu_id}: {e}")
77
+ return_dict[process_id] = []
78
+
79
+ def split_into_chunks(data, num_chunks):
80
+ """
81
+ Splits data into num_chunks approximately equal parts.
82
+ """
83
+ avg = len(data) // num_chunks
84
+ chunks = []
85
+ for i in range(num_chunks):
86
+ start = i * avg
87
+ # Ensure the last chunk takes the remainder of the data
88
+ end = (i + 1) * avg if i != num_chunks - 1 else len(data)
89
+ chunks.append(data[start:end])
90
+ return chunks
91
+
92
+ def main():
93
+ args = parse_args()
94
+
95
+ # Load data from JSON and slice by start/end if provided
96
+ with open(args.json_path, 'r') as f:
97
+ data = json.load(f)
98
+
99
+ # Check GPU availability and split data accordingly
100
+ num_gpus = torch.cuda.device_count()
101
+
102
+ print(f"Found {num_gpus} GPUs. Splitting data into {num_gpus} chunks.")
103
+ chunks = split_into_chunks(data, num_gpus)
104
+
105
+ # Prepare output directory
106
+ os.makedirs(args.output_dir, exist_ok=True)
107
+
108
+ # Create a manager dict to collect results from all processes
109
+ manager = multiprocessing.Manager()
110
+ return_dict = manager.dict()
111
+ processes = []
112
+
113
+ for i in range(num_gpus):
114
+ p = multiprocessing.Process(
115
+ target=process_chunk,
116
+ args=(args, chunks[i], i, return_dict, i)
117
+ )
118
+ processes.append(p)
119
+ p.start()
120
+ print(f"Started process {i} on GPU {i}")
121
+
122
+ for p in processes:
123
+ p.join()
124
+ print(f"Process {p.pid} has finished.")
125
+
126
+ # Aggregate all chunks back into a single list
127
+ combined_data = []
128
+ for i in range(num_gpus):
129
+ combined_data.extend(return_dict[i])
130
+
131
+ # Save the combined results to a single JSON file
132
+ output_file = f"{args.output_dir}/clap_scores.json"
133
+ with open(output_file, 'w') as f:
134
+ json.dump(combined_data, f)
135
+ print(f"All CLAP scores have been computed and saved to {output_file}")
136
+
137
+ max_item = [max(x, key=lambda item: item['clap_score']) for x in combined_data]
138
+ min_item = [min(x, key=lambda item: item['clap_score']) for x in combined_data]
139
+
140
+ crpo_dataset = []
141
+ for chosen,reject in zip(max_item,min_item):
142
+ crpo_dataset.append({"captions": chosen['captions'],
143
+ "duration": chosen['duration'],
144
+ "chosen": chosen['path'],
145
+ "reject": reject['path']})
146
+
147
+ with open(f"{args.output_dir}/train.json",'w') as f:
148
+ json.dump(crpo_dataset,f)
149
+
150
+
151
+ if __name__ == '__main__':
152
+ multiprocessing.set_start_method('spawn')
153
+ main()
external_models/TangoFlux/tangoflux/model.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5EncoderModel, T5TokenizerFast
2
+ import torch
3
+ from diffusers import FluxTransformer2DModel
4
+ from torch import nn
5
+ import random
6
+ from typing import List
7
+ from diffusers import FlowMatchEulerDiscreteScheduler
8
+ from diffusers.training_utils import compute_density_for_timestep_sampling
9
+ import copy
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ from typing import Optional, Union, List
15
+ from datasets import load_dataset, Audio
16
+ from math import pi
17
+ import inspect
18
+ import yaml
19
+
20
+
21
+ class StableAudioPositionalEmbedding(nn.Module):
22
+ """Used for continuous time
23
+ Adapted from Stable Audio Open.
24
+ """
25
+
26
+ def __init__(self, dim: int):
27
+ super().__init__()
28
+ assert (dim % 2) == 0
29
+ half_dim = dim // 2
30
+ self.weights = nn.Parameter(torch.randn(half_dim))
31
+
32
+ def forward(self, times: torch.Tensor) -> torch.Tensor:
33
+ times = times[..., None]
34
+ freqs = times * self.weights[None] * 2 * pi
35
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
36
+ fouriered = torch.cat((times, fouriered), dim=-1)
37
+ return fouriered
38
+
39
+
40
+ class DurationEmbedder(nn.Module):
41
+ """
42
+ A simple linear projection model to map numbers to a latent space.
43
+
44
+ Code is adapted from
45
+ https://github.com/Stability-AI/stable-audio-tools
46
+
47
+ Args:
48
+ number_embedding_dim (`int`):
49
+ Dimensionality of the number embeddings.
50
+ min_value (`int`):
51
+ The minimum value of the seconds number conditioning modules.
52
+ max_value (`int`):
53
+ The maximum value of the seconds number conditioning modules
54
+ internal_dim (`int`):
55
+ Dimensionality of the intermediate number hidden states.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ number_embedding_dim,
61
+ min_value,
62
+ max_value,
63
+ internal_dim: Optional[int] = 256,
64
+ ):
65
+ super().__init__()
66
+ self.time_positional_embedding = nn.Sequential(
67
+ StableAudioPositionalEmbedding(internal_dim),
68
+ nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim),
69
+ )
70
+
71
+ self.number_embedding_dim = number_embedding_dim
72
+ self.min_value = min_value
73
+ self.max_value = max_value
74
+ self.dtype = torch.float32
75
+
76
+ def forward(
77
+ self,
78
+ floats: torch.Tensor,
79
+ ):
80
+ floats = floats.clamp(self.min_value, self.max_value)
81
+
82
+ normalized_floats = (floats - self.min_value) / (
83
+ self.max_value - self.min_value
84
+ )
85
+
86
+ # Cast floats to same type as embedder
87
+ embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
88
+ normalized_floats = normalized_floats.to(embedder_dtype)
89
+
90
+ embedding = self.time_positional_embedding(normalized_floats)
91
+ float_embeds = embedding.view(-1, 1, self.number_embedding_dim)
92
+
93
+ return float_embeds
94
+
95
+
96
+ def retrieve_timesteps(
97
+ scheduler,
98
+ num_inference_steps: Optional[int] = None,
99
+ device: Optional[Union[str, torch.device]] = None,
100
+ timesteps: Optional[List[int]] = None,
101
+ sigmas: Optional[List[float]] = None,
102
+ **kwargs,
103
+ ):
104
+
105
+ if timesteps is not None and sigmas is not None:
106
+ raise ValueError(
107
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
108
+ )
109
+ if timesteps is not None:
110
+ accepts_timesteps = "timesteps" in set(
111
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
112
+ )
113
+ if not accepts_timesteps:
114
+ raise ValueError(
115
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
116
+ f" timestep schedules. Please check whether you are using the correct scheduler."
117
+ )
118
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ num_inference_steps = len(timesteps)
121
+ elif sigmas is not None:
122
+ accept_sigmas = "sigmas" in set(
123
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
124
+ )
125
+ if not accept_sigmas:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ else:
134
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ return timesteps, num_inference_steps
137
+
138
+
139
+ class TangoFlux(nn.Module):
140
+
141
+ def __init__(self, config, text_encoder_dir=None, initialize_reference_model=False,):
142
+
143
+ super().__init__()
144
+
145
+ self.num_layers = config.get("num_layers", 6)
146
+ self.num_single_layers = config.get("num_single_layers", 18)
147
+ self.in_channels = config.get("in_channels", 64)
148
+ self.attention_head_dim = config.get("attention_head_dim", 128)
149
+ self.joint_attention_dim = config.get("joint_attention_dim", 1024)
150
+ self.num_attention_heads = config.get("num_attention_heads", 8)
151
+ self.audio_seq_len = config.get("audio_seq_len", 645)
152
+ self.max_duration = config.get("max_duration", 30)
153
+ self.uncondition = config.get("uncondition", False)
154
+ self.text_encoder_name = config.get("text_encoder_name", "google/flan-t5-large")
155
+
156
+ self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
157
+ self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
158
+ self.max_text_seq_len = 64
159
+ self.text_encoder = T5EncoderModel.from_pretrained(
160
+ text_encoder_dir if text_encoder_dir is not None else self.text_encoder_name
161
+ )
162
+ self.tokenizer = T5TokenizerFast.from_pretrained(
163
+ text_encoder_dir if text_encoder_dir is not None else self.text_encoder_name
164
+ )
165
+ self.text_embedding_dim = self.text_encoder.config.d_model
166
+
167
+ self.fc = nn.Sequential(
168
+ nn.Linear(self.text_embedding_dim, self.joint_attention_dim), nn.ReLU()
169
+ )
170
+ self.duration_emebdder = DurationEmbedder(
171
+ self.text_embedding_dim, min_value=0, max_value=self.max_duration
172
+ )
173
+
174
+ self.transformer = FluxTransformer2DModel(
175
+ in_channels=self.in_channels,
176
+ num_layers=self.num_layers,
177
+ num_single_layers=self.num_single_layers,
178
+ attention_head_dim=self.attention_head_dim,
179
+ num_attention_heads=self.num_attention_heads,
180
+ joint_attention_dim=self.joint_attention_dim,
181
+ pooled_projection_dim=self.text_embedding_dim,
182
+ guidance_embeds=False,
183
+ )
184
+
185
+ self.beta_dpo = 2000 ## this is used for dpo training
186
+
187
+ def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32):
188
+ device = self.text_encoder.device
189
+ sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
190
+
191
+ schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device)
192
+ timesteps = timesteps.to(device)
193
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
194
+
195
+ sigma = sigmas[step_indices].flatten()
196
+ while len(sigma.shape) < n_dim:
197
+ sigma = sigma.unsqueeze(-1)
198
+ return sigma
199
+
200
+ def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1):
201
+ device = self.text_encoder.device
202
+ batch = self.tokenizer(
203
+ prompt,
204
+ max_length=self.tokenizer.model_max_length,
205
+ padding=True,
206
+ truncation=True,
207
+ return_tensors="pt",
208
+ )
209
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
210
+ device
211
+ )
212
+
213
+ with torch.no_grad():
214
+ prompt_embeds = self.text_encoder(
215
+ input_ids=input_ids, attention_mask=attention_mask
216
+ )[0]
217
+
218
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
219
+ attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
220
+
221
+ # get unconditional embeddings for classifier free guidance
222
+ uncond_tokens = [""]
223
+
224
+ max_length = prompt_embeds.shape[1]
225
+ uncond_batch = self.tokenizer(
226
+ uncond_tokens,
227
+ max_length=max_length,
228
+ padding="max_length",
229
+ truncation=True,
230
+ return_tensors="pt",
231
+ )
232
+ uncond_input_ids = uncond_batch.input_ids.to(device)
233
+ uncond_attention_mask = uncond_batch.attention_mask.to(device)
234
+
235
+ with torch.no_grad():
236
+ negative_prompt_embeds = self.text_encoder(
237
+ input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
238
+ )[0]
239
+
240
+ negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(
241
+ num_samples_per_prompt, 0
242
+ )
243
+ uncond_attention_mask = uncond_attention_mask.repeat_interleave(
244
+ num_samples_per_prompt, 0
245
+ )
246
+
247
+ # For classifier free guidance, we need to do two forward passes.
248
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
249
+
250
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
251
+ prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
252
+ boolean_prompt_mask = (prompt_mask == 1).to(device)
253
+
254
+ return prompt_embeds, boolean_prompt_mask
255
+
256
+ @torch.no_grad()
257
+ def encode_text(self, prompt):
258
+ device = self.text_encoder.device
259
+ batch = self.tokenizer(
260
+ prompt,
261
+ max_length=self.max_text_seq_len,
262
+ padding=True,
263
+ truncation=True,
264
+ return_tensors="pt",
265
+ )
266
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
267
+ device
268
+ )
269
+
270
+ encoder_hidden_states = self.text_encoder(
271
+ input_ids=input_ids, attention_mask=attention_mask
272
+ )[0]
273
+
274
+ boolean_encoder_mask = (attention_mask == 1).to(device)
275
+
276
+ return encoder_hidden_states, boolean_encoder_mask
277
+
278
+ def encode_duration(self, duration):
279
+ return self.duration_emebdder(duration)
280
+
281
+ @torch.no_grad()
282
+ def inference_flow(
283
+ self,
284
+ prompt,
285
+ num_inference_steps=50,
286
+ timesteps=None,
287
+ guidance_scale=3,
288
+ duration=10,
289
+ seed=0,
290
+ disable_progress=False,
291
+ num_samples_per_prompt=1,
292
+ callback_on_step_end=None,
293
+ ):
294
+ """Only tested for single inference. Haven't test for batch inference"""
295
+
296
+ torch.manual_seed(seed)
297
+ if torch.cuda.is_available():
298
+ torch.cuda.manual_seed(seed)
299
+ torch.cuda.manual_seed_all(seed)
300
+ torch.backends.cudnn.deterministic = True
301
+
302
+ bsz = num_samples_per_prompt
303
+ device = self.transformer.device
304
+ scheduler = self.noise_scheduler
305
+
306
+ if not isinstance(prompt, list):
307
+ prompt = [prompt]
308
+ if not isinstance(duration, torch.Tensor):
309
+ duration = torch.tensor([duration], device=device)
310
+ classifier_free_guidance = guidance_scale > 1.0
311
+ duration_hidden_states = self.encode_duration(duration)
312
+ if classifier_free_guidance:
313
+ bsz = 2 * num_samples_per_prompt
314
+
315
+ encoder_hidden_states, boolean_encoder_mask = (
316
+ self.encode_text_classifier_free(
317
+ prompt, num_samples_per_prompt=num_samples_per_prompt
318
+ )
319
+ )
320
+ duration_hidden_states = duration_hidden_states.repeat(bsz, 1, 1)
321
+
322
+ else:
323
+
324
+ encoder_hidden_states, boolean_encoder_mask = self.encode_text(
325
+ prompt, num_samples_per_prompt=num_samples_per_prompt
326
+ )
327
+
328
+ mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(
329
+ encoder_hidden_states
330
+ )
331
+ masked_data = torch.where(
332
+ mask_expanded, encoder_hidden_states, torch.tensor(float("nan"))
333
+ )
334
+
335
+ pooled = torch.nanmean(masked_data, dim=1)
336
+ pooled_projection = self.fc(pooled)
337
+
338
+ encoder_hidden_states = torch.cat(
339
+ [encoder_hidden_states, duration_hidden_states], dim=1
340
+ ) ## (bs,seq_len,dim)
341
+
342
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
343
+ timesteps, num_inference_steps = retrieve_timesteps(
344
+ scheduler, num_inference_steps, device, timesteps, sigmas
345
+ )
346
+
347
+ latents = torch.randn(num_samples_per_prompt, self.audio_seq_len, 64)
348
+ weight_dtype = latents.dtype
349
+
350
+ progress_bar = tqdm(range(num_inference_steps), disable=disable_progress)
351
+
352
+ txt_ids = torch.zeros(bsz, encoder_hidden_states.shape[1], 3).to(device)
353
+ audio_ids = (
354
+ torch.arange(self.audio_seq_len)
355
+ .unsqueeze(0)
356
+ .unsqueeze(-1)
357
+ .repeat(bsz, 1, 3)
358
+ .to(device)
359
+ )
360
+
361
+ timesteps = timesteps.to(device)
362
+ latents = latents.to(device)
363
+ encoder_hidden_states = encoder_hidden_states.to(device)
364
+
365
+ for i, t in enumerate(timesteps):
366
+
367
+ latents_input = (
368
+ torch.cat([latents] * 2) if classifier_free_guidance else latents
369
+ )
370
+
371
+ noise_pred = self.transformer(
372
+ hidden_states=latents_input,
373
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
374
+ timestep=torch.tensor([t / 1000], device=device),
375
+ guidance=None,
376
+ pooled_projections=pooled_projection,
377
+ encoder_hidden_states=encoder_hidden_states,
378
+ txt_ids=txt_ids,
379
+ img_ids=audio_ids,
380
+ return_dict=False,
381
+ )[0]
382
+
383
+ if classifier_free_guidance:
384
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
385
+ noise_pred = noise_pred_uncond + guidance_scale * (
386
+ noise_pred_text - noise_pred_uncond
387
+ )
388
+
389
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
390
+
391
+ progress_bar.update(1)
392
+
393
+ if callback_on_step_end is not None:
394
+ callback_on_step_end()
395
+
396
+ return latents
397
+
398
+ def forward(self, latents, prompt, duration=torch.tensor([10]), sft=True):
399
+
400
+ device = latents.device
401
+ audio_seq_length = self.audio_seq_len
402
+ bsz = latents.shape[0]
403
+
404
+ encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
405
+ duration_hidden_states = self.encode_duration(duration)
406
+
407
+ mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(
408
+ encoder_hidden_states
409
+ )
410
+ masked_data = torch.where(
411
+ mask_expanded, encoder_hidden_states, torch.tensor(float("nan"))
412
+ )
413
+ pooled = torch.nanmean(masked_data, dim=1)
414
+ pooled_projection = self.fc(pooled)
415
+
416
+ ## Add duration hidden states to encoder hidden states
417
+ encoder_hidden_states = torch.cat(
418
+ [encoder_hidden_states, duration_hidden_states], dim=1
419
+ ) ## (bs,seq_len,dim)
420
+
421
+ txt_ids = torch.zeros(bsz, encoder_hidden_states.shape[1], 3).to(device)
422
+ audio_ids = (
423
+ torch.arange(audio_seq_length)
424
+ .unsqueeze(0)
425
+ .unsqueeze(-1)
426
+ .repeat(bsz, 1, 3)
427
+ .to(device)
428
+ )
429
+
430
+ if sft:
431
+
432
+ if self.uncondition:
433
+ mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
434
+ if len(mask_indices) > 0:
435
+ encoder_hidden_states[mask_indices] = 0
436
+
437
+ noise = torch.randn_like(latents)
438
+
439
+ u = compute_density_for_timestep_sampling(
440
+ weighting_scheme="logit_normal",
441
+ batch_size=bsz,
442
+ logit_mean=0,
443
+ logit_std=1,
444
+ mode_scale=None,
445
+ )
446
+
447
+ indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
448
+ timesteps = self.noise_scheduler_copy.timesteps[indices].to(
449
+ device=latents.device
450
+ )
451
+ sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
452
+
453
+ noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
454
+
455
+ model_pred = self.transformer(
456
+ hidden_states=noisy_model_input,
457
+ encoder_hidden_states=encoder_hidden_states,
458
+ pooled_projections=pooled_projection,
459
+ img_ids=audio_ids,
460
+ txt_ids=txt_ids,
461
+ guidance=None,
462
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
463
+ timestep=timesteps / 1000,
464
+ return_dict=False,
465
+ )[0]
466
+
467
+ target = noise - latents
468
+ loss = torch.mean(
469
+ ((model_pred.float() - target.float()) ** 2).reshape(
470
+ target.shape[0], -1
471
+ ),
472
+ 1,
473
+ )
474
+ loss = loss.mean()
475
+ raw_model_loss, raw_ref_loss, implicit_acc = (
476
+ 0,
477
+ 0,
478
+ 0,
479
+ ) ## default this to 0 if doing sft
480
+
481
+ else:
482
+ encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1)
483
+ pooled_projection = pooled_projection.repeat(2, 1)
484
+ noise = (
485
+ torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1)
486
+ ) ## Have to sample same noise for preferred and rejected
487
+ u = compute_density_for_timestep_sampling(
488
+ weighting_scheme="logit_normal",
489
+ batch_size=bsz // 2,
490
+ logit_mean=0,
491
+ logit_std=1,
492
+ mode_scale=None,
493
+ )
494
+
495
+ indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
496
+ timesteps = self.noise_scheduler_copy.timesteps[indices].to(
497
+ device=latents.device
498
+ )
499
+ timesteps = timesteps.repeat(2)
500
+ sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
501
+
502
+ noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
503
+
504
+ model_pred = self.transformer(
505
+ hidden_states=noisy_model_input,
506
+ encoder_hidden_states=encoder_hidden_states,
507
+ pooled_projections=pooled_projection,
508
+ img_ids=audio_ids,
509
+ txt_ids=txt_ids,
510
+ guidance=None,
511
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
512
+ timestep=timesteps / 1000,
513
+ return_dict=False,
514
+ )[0]
515
+ target = noise - latents
516
+
517
+ model_losses = F.mse_loss(
518
+ model_pred.float(), target.float(), reduction="none"
519
+ )
520
+ model_losses = model_losses.mean(
521
+ dim=list(range(1, len(model_losses.shape)))
522
+ )
523
+ model_losses_w, model_losses_l = model_losses.chunk(2)
524
+ model_diff = model_losses_w - model_losses_l
525
+ raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
526
+
527
+ with torch.no_grad():
528
+ ref_preds = self.ref_transformer(
529
+ hidden_states=noisy_model_input,
530
+ encoder_hidden_states=encoder_hidden_states,
531
+ pooled_projections=pooled_projection,
532
+ img_ids=audio_ids,
533
+ txt_ids=txt_ids,
534
+ guidance=None,
535
+ timestep=timesteps / 1000,
536
+ return_dict=False,
537
+ )[0]
538
+
539
+ ref_loss = F.mse_loss(
540
+ ref_preds.float(), target.float(), reduction="none"
541
+ )
542
+ ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
543
+
544
+ ref_losses_w, ref_losses_l = ref_loss.chunk(2)
545
+ ref_diff = ref_losses_w - ref_losses_l
546
+ raw_ref_loss = ref_loss.mean()
547
+
548
+ scale_term = -0.5 * self.beta_dpo
549
+ inside_term = scale_term * (model_diff - ref_diff)
550
+ implicit_acc = (
551
+ scale_term * (model_diff - ref_diff) > 0
552
+ ).sum().float() / inside_term.size(0)
553
+ loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean()
554
+
555
+ ## raw_model_loss, raw_ref_loss, implicit_acc is used to help to analyze dpo behaviour.
556
+ return loss, raw_model_loss, raw_ref_loss, implicit_acc
external_models/TangoFlux/tangoflux/train.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import argparse
3
+ import json
4
+ import logging
5
+ import math
6
+ import os
7
+ import yaml
8
+ from pathlib import Path
9
+ import diffusers
10
+ import datasets
11
+ import numpy as np
12
+ import pandas as pd
13
+ import wandb
14
+ import transformers
15
+ import torch
16
+ from accelerate import Accelerator
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import set_seed
19
+ from datasets import load_dataset
20
+ from torch.utils.data import Dataset, DataLoader
21
+ from tqdm.auto import tqdm
22
+ from transformers import SchedulerType, get_scheduler
23
+ from model import TangoFlux
24
+ from datasets import load_dataset, Audio
25
+ from utils import Text2AudioDataset, read_wav_file, pad_wav
26
+
27
+ from diffusers import AutoencoderOobleck
28
+ import torchaudio
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ def parse_args():
34
+ parser = argparse.ArgumentParser(
35
+ description="Rectified flow for text to audio generation task."
36
+ )
37
+
38
+ parser.add_argument(
39
+ "--num_examples",
40
+ type=int,
41
+ default=-1,
42
+ help="How many examples to use for training and validation.",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--text_column",
47
+ type=str,
48
+ default="captions",
49
+ help="The name of the column in the datasets containing the input texts.",
50
+ )
51
+ parser.add_argument(
52
+ "--audio_column",
53
+ type=str,
54
+ default="location",
55
+ help="The name of the column in the datasets containing the audio paths.",
56
+ )
57
+ parser.add_argument(
58
+ "--adam_beta1",
59
+ type=float,
60
+ default=0.9,
61
+ help="The beta1 parameter for the Adam optimizer.",
62
+ )
63
+ parser.add_argument(
64
+ "--adam_beta2",
65
+ type=float,
66
+ default=0.95,
67
+ help="The beta2 parameter for the Adam optimizer.",
68
+ )
69
+ parser.add_argument(
70
+ "--config",
71
+ type=str,
72
+ default="tangoflux_config.yaml",
73
+ help="Config file defining the model size as well as other hyper parameter.",
74
+ )
75
+ parser.add_argument(
76
+ "--prefix",
77
+ type=str,
78
+ default="",
79
+ help="Add prefix in text prompts.",
80
+ )
81
+
82
+ parser.add_argument(
83
+ "--learning_rate",
84
+ type=float,
85
+ default=3e-5,
86
+ help="Initial learning rate (after the potential warmup period) to use.",
87
+ )
88
+ parser.add_argument(
89
+ "--weight_decay", type=float, default=1e-8, help="Weight decay to use."
90
+ )
91
+
92
+ parser.add_argument(
93
+ "--max_train_steps",
94
+ type=int,
95
+ default=None,
96
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--lr_scheduler_type",
101
+ type=SchedulerType,
102
+ default="linear",
103
+ help="The scheduler type to use.",
104
+ choices=[
105
+ "linear",
106
+ "cosine",
107
+ "cosine_with_restarts",
108
+ "polynomial",
109
+ "constant",
110
+ "constant_with_warmup",
111
+ ],
112
+ )
113
+ parser.add_argument(
114
+ "--num_warmup_steps",
115
+ type=int,
116
+ default=0,
117
+ help="Number of steps for the warmup in the lr scheduler.",
118
+ )
119
+ parser.add_argument(
120
+ "--adam_epsilon",
121
+ type=float,
122
+ default=1e-08,
123
+ help="Epsilon value for the Adam optimizer",
124
+ )
125
+ parser.add_argument(
126
+ "--adam_weight_decay",
127
+ type=float,
128
+ default=1e-2,
129
+ help="Epsilon value for the Adam optimizer",
130
+ )
131
+ parser.add_argument(
132
+ "--seed", type=int, default=None, help="A seed for reproducible training."
133
+ )
134
+ parser.add_argument(
135
+ "--checkpointing_steps",
136
+ type=str,
137
+ default="best",
138
+ help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
139
+ )
140
+ parser.add_argument(
141
+ "--save_every",
142
+ type=int,
143
+ default=5,
144
+ help="Save model after every how many epochs when checkpointing_steps is set to best.",
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--resume_from_checkpoint",
149
+ type=str,
150
+ default=None,
151
+ help="If the training should continue from a local checkpoint folder.",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--load_from_checkpoint",
156
+ type=str,
157
+ default=None,
158
+ help="Whether to continue training from a model weight",
159
+ )
160
+
161
+ args = parser.parse_args()
162
+
163
+ return args
164
+
165
+
166
+ def main():
167
+ args = parse_args()
168
+ accelerator_log_kwargs = {}
169
+
170
+ def load_config(config_path):
171
+ with open(config_path, "r") as file:
172
+ return yaml.safe_load(file)
173
+
174
+ config = load_config(args.config)
175
+
176
+ learning_rate = float(config["training"]["learning_rate"])
177
+ num_train_epochs = int(config["training"]["num_train_epochs"])
178
+ num_warmup_steps = int(config["training"]["num_warmup_steps"])
179
+ per_device_batch_size = int(config["training"]["per_device_batch_size"])
180
+ gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"])
181
+
182
+ output_dir = config["paths"]["output_dir"]
183
+
184
+ accelerator = Accelerator(
185
+ gradient_accumulation_steps=gradient_accumulation_steps,
186
+ **accelerator_log_kwargs,
187
+ )
188
+
189
+ # Make one log on every process with the configuration for debugging.
190
+ logging.basicConfig(
191
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
192
+ datefmt="%m/%d/%Y %H:%M:%S",
193
+ level=logging.INFO,
194
+ )
195
+ logger.info(accelerator.state, main_process_only=False)
196
+
197
+ datasets.utils.logging.set_verbosity_error()
198
+ diffusers.utils.logging.set_verbosity_error()
199
+ transformers.utils.logging.set_verbosity_error()
200
+
201
+ # If passed along, set the training seed now.
202
+ if args.seed is not None:
203
+ set_seed(args.seed)
204
+
205
+ # Handle output directory creation and wandb tracking
206
+ if accelerator.is_main_process:
207
+ if output_dir is None or output_dir == "":
208
+ output_dir = "saved/" + str(int(time.time()))
209
+
210
+ if not os.path.exists("saved"):
211
+ os.makedirs("saved")
212
+
213
+ os.makedirs(output_dir, exist_ok=True)
214
+
215
+ elif output_dir is not None:
216
+ os.makedirs(output_dir, exist_ok=True)
217
+
218
+ os.makedirs("{}/{}".format(output_dir, "outputs"), exist_ok=True)
219
+ with open("{}/summary.jsonl".format(output_dir), "a") as f:
220
+ f.write(json.dumps(dict(vars(args))) + "\n\n")
221
+
222
+ accelerator.project_configuration.automatic_checkpoint_naming = False
223
+
224
+ wandb.init(
225
+ project="Text to Audio Flow matching",
226
+ settings=wandb.Settings(_disable_stats=True),
227
+ )
228
+
229
+ accelerator.wait_for_everyone()
230
+
231
+ # Get the datasets
232
+ data_files = {}
233
+ # if args.train_file is not None:
234
+ if config["paths"]["train_file"] != "":
235
+ data_files["train"] = config["paths"]["train_file"]
236
+ # if args.validation_file is not None:
237
+ if config["paths"]["val_file"] != "":
238
+ data_files["validation"] = config["paths"]["val_file"]
239
+ if config["paths"]["test_file"] != "":
240
+ data_files["test"] = config["paths"]["test_file"]
241
+ else:
242
+ data_files["test"] = config["paths"]["val_file"]
243
+
244
+ extension = "json"
245
+ raw_datasets = load_dataset(extension, data_files=data_files)
246
+ text_column, audio_column = args.text_column, args.audio_column
247
+
248
+ model = TangoFlux(config=config["model"])
249
+ vae = AutoencoderOobleck.from_pretrained(
250
+ "stabilityai/stable-audio-open-1.0", subfolder="vae"
251
+ )
252
+
253
+ ## Freeze vae
254
+ for param in vae.parameters():
255
+ vae.requires_grad = False
256
+ vae.eval()
257
+
258
+ ## Freeze text encoder param
259
+ for param in model.text_encoder.parameters():
260
+ param.requires_grad = False
261
+ model.text_encoder.eval()
262
+
263
+ prefix = args.prefix
264
+
265
+ with accelerator.main_process_first():
266
+ train_dataset = Text2AudioDataset(
267
+ raw_datasets["train"],
268
+ prefix,
269
+ text_column,
270
+ audio_column,
271
+ "duration",
272
+ args.num_examples,
273
+ )
274
+ eval_dataset = Text2AudioDataset(
275
+ raw_datasets["validation"],
276
+ prefix,
277
+ text_column,
278
+ audio_column,
279
+ "duration",
280
+ args.num_examples,
281
+ )
282
+ test_dataset = Text2AudioDataset(
283
+ raw_datasets["test"],
284
+ prefix,
285
+ text_column,
286
+ audio_column,
287
+ "duration",
288
+ args.num_examples,
289
+ )
290
+
291
+ accelerator.print(
292
+ "Num instances in train: {}, validation: {}, test: {}".format(
293
+ train_dataset.get_num_instances(),
294
+ eval_dataset.get_num_instances(),
295
+ test_dataset.get_num_instances(),
296
+ )
297
+ )
298
+
299
+ train_dataloader = DataLoader(
300
+ train_dataset,
301
+ shuffle=True,
302
+ batch_size=config["training"]["per_device_batch_size"],
303
+ collate_fn=train_dataset.collate_fn,
304
+ )
305
+ eval_dataloader = DataLoader(
306
+ eval_dataset,
307
+ shuffle=True,
308
+ batch_size=config["training"]["per_device_batch_size"],
309
+ collate_fn=eval_dataset.collate_fn,
310
+ )
311
+ test_dataloader = DataLoader(
312
+ test_dataset,
313
+ shuffle=False,
314
+ batch_size=config["training"]["per_device_batch_size"],
315
+ collate_fn=test_dataset.collate_fn,
316
+ )
317
+
318
+ # Optimizer
319
+
320
+ optimizer_parameters = list(model.transformer.parameters()) + list(
321
+ model.fc.parameters()
322
+ )
323
+ num_trainable_parameters = sum(
324
+ p.numel() for p in model.parameters() if p.requires_grad
325
+ )
326
+ accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
327
+
328
+ if args.load_from_checkpoint:
329
+ from safetensors.torch import load_file
330
+
331
+ w1 = load_file(args.load_from_checkpoint)
332
+ model.load_state_dict(w1, strict=False)
333
+ logger.info("Weights loaded from{}".format(args.load_from_checkpoint))
334
+
335
+ optimizer = torch.optim.AdamW(
336
+ optimizer_parameters,
337
+ lr=learning_rate,
338
+ betas=(args.adam_beta1, args.adam_beta2),
339
+ weight_decay=args.adam_weight_decay,
340
+ eps=args.adam_epsilon,
341
+ )
342
+
343
+ # Scheduler and math around the number of training steps.
344
+ overrode_max_train_steps = False
345
+ num_update_steps_per_epoch = math.ceil(
346
+ len(train_dataloader) / gradient_accumulation_steps
347
+ )
348
+ if args.max_train_steps is None:
349
+ args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
350
+ overrode_max_train_steps = True
351
+
352
+ lr_scheduler = get_scheduler(
353
+ name=args.lr_scheduler_type,
354
+ optimizer=optimizer,
355
+ num_warmup_steps=num_warmup_steps
356
+ * gradient_accumulation_steps
357
+ * accelerator.num_processes,
358
+ num_training_steps=args.max_train_steps * gradient_accumulation_steps,
359
+ )
360
+
361
+ # Prepare everything with our `accelerator`.
362
+ vae, model, optimizer, lr_scheduler = accelerator.prepare(
363
+ vae, model, optimizer, lr_scheduler
364
+ )
365
+
366
+ train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
367
+ train_dataloader, eval_dataloader, test_dataloader
368
+ )
369
+
370
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
371
+ num_update_steps_per_epoch = math.ceil(
372
+ len(train_dataloader) / gradient_accumulation_steps
373
+ )
374
+ if overrode_max_train_steps:
375
+ args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
376
+ # Afterwards we recalculate our number of training epochs
377
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
378
+
379
+ # Figure out how many steps we should save the Accelerator states
380
+ checkpointing_steps = args.checkpointing_steps
381
+ if checkpointing_steps is not None and checkpointing_steps.isdigit():
382
+ checkpointing_steps = int(checkpointing_steps)
383
+
384
+ # We need to initialize the trackers we use, and also store our configuration.
385
+ # The trackers initializes automatically on the main process.
386
+
387
+ # Train!
388
+ total_batch_size = (
389
+ per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
390
+ )
391
+
392
+ logger.info("***** Running training *****")
393
+ logger.info(f" Num examples = {len(train_dataset)}")
394
+ logger.info(f" Num Epochs = {num_train_epochs}")
395
+ logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
396
+ logger.info(
397
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
398
+ )
399
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
400
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
401
+
402
+ # Only show the progress bar once on each machine.
403
+ progress_bar = tqdm(
404
+ range(args.max_train_steps), disable=not accelerator.is_local_main_process
405
+ )
406
+
407
+ completed_steps = 0
408
+ starting_epoch = 0
409
+ # Potentially load in the weights and states from a previous save
410
+ resume_from_checkpoint = config["paths"]["resume_from_checkpoint"]
411
+ if resume_from_checkpoint != "":
412
+ accelerator.load_state(resume_from_checkpoint)
413
+ accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}")
414
+
415
+ # Duration of the audio clips in seconds
416
+ best_loss = np.inf
417
+ length = config["training"]["max_audio_duration"]
418
+
419
+ for epoch in range(starting_epoch, num_train_epochs):
420
+ model.train()
421
+ total_loss, total_val_loss = 0, 0
422
+ for step, batch in enumerate(train_dataloader):
423
+
424
+ with accelerator.accumulate(model):
425
+ optimizer.zero_grad()
426
+ device = model.device
427
+ text, audios, duration, _ = batch
428
+
429
+ with torch.no_grad():
430
+ audio_list = []
431
+
432
+ for audio_path in audios:
433
+
434
+ wav = read_wav_file(
435
+ audio_path, length
436
+ ) ## Only read the first 30 seconds of audio
437
+ if (
438
+ wav.shape[0] == 1
439
+ ): ## If this audio is mono, we repeat the channel so it become "fake stereo"
440
+ wav = wav.repeat(2, 1)
441
+ audio_list.append(wav)
442
+
443
+ audio_input = torch.stack(audio_list, dim=0)
444
+ audio_input = audio_input.to(device)
445
+ unwrapped_vae = accelerator.unwrap_model(vae)
446
+
447
+ duration = torch.tensor(duration, device=device)
448
+ duration = torch.clamp(
449
+ duration, max=length
450
+ ) ## clamp duration to max audio length
451
+
452
+ audio_latent = unwrapped_vae.encode(
453
+ audio_input
454
+ ).latent_dist.sample()
455
+ audio_latent = audio_latent.transpose(
456
+ 1, 2
457
+ ) ## Tranpose to (bsz, seq_len, channel)
458
+
459
+ loss, _, _, _ = model(audio_latent, text, duration=duration)
460
+ total_loss += loss.detach().float()
461
+ accelerator.backward(loss)
462
+
463
+ if accelerator.sync_gradients:
464
+ progress_bar.update(1)
465
+ completed_steps += 1
466
+
467
+ optimizer.step()
468
+ lr_scheduler.step()
469
+
470
+ if completed_steps % 10 == 0 and accelerator.is_main_process:
471
+
472
+ total_norm = 0.0
473
+ for p in model.parameters():
474
+ if p.grad is not None:
475
+ param_norm = p.grad.data.norm(2)
476
+ total_norm += param_norm.item() ** 2
477
+
478
+ total_norm = total_norm**0.5
479
+ logger.info(
480
+ f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}"
481
+ )
482
+
483
+ lr = lr_scheduler.get_last_lr()[0]
484
+ result = {
485
+ "train_loss": loss.item(),
486
+ "grad_norm": total_norm,
487
+ "learning_rate": lr,
488
+ }
489
+
490
+ # result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
491
+ wandb.log(result, step=completed_steps)
492
+
493
+ # Checks if the accelerator has performed an optimization step behind the scenes
494
+
495
+ if isinstance(checkpointing_steps, int):
496
+ if completed_steps % checkpointing_steps == 0:
497
+ output_dir = f"step_{completed_steps }"
498
+ if output_dir is not None:
499
+ output_dir = os.path.join(output_dir, output_dir)
500
+ accelerator.save_state(output_dir)
501
+
502
+ if completed_steps >= args.max_train_steps:
503
+ break
504
+
505
+ model.eval()
506
+ eval_progress_bar = tqdm(
507
+ range(len(eval_dataloader)), disable=not accelerator.is_local_main_process
508
+ )
509
+ for step, batch in enumerate(eval_dataloader):
510
+ with accelerator.accumulate(model) and torch.no_grad():
511
+ device = model.device
512
+ text, audios, duration, _ = batch
513
+
514
+ audio_list = []
515
+ for audio_path in audios:
516
+
517
+ wav = read_wav_file(
518
+ audio_path, length
519
+ ) ## make sure none of audio exceed 30 sec
520
+ if (
521
+ wav.shape[0] == 1
522
+ ): ## If this audio is mono, we repeat the channel so it become "fake stereo"
523
+ wav = wav.repeat(2, 1)
524
+ audio_list.append(wav)
525
+
526
+ audio_input = torch.stack(audio_list, dim=0)
527
+ audio_input = audio_input.to(device)
528
+ duration = torch.tensor(duration, device=device)
529
+ unwrapped_vae = accelerator.unwrap_model(vae)
530
+ audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
531
+ audio_latent = audio_latent.transpose(
532
+ 1, 2
533
+ ) ## Tranpose to (bsz, seq_len, channel)
534
+
535
+ val_loss, _, _, _ = model(audio_latent, text, duration=duration)
536
+
537
+ total_val_loss += val_loss.detach().float()
538
+ eval_progress_bar.update(1)
539
+
540
+ if accelerator.is_main_process:
541
+
542
+ result = {}
543
+ result["epoch"] = float(epoch + 1)
544
+
545
+ result["epoch/train_loss"] = round(
546
+ total_loss.item() / len(train_dataloader), 4
547
+ )
548
+ result["epoch/val_loss"] = round(
549
+ total_val_loss.item() / len(eval_dataloader), 4
550
+ )
551
+
552
+ wandb.log(result, step=completed_steps)
553
+
554
+ result_string = "Epoch: {}, Loss Train: {}, Val: {}\n".format(
555
+ epoch, result["epoch/train_loss"], result["epoch/val_loss"]
556
+ )
557
+
558
+ accelerator.print(result_string)
559
+
560
+ with open("{}/summary.jsonl".format(output_dir), "a") as f:
561
+ f.write(json.dumps(result) + "\n\n")
562
+
563
+ logger.info(result)
564
+
565
+ if result["epoch/val_loss"] < best_loss:
566
+ best_loss = result["epoch/val_loss"]
567
+ save_checkpoint = True
568
+ else:
569
+ save_checkpoint = False
570
+
571
+ accelerator.wait_for_everyone()
572
+ if accelerator.is_main_process and args.checkpointing_steps == "best":
573
+ if save_checkpoint:
574
+ accelerator.save_state("{}/{}".format(output_dir, "best"))
575
+
576
+ if (epoch + 1) % args.save_every == 0:
577
+ accelerator.save_state(
578
+ "{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
579
+ )
580
+
581
+ if accelerator.is_main_process and args.checkpointing_steps == "epoch":
582
+ accelerator.save_state(
583
+ "{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
584
+ )
585
+
586
+
587
+ if __name__ == "__main__":
588
+ main()
external_models/TangoFlux/tangoflux/train_dpo.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import argparse
3
+ import json
4
+ import logging
5
+ import math
6
+ import os
7
+ import yaml
8
+
9
+ # from tqdm import tqdm
10
+ import copy
11
+ from pathlib import Path
12
+ import diffusers
13
+ import datasets
14
+ import numpy as np
15
+ import pandas as pd
16
+ import wandb
17
+ import transformers
18
+ import torch
19
+ from accelerate import Accelerator
20
+ from accelerate.logging import get_logger
21
+ from accelerate.utils import set_seed
22
+ from datasets import load_dataset
23
+ from torch.utils.data import Dataset, DataLoader
24
+ from tqdm.auto import tqdm
25
+ from transformers import SchedulerType, get_scheduler
26
+ from tangoflux.model import TangoFlux
27
+ from datasets import load_dataset, Audio
28
+ from tangoflux.utils import Text2AudioDataset, read_wav_file, DPOText2AudioDataset
29
+
30
+ from diffusers import AutoencoderOobleck
31
+ import torchaudio
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ def parse_args():
37
+ parser = argparse.ArgumentParser(
38
+ description="Rectified flow for text to audio generation task."
39
+ )
40
+
41
+ parser.add_argument(
42
+ "--num_examples",
43
+ type=int,
44
+ default=-1,
45
+ help="How many examples to use for training and validation.",
46
+ )
47
+
48
+ parser.add_argument(
49
+ "--text_column",
50
+ type=str,
51
+ default="captions",
52
+ help="The name of the column in the datasets containing the input texts.",
53
+ )
54
+ parser.add_argument(
55
+ "--audio_column",
56
+ type=str,
57
+ default="location",
58
+ help="The name of the column in the datasets containing the audio paths.",
59
+ )
60
+ parser.add_argument(
61
+ "--adam_beta1",
62
+ type=float,
63
+ default=0.9,
64
+ help="The beta1 parameter for the Adam optimizer.",
65
+ )
66
+ parser.add_argument(
67
+ "--adam_beta2",
68
+ type=float,
69
+ default=0.95,
70
+ help="The beta2 parameter for the Adam optimizer.",
71
+ )
72
+ parser.add_argument(
73
+ "--config",
74
+ type=str,
75
+ default="tangoflux_config.yaml",
76
+ help="Config file defining the model size.",
77
+ )
78
+
79
+ parser.add_argument(
80
+ "--weight_decay", type=float, default=1e-8, help="Weight decay to use."
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--max_train_steps",
85
+ type=int,
86
+ default=None,
87
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--lr_scheduler_type",
92
+ type=SchedulerType,
93
+ default="linear",
94
+ help="The scheduler type to use.",
95
+ choices=[
96
+ "linear",
97
+ "cosine",
98
+ "cosine_with_restarts",
99
+ "polynomial",
100
+ "constant",
101
+ "constant_with_warmup",
102
+ ],
103
+ )
104
+
105
+ parser.add_argument(
106
+ "--adam_epsilon",
107
+ type=float,
108
+ default=1e-08,
109
+ help="Epsilon value for the Adam optimizer",
110
+ )
111
+ parser.add_argument(
112
+ "--adam_weight_decay",
113
+ type=float,
114
+ default=1e-2,
115
+ help="Epsilon value for the Adam optimizer",
116
+ )
117
+ parser.add_argument(
118
+ "--seed", type=int, default=None, help="A seed for reproducible training."
119
+ )
120
+ parser.add_argument(
121
+ "--checkpointing_steps",
122
+ type=str,
123
+ default="best",
124
+ help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
125
+ )
126
+ parser.add_argument(
127
+ "--save_every",
128
+ type=int,
129
+ default=5,
130
+ help="Save model after every how many epochs when checkpointing_steps is set to best.",
131
+ )
132
+
133
+
134
+
135
+ parser.add_argument(
136
+ "--load_from_checkpoint",
137
+ type=str,
138
+ default=None,
139
+ help="Whether to continue training from a model weight",
140
+ )
141
+
142
+
143
+ args = parser.parse_args()
144
+
145
+ # Sanity checks
146
+ # if args.train_file is None and args.validation_file is None:
147
+ # raise ValueError("Need a training/validation file.")
148
+ # else:
149
+ # if args.train_file is not None:
150
+ # extension = args.train_file.split(".")[-1]
151
+ # assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
152
+ # if args.validation_file is not None:
153
+ # extension = args.validation_file.split(".")[-1]
154
+ # assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
155
+
156
+ return args
157
+
158
+
159
+ def main():
160
+ args = parse_args()
161
+ accelerator_log_kwargs = {}
162
+
163
+ def load_config(config_path):
164
+ with open(config_path, "r") as file:
165
+ return yaml.safe_load(file)
166
+
167
+ config = load_config(args.config)
168
+
169
+ learning_rate = float(config["training"]["learning_rate"])
170
+ num_train_epochs = int(config["training"]["num_train_epochs"])
171
+ num_warmup_steps = int(config["training"]["num_warmup_steps"])
172
+ per_device_batch_size = int(config["training"]["per_device_batch_size"])
173
+ gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"])
174
+
175
+ output_dir = config["paths"]["output_dir"]
176
+
177
+ accelerator = Accelerator(
178
+ gradient_accumulation_steps=gradient_accumulation_steps,
179
+ **accelerator_log_kwargs,
180
+ )
181
+
182
+ # Make one log on every process with the configuration for debugging.
183
+ logging.basicConfig(
184
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
185
+ datefmt="%m/%d/%Y %H:%M:%S",
186
+ level=logging.INFO,
187
+ )
188
+ logger.info(accelerator.state, main_process_only=False)
189
+
190
+ datasets.utils.logging.set_verbosity_error()
191
+ diffusers.utils.logging.set_verbosity_error()
192
+ transformers.utils.logging.set_verbosity_error()
193
+
194
+ # If passed along, set the training seed now.
195
+ if args.seed is not None:
196
+ set_seed(args.seed)
197
+
198
+ # Handle output directory creation and wandb tracking
199
+ if accelerator.is_main_process:
200
+ if output_dir is None or output_dir == "":
201
+ output_dir = "saved/" + str(int(time.time()))
202
+
203
+ if not os.path.exists("saved"):
204
+ os.makedirs("saved")
205
+
206
+ os.makedirs(output_dir, exist_ok=True)
207
+
208
+ elif output_dir is not None:
209
+ os.makedirs(output_dir, exist_ok=True)
210
+
211
+ os.makedirs("{}/{}".format(output_dir, "outputs"), exist_ok=True)
212
+ with open("{}/summary.jsonl".format(output_dir), "a") as f:
213
+ f.write(json.dumps(dict(vars(args))) + "\n\n")
214
+
215
+ accelerator.project_configuration.automatic_checkpoint_naming = False
216
+
217
+ wandb.init(
218
+ project="Text to Audio Flow matching DPO",
219
+ settings=wandb.Settings(_disable_stats=True),
220
+ )
221
+
222
+ accelerator.wait_for_everyone()
223
+
224
+ # Get the datasets
225
+ data_files = {}
226
+ # if args.train_file is not None:
227
+ if config["paths"]["train_file"] != "":
228
+ data_files["train"] = config["paths"]["train_file"]
229
+ # if args.validation_file is not None:
230
+ if config["paths"]["val_file"] != "":
231
+ data_files["validation"] = config["paths"]["val_file"]
232
+ if config["paths"]["test_file"] != "":
233
+ data_files["test"] = config["paths"]["test_file"]
234
+ else:
235
+ data_files["test"] = config["paths"]["val_file"]
236
+
237
+ extension = "json"
238
+ train_dataset = load_dataset(extension, data_files=data_files["train"])
239
+ data_files.pop("train")
240
+ raw_datasets = load_dataset(extension, data_files=data_files)
241
+ text_column, audio_column = args.text_column, args.audio_column
242
+
243
+ model = TangoFlux(config=config["model"], initialize_reference_model=True)
244
+ vae = AutoencoderOobleck.from_pretrained(
245
+ "stabilityai/stable-audio-open-1.0", subfolder="vae"
246
+ )
247
+
248
+ ## Freeze vae
249
+ for param in vae.parameters():
250
+ vae.requires_grad = False
251
+ vae.eval()
252
+
253
+ ## Freeze text encoder param
254
+ for param in model.text_encoder.parameters():
255
+ param.requires_grad = False
256
+ model.text_encoder.eval()
257
+
258
+ prefix = ""
259
+
260
+ with accelerator.main_process_first():
261
+ train_dataset = DPOText2AudioDataset(
262
+ train_dataset["train"],
263
+ prefix,
264
+ text_column,
265
+ "chosen",
266
+ "reject",
267
+ "duration",
268
+ args.num_examples,
269
+ )
270
+ eval_dataset = Text2AudioDataset(
271
+ raw_datasets["validation"],
272
+ prefix,
273
+ text_column,
274
+ audio_column,
275
+ "duration",
276
+ args.num_examples,
277
+ )
278
+ test_dataset = Text2AudioDataset(
279
+ raw_datasets["test"],
280
+ prefix,
281
+ text_column,
282
+ audio_column,
283
+ "duration",
284
+ args.num_examples,
285
+ )
286
+
287
+ accelerator.print(
288
+ "Num instances in train: {}, validation: {}, test: {}".format(
289
+ train_dataset.get_num_instances(),
290
+ eval_dataset.get_num_instances(),
291
+ test_dataset.get_num_instances(),
292
+ )
293
+ )
294
+
295
+ train_dataloader = DataLoader(
296
+ train_dataset,
297
+ shuffle=True,
298
+ batch_size=config["training"]["per_device_batch_size"],
299
+ collate_fn=train_dataset.collate_fn,
300
+ )
301
+ eval_dataloader = DataLoader(
302
+ eval_dataset,
303
+ shuffle=True,
304
+ batch_size=config["training"]["per_device_batch_size"],
305
+ collate_fn=eval_dataset.collate_fn,
306
+ )
307
+ test_dataloader = DataLoader(
308
+ test_dataset,
309
+ shuffle=False,
310
+ batch_size=config["training"]["per_device_batch_size"],
311
+ collate_fn=test_dataset.collate_fn,
312
+ )
313
+
314
+ # Optimizer
315
+
316
+ optimizer_parameters = list(model.transformer.parameters()) + list(
317
+ model.fc.parameters()
318
+ )
319
+ num_trainable_parameters = sum(
320
+ p.numel() for p in model.parameters() if p.requires_grad
321
+ )
322
+ accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
323
+
324
+ if args.load_from_checkpoint:
325
+ from safetensors.torch import load_file
326
+
327
+ w1 = load_file(args.load_from_checkpoint)
328
+ model.load_state_dict(w1, strict=False)
329
+ logger.info("Weights loaded from{}".format(args.load_from_checkpoint))
330
+
331
+ import copy
332
+
333
+ model.ref_transformer = copy.deepcopy(model.transformer)
334
+ model.ref_transformer.requires_grad_ = False
335
+ model.ref_transformer.eval()
336
+ for param in model.ref_transformer.parameters():
337
+ param.requires_grad = False
338
+
339
+
340
+
341
+
342
+ optimizer = torch.optim.AdamW(
343
+ optimizer_parameters,
344
+ lr=learning_rate,
345
+ betas=(args.adam_beta1, args.adam_beta2),
346
+ weight_decay=args.adam_weight_decay,
347
+ eps=args.adam_epsilon,
348
+ )
349
+
350
+ # Scheduler and math around the number of training steps.
351
+ overrode_max_train_steps = False
352
+ num_update_steps_per_epoch = math.ceil(
353
+ len(train_dataloader) / gradient_accumulation_steps
354
+ )
355
+ if args.max_train_steps is None:
356
+ args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
357
+ overrode_max_train_steps = True
358
+
359
+ lr_scheduler = get_scheduler(
360
+ name=args.lr_scheduler_type,
361
+ optimizer=optimizer,
362
+ num_warmup_steps=num_warmup_steps
363
+ * gradient_accumulation_steps
364
+ * accelerator.num_processes,
365
+ num_training_steps=args.max_train_steps * gradient_accumulation_steps,
366
+ )
367
+
368
+ # Prepare everything with our `accelerator`.
369
+ vae, model, optimizer, lr_scheduler = accelerator.prepare(
370
+ vae, model, optimizer, lr_scheduler
371
+ )
372
+
373
+ train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
374
+ train_dataloader, eval_dataloader, test_dataloader
375
+ )
376
+
377
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
378
+ num_update_steps_per_epoch = math.ceil(
379
+ len(train_dataloader) / gradient_accumulation_steps
380
+ )
381
+ if overrode_max_train_steps:
382
+ args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
383
+ # Afterwards we recalculate our number of training epochs
384
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
385
+
386
+ # Figure out how many steps we should save the Accelerator states
387
+ checkpointing_steps = args.checkpointing_steps
388
+ if checkpointing_steps is not None and checkpointing_steps.isdigit():
389
+ checkpointing_steps = int(checkpointing_steps)
390
+
391
+
392
+
393
+ # Train!
394
+ total_batch_size = (
395
+ per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
396
+ )
397
+
398
+ logger.info("***** Running training *****")
399
+ logger.info(f" Num examples = {len(train_dataset)}")
400
+ logger.info(f" Num Epochs = {num_train_epochs}")
401
+ logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
402
+ logger.info(
403
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
404
+ )
405
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
406
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
407
+
408
+ # Only show the progress bar once on each machine.
409
+ progress_bar = tqdm(
410
+ range(args.max_train_steps), disable=not accelerator.is_local_main_process
411
+ )
412
+
413
+ completed_steps = 0
414
+ starting_epoch = 0
415
+ # Potentially load in the weights and states from a previous save
416
+ resume_from_checkpoint = config["paths"]["resume_from_checkpoint"]
417
+ if resume_from_checkpoint != "":
418
+ accelerator.load_state(resume_from_checkpoint)
419
+ accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}")
420
+
421
+ # Duration of the audio clips in seconds
422
+ best_loss = np.inf
423
+ length = config["training"]["max_audio_duration"]
424
+
425
+ for epoch in range(starting_epoch, num_train_epochs):
426
+ model.train()
427
+ total_loss, total_val_loss = 0, 0
428
+
429
+ for step, batch in enumerate(train_dataloader):
430
+ optimizer.zero_grad()
431
+ with accelerator.accumulate(model):
432
+ optimizer.zero_grad()
433
+ device = accelerator.device
434
+ text, audio_w, audio_l, duration, _ = batch
435
+
436
+ with torch.no_grad():
437
+ audio_list_w = []
438
+ audio_list_l = []
439
+ for audio_path in audio_w:
440
+
441
+ wav = read_wav_file(
442
+ audio_path, length
443
+ ) ## Only read the first 30 seconds of audio
444
+ if (
445
+ wav.shape[0] == 1
446
+ ): ## If this audio is mono, we repeat the channel so it become "fake stereo"
447
+ wav = wav.repeat(2, 1)
448
+ audio_list_w.append(wav)
449
+
450
+ for audio_path in audio_l:
451
+ wav = read_wav_file(
452
+ audio_path, length
453
+ ) ## Only read the first 30 seconds of audio
454
+ if (
455
+ wav.shape[0] == 1
456
+ ): ## If this audio is mono, we repeat the channel so it become "fake stereo"
457
+ wav = wav.repeat(2, 1)
458
+ audio_list_l.append(wav)
459
+
460
+ audio_input_w = torch.stack(audio_list_w, dim=0).to(device)
461
+ audio_input_l = torch.stack(audio_list_l, dim=0).to(device)
462
+ # audio_input_ = audio_input.to(device)
463
+ unwrapped_vae = accelerator.unwrap_model(vae)
464
+
465
+ duration = torch.tensor(duration, device=device)
466
+ duration = torch.clamp(
467
+ duration, max=length
468
+ ) ## max duration is 30 sec
469
+
470
+ audio_latent_w = unwrapped_vae.encode(
471
+ audio_input_w
472
+ ).latent_dist.sample()
473
+ audio_latent_l = unwrapped_vae.encode(
474
+ audio_input_l
475
+ ).latent_dist.sample()
476
+ audio_latent = torch.cat((audio_latent_w, audio_latent_l), dim=0)
477
+ audio_latent = audio_latent.transpose(
478
+ 1, 2
479
+ ) ## Tranpose to (bsz, seq_len, channel)
480
+
481
+ loss, raw_model_loss, raw_ref_loss, implicit_acc = model(
482
+ audio_latent, text, duration=duration, sft=False
483
+ )
484
+
485
+ total_loss += loss.detach().float()
486
+ accelerator.backward(loss)
487
+ optimizer.step()
488
+ lr_scheduler.step()
489
+ # if accelerator.sync_gradients:
490
+ if accelerator.sync_gradients:
491
+ # accelerator.clip_grad_value_(model.parameters(),1.0)
492
+ progress_bar.update(1)
493
+ completed_steps += 1
494
+
495
+ if completed_steps % 10 == 0 and accelerator.is_main_process:
496
+
497
+ total_norm = 0.0
498
+ for p in model.parameters():
499
+ if p.grad is not None:
500
+ param_norm = p.grad.data.norm(2)
501
+ total_norm += param_norm.item() ** 2
502
+
503
+ total_norm = total_norm**0.5
504
+ logger.info(
505
+ f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}"
506
+ )
507
+
508
+ lr = lr_scheduler.get_last_lr()[0]
509
+
510
+ result = {
511
+ "train_loss": loss.item(),
512
+ "grad_norm": total_norm,
513
+ "learning_rate": lr,
514
+ "raw_model_loss": raw_model_loss,
515
+ "raw_ref_loss": raw_ref_loss,
516
+ "implicit_acc": implicit_acc,
517
+ }
518
+
519
+ # result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
520
+ wandb.log(result, step=completed_steps)
521
+
522
+ # Checks if the accelerator has performed an optimization step behind the scenes
523
+
524
+ if isinstance(checkpointing_steps, int):
525
+ if completed_steps % checkpointing_steps == 0:
526
+ output_dir = f"step_{completed_steps }"
527
+ if output_dir is not None:
528
+ output_dir = os.path.join(output_dir, output_dir)
529
+ accelerator.save_state(output_dir)
530
+
531
+ if completed_steps >= args.max_train_steps:
532
+ break
533
+
534
+ model.eval()
535
+ eval_progress_bar = tqdm(
536
+ range(len(eval_dataloader)), disable=not accelerator.is_local_main_process
537
+ )
538
+ for step, batch in enumerate(eval_dataloader):
539
+ with accelerator.accumulate(model) and torch.no_grad():
540
+ device = model.device
541
+ text, audios, duration, _ = batch
542
+
543
+ audio_list = []
544
+ for audio_path in audios:
545
+
546
+ wav = read_wav_file(
547
+ audio_path, length
548
+ ) ## Only read the first 30 seconds of audio
549
+ if (
550
+ wav.shape[0] == 1
551
+ ): ## If this audio is mono, we repeat the channel so it become "fake stereo"
552
+ wav = wav.repeat(2, 1)
553
+ audio_list.append(wav)
554
+
555
+ audio_input = torch.stack(audio_list, dim=0)
556
+ audio_input = audio_input.to(device)
557
+ duration = torch.tensor(duration, device=device)
558
+ unwrapped_vae = accelerator.unwrap_model(vae)
559
+ audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
560
+ audio_latent = audio_latent.transpose(
561
+ 1, 2
562
+ ) ## Tranpose to (bsz, seq_len, channel)
563
+
564
+ val_loss, _, _, _ = model(
565
+ audio_latent, text, duration=duration, sft=True
566
+ )
567
+
568
+ total_val_loss += val_loss.detach().float()
569
+ eval_progress_bar.update(1)
570
+
571
+ if accelerator.is_main_process:
572
+
573
+ result = {}
574
+ result["epoch"] = float(epoch + 1)
575
+
576
+ result["epoch/train_loss"] = round(
577
+ total_loss.item() / len(train_dataloader), 4
578
+ )
579
+ result["epoch/val_loss"] = round(
580
+ total_val_loss.item() / len(eval_dataloader), 4
581
+ )
582
+
583
+ wandb.log(result, step=completed_steps)
584
+
585
+ with open("{}/summary.jsonl".format(output_dir), "a") as f:
586
+ f.write(json.dumps(result) + "\n\n")
587
+
588
+ logger.info(result)
589
+
590
+ save_checkpoint = True
591
+ accelerator.wait_for_everyone()
592
+ if accelerator.is_main_process and args.checkpointing_steps == "best":
593
+ if save_checkpoint:
594
+ accelerator.save_state("{}/{}".format(output_dir, "best"))
595
+
596
+ if (epoch + 1) % args.save_every == 0:
597
+ accelerator.save_state(
598
+ "{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
599
+ )
600
+
601
+ if accelerator.is_main_process and args.checkpointing_steps == "epoch":
602
+ accelerator.save_state(
603
+ "{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
604
+ )
605
+
606
+
607
+ if __name__ == "__main__":
608
+ main()
external_models/TangoFlux/tangoflux/utils.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ import torchaudio
7
+ import random
8
+ import itertools
9
+ import numpy as np
10
+
11
+
12
+ import numpy as np
13
+
14
+
15
+ def normalize_wav(waveform):
16
+ waveform = waveform - torch.mean(waveform)
17
+ waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
18
+ return waveform * 0.5
19
+
20
+
21
+ def pad_wav(waveform, segment_length):
22
+ waveform_length = len(waveform)
23
+
24
+ if segment_length is None or waveform_length == segment_length:
25
+ return waveform
26
+ elif waveform_length > segment_length:
27
+ return waveform[:segment_length]
28
+ else:
29
+ padded_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
30
+ waveform = torch.cat([waveform, padded_wav])
31
+ return waveform
32
+
33
+
34
+ def read_wav_file(filename, duration_sec):
35
+ info = torchaudio.info(filename)
36
+ sample_rate = info.sample_rate
37
+
38
+ # Calculate the number of frames corresponding to the desired duration
39
+ num_frames = int(sample_rate * duration_sec)
40
+
41
+ waveform, sr = torchaudio.load(filename, num_frames=num_frames) # Faster!!!
42
+
43
+ if waveform.shape[0] == 2: ## Stereo audio
44
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=44100)
45
+ resampled_waveform = resampler(waveform)
46
+ # print(resampled_waveform.shape)
47
+ padded_left = pad_wav(
48
+ resampled_waveform[0], int(44100 * duration_sec)
49
+ ) ## We pad left and right seperately
50
+ padded_right = pad_wav(resampled_waveform[1], int(44100 * duration_sec))
51
+
52
+ return torch.stack([padded_left, padded_right])
53
+ else:
54
+ waveform = torchaudio.functional.resample(
55
+ waveform, orig_freq=sr, new_freq=44100
56
+ )[0]
57
+ waveform = pad_wav(waveform, int(44100 * duration_sec)).unsqueeze(0)
58
+
59
+ return waveform
60
+
61
+
62
+ class DPOText2AudioDataset(Dataset):
63
+ def __init__(
64
+ self,
65
+ dataset,
66
+ prefix,
67
+ text_column,
68
+ audio_w_column,
69
+ audio_l_column,
70
+ duration,
71
+ num_examples=-1,
72
+ ):
73
+
74
+ inputs = list(dataset[text_column])
75
+ self.inputs = [prefix + inp for inp in inputs]
76
+ self.audios_w = list(dataset[audio_w_column])
77
+ self.audios_l = list(dataset[audio_l_column])
78
+ self.durations = list(dataset[duration])
79
+ self.indices = list(range(len(self.inputs)))
80
+
81
+ self.mapper = {}
82
+ for index, audio_w, audio_l, duration, text in zip(
83
+ self.indices, self.audios_w, self.audios_l, self.durations, inputs
84
+ ):
85
+ self.mapper[index] = [audio_w, audio_l, duration, text]
86
+
87
+ if num_examples != -1:
88
+ self.inputs, self.audios_w, self.audios_l, self.durations = (
89
+ self.inputs[:num_examples],
90
+ self.audios_w[:num_examples],
91
+ self.audios_l[:num_examples],
92
+ self.durations[:num_examples],
93
+ )
94
+ self.indices = self.indices[:num_examples]
95
+
96
+ def __len__(self):
97
+ return len(self.inputs)
98
+
99
+ def get_num_instances(self):
100
+ return len(self.inputs)
101
+
102
+ def __getitem__(self, index):
103
+ s1, s2, s3, s4, s5 = (
104
+ self.inputs[index],
105
+ self.audios_w[index],
106
+ self.audios_l[index],
107
+ self.durations[index],
108
+ self.indices[index],
109
+ )
110
+ return s1, s2, s3, s4, s5
111
+
112
+ def collate_fn(self, data):
113
+ dat = pd.DataFrame(data)
114
+ return [dat[i].tolist() for i in dat]
115
+
116
+
117
+ class Text2AudioDataset(Dataset):
118
+ def __init__(
119
+ self, dataset, prefix, text_column, audio_column, duration, num_examples=-1
120
+ ):
121
+
122
+ inputs = list(dataset[text_column])
123
+ self.inputs = [prefix + inp for inp in inputs]
124
+ self.audios = list(dataset[audio_column])
125
+ self.durations = list(dataset[duration])
126
+ self.indices = list(range(len(self.inputs)))
127
+
128
+ self.mapper = {}
129
+ for index, audio, duration, text in zip(
130
+ self.indices, self.audios, self.durations, inputs
131
+ ):
132
+ self.mapper[index] = [audio, text, duration]
133
+
134
+ if num_examples != -1:
135
+ self.inputs, self.audios, self.durations = (
136
+ self.inputs[:num_examples],
137
+ self.audios[:num_examples],
138
+ self.durations[:num_examples],
139
+ )
140
+ self.indices = self.indices[:num_examples]
141
+
142
+ def __len__(self):
143
+ return len(self.inputs)
144
+
145
+ def get_num_instances(self):
146
+ return len(self.inputs)
147
+
148
+ def __getitem__(self, index):
149
+ s1, s2, s3, s4 = (
150
+ self.inputs[index],
151
+ self.audios[index],
152
+ self.durations[index],
153
+ self.indices[index],
154
+ )
155
+ return s1, s2, s3, s4
156
+
157
+ def collate_fn(self, data):
158
+ dat = pd.DataFrame(data)
159
+ return [dat[i].tolist() for i in dat]
external_models/TangoFlux/train.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ CUDA_VISISBLE_DEVICES=0,1 accelerate launch --config_file='configs/accelerator_config.yaml' tangoflux/train.py --checkpointing_steps="best" --save_every=5 --config='configs/tangoflux_config.yaml'
external_models/depth-fm/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *__pycache__*
2
+ sandbox
3
+ *.ckpt
4
+ *-depth.png
5
+ evaluation