Zeyue7 commited on
Commit
8ab1cf8
·
1 Parent(s): 83cf111
Files changed (43) hide show
  1. README copy.md +13 -0
  2. app.py +399 -0
  3. packages.txt +2 -0
  4. requirements.txt +40 -0
  5. stable_audio_tools/__init__.py +2 -0
  6. stable_audio_tools/data/__init__.py +0 -0
  7. stable_audio_tools/data/dataset.py +876 -0
  8. stable_audio_tools/data/utils.py +204 -0
  9. stable_audio_tools/inference/__init__.py +0 -0
  10. stable_audio_tools/inference/generation.py +275 -0
  11. stable_audio_tools/inference/sampling.py +235 -0
  12. stable_audio_tools/inference/utils.py +35 -0
  13. stable_audio_tools/interface/__init__.py +0 -0
  14. stable_audio_tools/interface/gradio.py +495 -0
  15. stable_audio_tools/models/__init__.py +1 -0
  16. stable_audio_tools/models/adp.py +1588 -0
  17. stable_audio_tools/models/autoencoders.py +794 -0
  18. stable_audio_tools/models/blocks.py +321 -0
  19. stable_audio_tools/models/bottleneck.py +355 -0
  20. stable_audio_tools/models/codebook_patterns.py +545 -0
  21. stable_audio_tools/models/conditioners.py +711 -0
  22. stable_audio_tools/models/diffusion.py +704 -0
  23. stable_audio_tools/models/discriminators.py +546 -0
  24. stable_audio_tools/models/dit.py +379 -0
  25. stable_audio_tools/models/factory.py +153 -0
  26. stable_audio_tools/models/lm.py +542 -0
  27. stable_audio_tools/models/local_attention.py +278 -0
  28. stable_audio_tools/models/pqmf.py +393 -0
  29. stable_audio_tools/models/pretrained.py +33 -0
  30. stable_audio_tools/models/pretransforms.py +258 -0
  31. stable_audio_tools/models/temptransformer.py +190 -0
  32. stable_audio_tools/models/transformer.py +812 -0
  33. stable_audio_tools/models/utils.py +92 -0
  34. stable_audio_tools/models/wavelets.py +82 -0
  35. stable_audio_tools/training/__init__.py +1 -0
  36. stable_audio_tools/training/autoencoders.py +476 -0
  37. stable_audio_tools/training/diffusion.py +1656 -0
  38. stable_audio_tools/training/factory.py +240 -0
  39. stable_audio_tools/training/lm.py +267 -0
  40. stable_audio_tools/training/losses/__init__.py +1 -0
  41. stable_audio_tools/training/losses/auraloss.py +607 -0
  42. stable_audio_tools/training/losses/losses.py +101 -0
  43. stable_audio_tools/training/utils.py +111 -0
README copy.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AudioX
3
+ emoji: 🎧
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.25.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ ---
12
+
13
+
app.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ import torchaudio
5
+ import os
6
+ from einops import rearrange
7
+ import gc
8
+ import spaces
9
+ import gradio as gr
10
+ import torch
11
+ import torchaudio
12
+ import os
13
+ from einops import rearrange
14
+ from stable_audio_tools import get_pretrained_model
15
+ from stable_audio_tools.inference.generation import generate_diffusion_cond
16
+ from stable_audio_tools.data.utils import read_video, merge_video_audio, load_and_process_audio
17
+ import stat
18
+ import platform
19
+ import logging
20
+ from transformers import logging as transformers_logging
21
+
22
+ transformers_logging.set_verbosity_error()
23
+ logging.getLogger("transformers").setLevel(logging.ERROR)
24
+
25
+ model, model_config = get_pretrained_model('HKUSTAudio/AudioX')
26
+ sample_rate = model_config["sample_rate"]
27
+ sample_size = model_config["sample_size"]
28
+
29
+ TEMP_DIR = "tmp/gradio"
30
+ os.makedirs(TEMP_DIR, exist_ok=True)
31
+ os.chmod(TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
32
+
33
+ VIDEO_TEMP_DIR = os.path.join(TEMP_DIR, "videos")
34
+ os.makedirs(VIDEO_TEMP_DIR, exist_ok=True)
35
+ os.chmod(VIDEO_TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
36
+
37
+
38
+
39
+ @spaces.GPU(duration=10)
40
+ def generate_cond(
41
+ prompt,
42
+ negative_prompt=None,
43
+ video_file=None,
44
+ audio_prompt_file=None,
45
+ audio_prompt_path=None,
46
+ seconds_start=0,
47
+ seconds_total=10,
48
+ cfg_scale=7.0,
49
+ steps=100,
50
+ preview_every=0,
51
+ seed=-1,
52
+ sampler_type="dpmpp-3m-sde",
53
+ sigma_min=0.03,
54
+ sigma_max=500,
55
+ cfg_rescale=0.0,
56
+ use_init=False,
57
+ init_audio=None,
58
+ init_noise_level=0.1,
59
+ mask_cropfrom=None,
60
+ mask_pastefrom=None,
61
+ mask_pasteto=None,
62
+ mask_maskstart=None,
63
+ mask_maskend=None,
64
+ mask_softnessL=None,
65
+ mask_softnessR=None,
66
+ mask_marination=None,
67
+ batch_size=1
68
+ ):
69
+ if torch.cuda.is_available():
70
+ torch.cuda.empty_cache()
71
+ gc.collect()
72
+ print(f"Prompt: {prompt}")
73
+ preview_images = []
74
+ if preview_every == 0:
75
+ preview_every = None
76
+
77
+ try:
78
+ has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
79
+ except Exception:
80
+ has_mps = False
81
+ if has_mps:
82
+ device = torch.device("mps")
83
+ elif torch.cuda.is_available():
84
+ device = torch.device("cuda")
85
+ else:
86
+ device = torch.device("cpu")
87
+
88
+ global model
89
+ model = model.to(device)
90
+
91
+ target_fps = model_config.get("video_fps", 5)
92
+ model_type = model_config.get("model_type", "diffusion_cond")
93
+
94
+ if video_file is not None:
95
+ actual_video_path = video_file['name'] if isinstance(video_file, dict) else video_file.name
96
+ else:
97
+ actual_video_path = None
98
+
99
+ if audio_prompt_file is not None:
100
+ audio_path = audio_prompt_file.name
101
+ elif audio_prompt_path:
102
+ audio_path = audio_prompt_path.strip()
103
+ else:
104
+ audio_path = None
105
+
106
+ Video_tensors = read_video(actual_video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps)
107
+ audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total)
108
+
109
+ audio_tensor = audio_tensor.to(device)
110
+ seconds_input = sample_size / sample_rate
111
+
112
+ if not prompt:
113
+ prompt = ""
114
+
115
+ conditioning = [{
116
+ "video_prompt": [Video_tensors.unsqueeze(0)],
117
+ "text_prompt": prompt,
118
+ "audio_prompt": audio_tensor.unsqueeze(0),
119
+ "seconds_start": seconds_start,
120
+ "seconds_total": seconds_input
121
+ }]
122
+ if negative_prompt:
123
+ negative_conditioning = [{
124
+ "video_prompt": [Video_tensors.unsqueeze(0)],
125
+ "text_prompt": negative_prompt,
126
+ "audio_prompt": audio_tensor.unsqueeze(0),
127
+ "seconds_start": seconds_start,
128
+ "seconds_total": seconds_total
129
+ }] * 1
130
+ else:
131
+ negative_conditioning = None
132
+
133
+ seed = int(seed)
134
+ if not use_init:
135
+ init_audio = None
136
+ input_sample_size = sample_size
137
+
138
+ def progress_callback(callback_info):
139
+ nonlocal preview_images
140
+ denoised = callback_info["denoised"]
141
+ current_step = callback_info["i"]
142
+ sigma = callback_info["sigma"]
143
+ if (current_step - 1) % preview_every == 0:
144
+ if model.pretransform is not None:
145
+ denoised = model.pretransform.decode(denoised)
146
+ denoised = rearrange(denoised, "b d n -> d (b n)")
147
+ denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
148
+ audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
149
+ preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
150
+
151
+ if model_type == "diffusion_cond":
152
+ audio = generate_diffusion_cond(
153
+ model,
154
+ conditioning=conditioning,
155
+ negative_conditioning=negative_conditioning,
156
+ steps=steps,
157
+ cfg_scale=cfg_scale,
158
+ batch_size=batch_size,
159
+ sample_size=input_sample_size,
160
+ sample_rate=sample_rate,
161
+ seed=seed,
162
+ device=device,
163
+ sampler_type=sampler_type,
164
+ sigma_min=sigma_min,
165
+ sigma_max=sigma_max,
166
+ init_audio=init_audio,
167
+ init_noise_level=init_noise_level,
168
+ mask_args=None,
169
+ callback=progress_callback if preview_every is not None else None,
170
+ scale_phi=cfg_rescale
171
+ )
172
+
173
+ audio = rearrange(audio, "b d n -> d (b n)")
174
+
175
+ samples_10s = 10 * sample_rate
176
+ audio = audio[:, :samples_10s]
177
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
178
+
179
+ output_dir = "demo_result"
180
+ os.makedirs(output_dir, exist_ok=True)
181
+ output_audio_path = f"{output_dir}/output.wav"
182
+ torchaudio.save(output_audio_path, audio, sample_rate)
183
+
184
+ if actual_video_path:
185
+ output_video_path = f"{output_dir}/{os.path.basename(actual_video_path)}"
186
+ target_width = 1280
187
+ target_height = 720
188
+ merge_video_audio(
189
+ actual_video_path,
190
+ output_audio_path,
191
+ output_video_path,
192
+ seconds_start,
193
+ seconds_total
194
+ )
195
+ else:
196
+ output_video_path = None
197
+
198
+ del actual_video_path
199
+ torch.cuda.empty_cache()
200
+ gc.collect()
201
+
202
+ return output_video_path, output_audio_path
203
+
204
+
205
+ with gr.Blocks() as interface:
206
+ gr.Markdown(
207
+ """
208
+ # 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation
209
+ **[Paper](https://arxiv.org/abs/2503.10522) · [Project Page](https://zeyuet.github.io/AudioX/) · [Huggingface](https://huggingface.co/HKUSTAudio/AudioX) · [GitHub](https://github.com/ZeyueT/AudioX)**
210
+ """
211
+ )
212
+
213
+ with gr.Tab("Generation"):
214
+ with gr.Row():
215
+ with gr.Column():
216
+ prompt = gr.Textbox(
217
+ show_label=False,
218
+ placeholder="Enter your prompt"
219
+ )
220
+ negative_prompt = gr.Textbox(
221
+ show_label=False,
222
+ placeholder="Negative prompt",
223
+ visible=False
224
+ )
225
+ video_file = gr.File(label="Upload Video File")
226
+ audio_prompt_file = gr.File(
227
+ label="Upload Audio Prompt File",
228
+ visible=False
229
+ )
230
+ audio_prompt_path = gr.Textbox(
231
+ label="Audio Prompt Path",
232
+ placeholder="Enter audio file path",
233
+ visible=False
234
+ )
235
+
236
+ with gr.Row():
237
+ with gr.Column(scale=6):
238
+ with gr.Accordion("Video Params", open=False):
239
+ seconds_start = gr.Slider(
240
+ minimum=0,
241
+ maximum=512,
242
+ step=1,
243
+ value=0,
244
+ label="Video Seconds Start"
245
+ )
246
+ seconds_total = gr.Slider(
247
+ minimum=0,
248
+ maximum=10,
249
+ step=1,
250
+ value=10,
251
+ label="Seconds Total",
252
+ interactive=False
253
+ )
254
+
255
+ with gr.Row():
256
+ with gr.Column(scale=4):
257
+ with gr.Accordion("Sampler Params", open=False):
258
+ steps = gr.Slider(
259
+ minimum=1,
260
+ maximum=500,
261
+ step=1,
262
+ value=100,
263
+ label="Steps"
264
+ )
265
+ preview_every = gr.Slider(
266
+ minimum=0,
267
+ maximum=100,
268
+ step=1,
269
+ value=0,
270
+ label="Preview Every"
271
+ )
272
+ cfg_scale = gr.Slider(
273
+ minimum=0.0,
274
+ maximum=25.0,
275
+ step=0.1,
276
+ value=7.0,
277
+ label="CFG Scale"
278
+ )
279
+ seed = gr.Textbox(
280
+ label="Seed (set to -1 for random seed)",
281
+ value="-1"
282
+ )
283
+ sampler_type = gr.Dropdown(
284
+ choices=[
285
+ "dpmpp-2m-sde",
286
+ "dpmpp-3m-sde",
287
+ "k-heun",
288
+ "k-lms",
289
+ "k-dpmpp-2s-ancestral",
290
+ "k-dpm-2",
291
+ "k-dpm-fast"
292
+ ],
293
+ label="Sampler Type",
294
+ value="dpmpp-3m-sde"
295
+ )
296
+ sigma_min = gr.Slider(
297
+ minimum=0.0,
298
+ maximum=2.0,
299
+ step=0.01,
300
+ value=0.03,
301
+ label="Sigma Min"
302
+ )
303
+ sigma_max = gr.Slider(
304
+ minimum=0.0,
305
+ maximum=1000.0,
306
+ step=0.1,
307
+ value=500,
308
+ label="Sigma Max"
309
+ )
310
+ cfg_rescale = gr.Slider(
311
+ minimum=0.0,
312
+ maximum=1,
313
+ step=0.01,
314
+ value=0.0,
315
+ label="CFG Rescale Amount"
316
+ )
317
+
318
+ with gr.Row():
319
+ with gr.Column(scale=4):
320
+ with gr.Accordion("Init Audio", open=False, visible=False):
321
+ init_audio_checkbox = gr.Checkbox(label="Use Init Audio")
322
+ init_audio_input = gr.Audio(label="Init Audio")
323
+ init_noise_level = gr.Slider(
324
+ minimum=0.1,
325
+ maximum=100.0,
326
+ step=0.01,
327
+ value=0.1,
328
+ label="Init Noise Level"
329
+ )
330
+
331
+ with gr.Row():
332
+ generate_button = gr.Button("Generate", variant="primary")
333
+
334
+ with gr.Row():
335
+ with gr.Column(scale=6):
336
+ video_output = gr.Video(label="Output Video", interactive=False)
337
+ audio_output = gr.Audio(label="Output Audio", interactive=False)
338
+
339
+ inputs = [
340
+ prompt,
341
+ negative_prompt,
342
+ video_file,
343
+ audio_prompt_file,
344
+ audio_prompt_path,
345
+ seconds_start,
346
+ seconds_total,
347
+ cfg_scale,
348
+ steps,
349
+ preview_every,
350
+ seed,
351
+ sampler_type,
352
+ sigma_min,
353
+ sigma_max,
354
+ cfg_rescale,
355
+ init_audio_checkbox,
356
+ init_audio_input,
357
+ init_noise_level
358
+ ]
359
+
360
+ generate_button.click(
361
+ fn=generate_cond,
362
+ inputs=inputs,
363
+ outputs=[video_output, audio_output]
364
+ )
365
+
366
+ gr.Markdown("## Examples")
367
+ with gr.Accordion("Click to show examples", open=False):
368
+ with gr.Row():
369
+ gr.Markdown("**📝 Task: Text-to-Audio**")
370
+ with gr.Column(scale=1.2):
371
+ gr.Markdown("Prompt: *Typing on a keyboard*")
372
+ ex1 = gr.Button("Load Example")
373
+ with gr.Column(scale=1.2):
374
+ gr.Markdown("Prompt: *Ocean waves crashing*")
375
+ ex2 = gr.Button("Load Example")
376
+ with gr.Column(scale=1.2):
377
+ gr.Markdown("Prompt: *Footsteps in snow*")
378
+ ex3 = gr.Button("Load Example")
379
+
380
+ with gr.Row():
381
+ gr.Markdown("**🎶 Task: Text-to-Music**")
382
+ with gr.Column(scale=1.2):
383
+ gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*")
384
+ ex4 = gr.Button("Load Example")
385
+ with gr.Column(scale=1.2):
386
+ gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*")
387
+ ex5 = gr.Button("Load Example")
388
+ with gr.Column(scale=1.2):
389
+ gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*")
390
+ ex6 = gr.Button("Load Example")
391
+
392
+ ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
393
+ ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
394
+ ex3.click(lambda: ["Footsteps in snow", None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
395
+ ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
396
+ ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
397
+ ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
398
+
399
+ interface.queue(5).launch(server_name="0.0.0.0", server_port=7860, share=True)
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1
requirements.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aeiou
2
+ alias-free-torch==0.0.6
3
+ auraloss==0.4.0
4
+ descript-audio-codec==1.0.0
5
+ decord==0.6.0
6
+ einops
7
+ einops_exts
8
+ ema-pytorch==0.2.3
9
+ encodec==0.1.1
10
+ gradio==4.44.1
11
+ gradio_client==1.3.0
12
+ huggingface_hub>=0.16.0
13
+ importlib-resources==5.12.0
14
+ k-diffusion==0.1.1
15
+ laion-clap==1.1.6
16
+ local-attention==1.8.6
17
+ pandas==2.0.2
18
+ pedalboard==0.9.14
19
+ prefigure==0.0.9
20
+ pytorch_lightning==2.1.0
21
+ PyWavelets==1.4.1
22
+ safetensors
23
+ sentencepiece==0.1.99
24
+ torch>=2.1.0
25
+ torchvision>=0.16.0
26
+ torchaudio>=2.1.0
27
+ torchmetrics==1.5.2
28
+ tqdm
29
+ transformers==4.30.0
30
+ v-diffusion-pytorch==0.0.2
31
+ vector-quantize-pytorch==1.9.14
32
+ wandb
33
+ webdataset==0.2.48
34
+ x-transformers==1.42.11
35
+ numpy<2.0.0
36
+ accelerate>=0.20.3
37
+ scipy>=1.10.1
38
+ librosa>=0.10.0
39
+ ffmpeg-python>=0.2.0
40
+ ninja
stable_audio_tools/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .models.factory import create_model_from_config, create_model_from_config_path
2
+ from .models.pretrained import get_pretrained_model
stable_audio_tools/data/__init__.py ADDED
File without changes
stable_audio_tools/data/dataset.py ADDED
@@ -0,0 +1,876 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import io
4
+ import os
5
+ import posixpath
6
+ import random
7
+ import re
8
+ import subprocess
9
+ import time
10
+ import torch
11
+ import torchaudio
12
+ import webdataset as wds
13
+
14
+ from aeiou.core import is_silence
15
+ from os import path
16
+ from pedalboard.io import AudioFile
17
+ from torchaudio import transforms as T
18
+ from typing import Optional, Callable, List
19
+ from torchdata.datapipes.iter import IterDataPipe, IterableWrapper
20
+ from torchdata.datapipes.iter import Prefetcher
21
+
22
+ from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T
23
+ import json
24
+
25
+
26
+ import os
27
+ import datetime
28
+ from memory_profiler import profile
29
+
30
+
31
+ AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
32
+
33
+ # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
34
+
35
+ def fast_scandir(
36
+ dir:str, # top-level directory at which to begin scanning
37
+ ext:list, # list of allowed file extensions,
38
+ #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB
39
+ ):
40
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
41
+ subfolders, files = [], []
42
+ ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed
43
+ try: # hope to avoid 'permission denied' by this try
44
+ for f in os.scandir(dir):
45
+ try: # 'hope to avoid too many levels of symbolic links' error
46
+ if f.is_dir():
47
+ subfolders.append(f.path)
48
+ elif f.is_file():
49
+ file_ext = os.path.splitext(f.name)[1].lower()
50
+ is_hidden = os.path.basename(f.path).startswith(".")
51
+
52
+ if file_ext in ext and not is_hidden:
53
+ files.append(f.path)
54
+ except:
55
+ pass
56
+ except:
57
+ pass
58
+
59
+ for dir in list(subfolders):
60
+ sf, f = fast_scandir(dir, ext)
61
+ subfolders.extend(sf)
62
+ files.extend(f)
63
+ return subfolders, files
64
+
65
+ def extract_audio_paths(jsonl_file, exts):
66
+ audio_paths = []
67
+ video_paths = []
68
+ text_prompts = []
69
+ data_types = []
70
+ with open(jsonl_file, 'r') as file:
71
+ for line in file:
72
+ try:
73
+ data = json.loads(line.strip())
74
+ path = data.get('path', '')
75
+ video_path = data.get('video_path', '')
76
+ text_prompt = data.get('caption', '')
77
+ data_type = data.get('type', None)
78
+ if any(path.endswith(ext) for ext in exts):
79
+ audio_paths.append(path)
80
+ video_paths.append(video_path)
81
+ text_prompts.append(text_prompt)
82
+ data_types.append(data_type)
83
+ except json.JSONDecodeError:
84
+ print(f"Error decoding JSON line: {line.strip()}")
85
+ return audio_paths, video_paths, text_prompts, data_types
86
+
87
+ def keyword_scandir(
88
+ dir: str, # top-level directory at which to begin scanning
89
+ ext: list, # list of allowed file extensions
90
+ keywords: list, # list of keywords to search for in the file name
91
+ ):
92
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
93
+ subfolders, files = [], []
94
+ # make keywords case insensitive
95
+ keywords = [keyword.lower() for keyword in keywords]
96
+ # add starting period to extensions if needed
97
+ ext = ['.'+x if x[0] != '.' else x for x in ext]
98
+ banned_words = ["paxheader", "__macosx"]
99
+ try: # hope to avoid 'permission denied' by this try
100
+ for f in os.scandir(dir):
101
+ try: # 'hope to avoid too many levels of symbolic links' error
102
+ if f.is_dir():
103
+ subfolders.append(f.path)
104
+ elif f.is_file():
105
+ is_hidden = f.name.split("/")[-1][0] == '.'
106
+ has_ext = os.path.splitext(f.name)[1].lower() in ext
107
+ name_lower = f.name.lower()
108
+ has_keyword = any(
109
+ [keyword in name_lower for keyword in keywords])
110
+ has_banned = any(
111
+ [banned_word in name_lower for banned_word in banned_words])
112
+ if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"):
113
+ files.append(f.path)
114
+ except:
115
+ pass
116
+ except:
117
+ pass
118
+
119
+ for dir in list(subfolders):
120
+ sf, f = keyword_scandir(dir, ext, keywords)
121
+ subfolders.extend(sf)
122
+ files.extend(f)
123
+ return subfolders, files
124
+
125
+ def get_audio_filenames(
126
+ paths: list, # directories in which to search
127
+ keywords=None,
128
+ exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus']
129
+ ):
130
+
131
+ "recursively get a list of audio filenames"
132
+ filenames = []
133
+ video_filenames = []
134
+ text_prompts = []
135
+ data_types = []
136
+
137
+ if type(paths) is str:
138
+ paths = [paths]
139
+
140
+
141
+ if os.path.isdir(paths[0]):
142
+ for path in paths: # get a list of relevant filenames
143
+ if keywords is not None:
144
+ subfolders, files = keyword_scandir(path, exts, keywords)
145
+ else:
146
+ subfolders, files = fast_scandir(path, exts)
147
+ filenames.extend(files)
148
+ return filenames
149
+
150
+ elif os.path.isfile(paths[0]):
151
+ assert paths[0].endswith('.jsonl')
152
+ for path in paths:
153
+ audio_paths, video_paths, text_prompt, data_type = extract_audio_paths(path, exts)
154
+ filenames.extend(audio_paths)
155
+ video_filenames.extend(video_paths)
156
+ text_prompts.extend(text_prompt)
157
+ data_types.extend(data_type)
158
+
159
+ return filenames, video_filenames, text_prompts, data_types
160
+
161
+
162
+ class LocalDatasetConfig:
163
+ def __init__(
164
+ self,
165
+ id: str,
166
+ path: str,
167
+ video_fps: int,
168
+ custom_metadata_fn: Optional[Callable[[str], str]] = None
169
+ ):
170
+ self.id = id
171
+ self.path = path
172
+ self.video_fps = video_fps
173
+ self.custom_metadata_fn = custom_metadata_fn
174
+
175
+
176
+ # @profile
177
+ class SampleDataset(torch.utils.data.Dataset):
178
+ # @profile
179
+ def __init__(
180
+ self,
181
+ configs,
182
+ sample_size=65536,
183
+ sample_rate=48000,
184
+ keywords=None,
185
+ random_crop=True,
186
+ force_channels="stereo",
187
+ video_fps=5
188
+ ):
189
+ super().__init__()
190
+ self.filenames = []
191
+ self.video_filenames = []
192
+ self.text_prompts = []
193
+ self.data_types = []
194
+
195
+ self.augs = torch.nn.Sequential(
196
+ PhaseFlipper(),
197
+ )
198
+
199
+ self.root_paths = []
200
+
201
+ self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)
202
+
203
+ self.force_channels = force_channels
204
+
205
+ self.encoding = torch.nn.Sequential(
206
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
207
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
208
+ )
209
+
210
+ self.sr = sample_rate
211
+
212
+ self.custom_metadata_fns = {}
213
+
214
+ for config in configs:
215
+ self.video_fps = config.video_fps
216
+
217
+ self.root_paths.append(config.path)
218
+ audio_files, video_files, text_prompt, data_types = get_audio_filenames(config.path, keywords)
219
+
220
+ self.filenames.extend(audio_files)
221
+ self.video_filenames.extend(video_files)
222
+ self.text_prompts.extend(text_prompt)
223
+ self.data_types.extend(data_types)
224
+ if config.custom_metadata_fn is not None:
225
+ self.custom_metadata_fns[config.path] = config.custom_metadata_fn
226
+
227
+ print(f'Found {len(self.filenames)} files')
228
+
229
+
230
+ def load_file(self, filename):
231
+ ext = filename.split(".")[-1]
232
+
233
+ if ext == "mp3":
234
+ with AudioFile(filename) as f:
235
+ audio = f.read(f.frames)
236
+ audio = torch.from_numpy(audio)
237
+ in_sr = f.samplerate
238
+ else:
239
+ audio, in_sr = torchaudio.load(filename, format=ext)
240
+
241
+ if in_sr != self.sr:
242
+ resample_tf = T.Resample(in_sr, self.sr)
243
+ audio = resample_tf(audio)
244
+
245
+ return audio
246
+
247
+ def __len__(self):
248
+ return len(self.filenames)
249
+
250
+
251
+ def __getitem__(self, idx):
252
+ audio_filename = self.filenames[idx]
253
+ video_filename = self.video_filenames[idx]
254
+ text_prompt = self.text_prompts[idx]
255
+ data_type = self.data_types[idx]
256
+
257
+ try:
258
+
259
+ start_time = time.time()
260
+ audio = self.load_file(audio_filename)
261
+
262
+
263
+ if data_type in ["text_condition-audio", "text_condition-music",
264
+ "video_condition-audio", "video_condition-music",
265
+ "text+video_condition-audio","text+video_condition-music"]:
266
+ if_audio_contition = False
267
+ audio_prompt = torch.zeros((2, self.sr * 10))
268
+ elif data_type in ["audio_condition-audio", "audio_condition-music",
269
+ "uni_condition-audio", "uni_condition-music"]:
270
+ if_audio_contition = True
271
+
272
+ if if_audio_contition:
273
+ audio_org = audio.clamp(-1, 1)
274
+
275
+
276
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio)
277
+
278
+ if self.augs is not None:
279
+ audio = self.augs(audio)
280
+
281
+ audio = audio.clamp(-1, 1)
282
+
283
+ if if_audio_contition:
284
+ if data_type.split("-")[-1] == "audio":
285
+ start_index = max(0, int((seconds_start) * self.sr))
286
+ end_index = int((seconds_start+10) * self.sr)
287
+ audio_prompt = audio_org[:, start_index:end_index]
288
+
289
+ elif data_type.split("-")[-1] == "music":
290
+ if seconds_start < 10:
291
+ start_index = 0
292
+ end_index = int(10 * self.sr)
293
+ else:
294
+ start_index = max(0, int((seconds_start - 10) * self.sr))
295
+ end_index = int(seconds_start * self.sr)
296
+ audio_prompt = audio_org[:, start_index:end_index]
297
+
298
+ # Encode the file to assist in prediction
299
+ if self.encoding is not None:
300
+ audio = self.encoding(audio)
301
+
302
+ info = {}
303
+
304
+
305
+ info["path"] = audio_filename
306
+ info["video_path"] = video_filename
307
+ info["text_prompt"] = text_prompt
308
+ info["audio_prompt"] = audio_prompt
309
+ info["data_type"] = data_type
310
+
311
+ for root_path in self.root_paths:
312
+ if root_path in audio_filename:
313
+ info["relpath"] = path.relpath(audio_filename, root_path)
314
+
315
+ info["timestamps"] = (t_start, t_end)
316
+ info["seconds_start"] = seconds_start
317
+ info["seconds_total"] = seconds_total
318
+ info["padding_mask"] = padding_mask
319
+ info["video_fps"] = self.video_fps
320
+ end_time = time.time()
321
+
322
+ info["load_time"] = end_time - start_time
323
+
324
+ for custom_md_path in self.custom_metadata_fns.keys():
325
+ if os.path.isdir(custom_md_path):
326
+ if custom_md_path in audio_filename:
327
+ custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
328
+ custom_metadata = custom_metadata_fn(info, audio)
329
+ info.update(custom_metadata)
330
+ elif os.path.isfile(custom_md_path):
331
+ custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
332
+ custom_metadata = custom_metadata_fn(info, audio)
333
+ info.update(custom_metadata)
334
+
335
+ if "__reject__" in info and info["__reject__"]:
336
+ return self[random.randrange(len(self))]
337
+
338
+ file_name = audio_filename.split('/')[-1]
339
+
340
+ return (audio, info)
341
+ except Exception as e:
342
+ print(f'Couldn\'t load file {audio_filename}: {e}')
343
+ return self[random.randrange(len(self))]
344
+
345
+ def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None):
346
+ """Return function over iterator that groups key, value pairs into samples.
347
+ :param keys: function that splits the key into key and extension (base_plus_ext)
348
+ :param lcase: convert suffixes to lower case (Default value = True)
349
+ """
350
+ current_sample = None
351
+ for filesample in data:
352
+ assert isinstance(filesample, dict)
353
+ fname, value = filesample["fname"], filesample["data"]
354
+ prefix, suffix = keys(fname)
355
+ if wds.tariterators.trace:
356
+ print(
357
+ prefix,
358
+ suffix,
359
+ current_sample.keys() if isinstance(current_sample, dict) else None,
360
+ )
361
+ if prefix is None:
362
+ continue
363
+ if lcase:
364
+ suffix = suffix.lower()
365
+ if current_sample is None or prefix != current_sample["__key__"]:
366
+ if wds.tariterators.valid_sample(current_sample):
367
+ yield current_sample
368
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
369
+ if suffix in current_sample:
370
+ print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
371
+ if suffixes is None or suffix in suffixes:
372
+ current_sample[suffix] = value
373
+ if wds.tariterators.valid_sample(current_sample):
374
+ yield current_sample
375
+
376
+ wds.tariterators.group_by_keys = group_by_keys
377
+
378
+ # S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
379
+
380
+ def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
381
+ """
382
+ Returns a list of full S3 paths to files in a given S3 bucket and directory path.
383
+ """
384
+ # Ensure dataset_path ends with a trailing slash
385
+ if dataset_path != '' and not dataset_path.endswith('/'):
386
+ dataset_path += '/'
387
+ # Use posixpath to construct the S3 URL path
388
+ bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
389
+ # Construct the `aws s3 ls` command
390
+ cmd = ['aws', 's3', 'ls', bucket_path]
391
+
392
+ if profile is not None:
393
+ cmd.extend(['--profile', profile])
394
+
395
+ if recursive:
396
+ # Add the --recursive flag if requested
397
+ cmd.append('--recursive')
398
+
399
+ # Run the `aws s3 ls` command and capture the output
400
+ run_ls = subprocess.run(cmd, capture_output=True, check=True)
401
+ # Split the output into lines and strip whitespace from each line
402
+ contents = run_ls.stdout.decode('utf-8').split('\n')
403
+ contents = [x.strip() for x in contents if x]
404
+ # Remove the timestamp from lines that begin with a timestamp
405
+ contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
406
+ if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
407
+ # Construct a full S3 path for each file in the contents list
408
+ contents = [posixpath.join(s3_url_prefix or '', x)
409
+ for x in contents if not x.endswith('/')]
410
+ # Apply the filter, if specified
411
+ if filter:
412
+ contents = [x for x in contents if filter in x]
413
+ # Remove redundant directory names in the S3 URL
414
+ if recursive:
415
+ # Get the main directory name from the S3 URL
416
+ main_dir = "/".join(bucket_path.split('/')[3:])
417
+ # Remove the redundant directory names from each file path
418
+ contents = [x.replace(f'{main_dir}', '').replace(
419
+ '//', '/') for x in contents]
420
+ # Print debugging information, if requested
421
+ if debug:
422
+ print("contents = \n", contents)
423
+ # Return the list of S3 paths to files
424
+ return contents
425
+
426
+
427
+ def get_all_s3_urls(
428
+ names=[], # list of all valid [LAION AudioDataset] dataset names
429
+ # list of subsets you want from those datasets, e.g. ['train','valid']
430
+ subsets=[''],
431
+ s3_url_prefix=None, # prefix for those dataset names
432
+ recursive=True, # recursively list all tar files in all subdirs
433
+ filter_str='tar', # only grab files with this substring
434
+ # print debugging info -- note: info displayed likely to change at dev's whims
435
+ debug=False,
436
+ profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
437
+ ):
438
+ "get urls of shards (tar files) for multiple datasets in one s3 bucket"
439
+ urls = []
440
+ for name in names:
441
+ # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
442
+ if s3_url_prefix is None:
443
+ contents_str = name
444
+ else:
445
+ # Construct the S3 path using the s3_url_prefix and the current name value
446
+ contents_str = posixpath.join(s3_url_prefix, name)
447
+ if debug:
448
+ print(f"get_all_s3_urls: {contents_str}:")
449
+ for subset in subsets:
450
+ subset_str = posixpath.join(contents_str, subset)
451
+ if debug:
452
+ print(f"subset_str = {subset_str}")
453
+ # Get the list of tar files in the current subset directory
454
+ profile = profiles.get(name, None)
455
+ tar_list = get_s3_contents(
456
+ subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
457
+ for tar in tar_list:
458
+ # Escape spaces and parentheses in the tar filename for use in the shell command
459
+ tar = tar.replace(" ", "\ ").replace(
460
+ "(", "\(").replace(")", "\)")
461
+ # Construct the S3 path to the current tar file
462
+ s3_path = posixpath.join(name, subset, tar) + " -"
463
+ # Construct the AWS CLI command to download the current tar file
464
+ if s3_url_prefix is None:
465
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
466
+ else:
467
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
468
+ if profiles.get(name):
469
+ request_str += f" --profile {profiles.get(name)}"
470
+ if debug:
471
+ print("request_str = ", request_str)
472
+ # Add the constructed URL to the list of URLs
473
+ urls.append(request_str)
474
+ return urls
475
+
476
+
477
+ def log_and_continue(exn):
478
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
479
+ print(f"Handling webdataset error ({repr(exn)}). Ignoring.")
480
+ return True
481
+
482
+
483
+ def is_valid_sample(sample):
484
+ has_json = "json" in sample
485
+ has_audio = "audio" in sample
486
+ is_silent = is_silence(sample["audio"])
487
+ is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"]
488
+
489
+ return has_json and has_audio and not is_silent and not is_rejected
490
+
491
+ class S3DatasetConfig:
492
+ def __init__(
493
+ self,
494
+ id: str,
495
+ s3_path: str,
496
+ custom_metadata_fn: Optional[Callable[[str], str]] = None,
497
+ profile: Optional[str] = None,
498
+ ):
499
+ self.id = id
500
+ self.path = s3_path
501
+ self.custom_metadata_fn = custom_metadata_fn
502
+ self.profile = profile
503
+ self.urls = []
504
+
505
+ def load_data_urls(self):
506
+ self.urls = get_all_s3_urls(
507
+ names=[self.path],
508
+ s3_url_prefix=None,
509
+ recursive=True,
510
+ profiles={self.path: self.profile} if self.profile else {},
511
+ )
512
+
513
+ return self.urls
514
+
515
+ class LocalWebDatasetConfig:
516
+ def __init__(
517
+ self,
518
+ id: str,
519
+ path: str,
520
+ custom_metadata_fn: Optional[Callable[[str], str]] = None,
521
+ profile: Optional[str] = None,
522
+ ):
523
+ self.id = id
524
+ self.path = path
525
+ self.custom_metadata_fn = custom_metadata_fn
526
+ self.urls = []
527
+
528
+ def load_data_urls(self):
529
+
530
+ self.urls = fast_scandir(self.path, ["tar"])[1]
531
+
532
+ return self.urls
533
+
534
+ def audio_decoder(key, value):
535
+ # Get file extension from key
536
+ ext = key.split(".")[-1]
537
+
538
+ if ext in AUDIO_KEYS:
539
+ return torchaudio.load(io.BytesIO(value))
540
+ else:
541
+ return None
542
+
543
+ def collation_fn(samples):
544
+ batched = list(zip(*samples))
545
+ result = []
546
+ for b in batched:
547
+ if isinstance(b[0], (int, float)):
548
+ b = np.array(b)
549
+ elif isinstance(b[0], torch.Tensor):
550
+ b = torch.stack(b)
551
+ elif isinstance(b[0], np.ndarray):
552
+ b = np.array(b)
553
+ else:
554
+ b = b
555
+ result.append(b)
556
+ return result
557
+
558
+ class WebDatasetDataLoader():
559
+ def __init__(
560
+ self,
561
+ datasets: List[S3DatasetConfig],
562
+ batch_size,
563
+ sample_size,
564
+ sample_rate=48000,
565
+ num_workers=8,
566
+ epoch_steps=1000,
567
+ random_crop=True,
568
+ force_channels="stereo",
569
+ augment_phase=True,
570
+ **data_loader_kwargs
571
+ ):
572
+
573
+ self.datasets = datasets
574
+
575
+ self.sample_size = sample_size
576
+ self.sample_rate = sample_rate
577
+ self.random_crop = random_crop
578
+ self.force_channels = force_channels
579
+ self.augment_phase = augment_phase
580
+
581
+ urls = [dataset.load_data_urls() for dataset in datasets]
582
+
583
+ # Flatten the list of lists of URLs
584
+ urls = [url for dataset_urls in urls for url in dataset_urls]
585
+
586
+ # Shuffle the urls
587
+ random.shuffle(urls)
588
+
589
+ self.dataset = wds.DataPipeline(
590
+ wds.ResampledShards(urls),
591
+ wds.tarfile_to_samples(handler=log_and_continue),
592
+ wds.decode(audio_decoder, handler=log_and_continue),
593
+ wds.map(self.wds_preprocess, handler=log_and_continue),
594
+ wds.select(is_valid_sample),
595
+ wds.to_tuple("audio", "json", handler=log_and_continue),
596
+ #wds.shuffle(bufsize=1000, initial=5000),
597
+ wds.batched(batch_size, partial=False, collation_fn=collation_fn),
598
+ ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps)
599
+
600
+ self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs)
601
+
602
+ def wds_preprocess(self, sample):
603
+
604
+ found_key, rewrite_key = '', ''
605
+ for k, v in sample.items(): # print the all entries in dict
606
+ for akey in AUDIO_KEYS:
607
+ if k.endswith(akey):
608
+ # to rename long/weird key with its simpler counterpart
609
+ found_key, rewrite_key = k, akey
610
+ break
611
+ if '' != found_key:
612
+ break
613
+ if '' == found_key: # got no audio!
614
+ return None # try returning None to tell WebDataset to skip this one
615
+
616
+ audio, in_sr = sample[found_key]
617
+ if in_sr != self.sample_rate:
618
+ resample_tf = T.Resample(in_sr, self.sample_rate)
619
+ audio = resample_tf(audio)
620
+
621
+ if self.sample_size is not None:
622
+ # Pad/crop and get the relative timestamp
623
+ pad_crop = PadCrop_Normalized_T(
624
+ self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate)
625
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(
626
+ audio)
627
+ sample["json"]["seconds_start"] = seconds_start
628
+ sample["json"]["seconds_total"] = seconds_total
629
+ sample["json"]["padding_mask"] = padding_mask
630
+ else:
631
+ t_start, t_end = 0, 1
632
+
633
+ # Check if audio is length zero, initialize to a single zero if so
634
+ if audio.shape[-1] == 0:
635
+ audio = torch.zeros(1, 1)
636
+
637
+ # Make the audio stereo and augment by randomly inverting phase
638
+ augs = torch.nn.Sequential(
639
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
640
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
641
+ PhaseFlipper() if self.augment_phase else torch.nn.Identity()
642
+ )
643
+
644
+ audio = augs(audio)
645
+
646
+ sample["json"]["timestamps"] = (t_start, t_end)
647
+
648
+ if "text" in sample["json"]:
649
+ sample["json"]["prompt"] = sample["json"]["text"]
650
+
651
+ # Check for custom metadata functions
652
+ for dataset in self.datasets:
653
+ if dataset.custom_metadata_fn is None:
654
+ continue
655
+
656
+ if dataset.path in sample["__url__"]:
657
+ custom_metadata = dataset.custom_metadata_fn(sample["json"], audio)
658
+ sample["json"].update(custom_metadata)
659
+
660
+ if found_key != rewrite_key: # rename long/weird key with its simpler counterpart
661
+ del sample[found_key]
662
+
663
+ sample["audio"] = audio
664
+
665
+ # Add audio to the metadata as well for conditioning
666
+ sample["json"]["audio"] = audio
667
+
668
+ return sample
669
+
670
+ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4, video_fps=5):
671
+
672
+ dataset_type = dataset_config.get("dataset_type", None)
673
+
674
+ assert dataset_type is not None, "Dataset type must be specified in dataset config"
675
+
676
+ if audio_channels == 1:
677
+ force_channels = "mono"
678
+ else:
679
+ force_channels = "stereo"
680
+
681
+ if dataset_type == "audio_dir":
682
+
683
+ audio_dir_configs = dataset_config.get("datasets", None)
684
+
685
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
686
+
687
+ configs = []
688
+
689
+ for audio_dir_config in audio_dir_configs:
690
+ audio_dir_path = audio_dir_config.get("path", None)
691
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
692
+
693
+ custom_metadata_fn = None
694
+ custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
695
+
696
+ if custom_metadata_module_path is not None:
697
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
698
+ metadata_module = importlib.util.module_from_spec(spec)
699
+ spec.loader.exec_module(metadata_module)
700
+
701
+ custom_metadata_fn = metadata_module.get_custom_metadata
702
+
703
+ configs.append(
704
+ LocalDatasetConfig(
705
+ id=audio_dir_config["id"],
706
+ path=audio_dir_path,
707
+ custom_metadata_fn=custom_metadata_fn,
708
+ video_fps=video_fps
709
+ )
710
+ )
711
+
712
+ train_set = SampleDataset(
713
+ configs,
714
+ sample_rate=sample_rate,
715
+ sample_size=sample_size,
716
+ random_crop=dataset_config.get("random_crop", True),
717
+ force_channels=force_channels,
718
+ video_fps=video_fps
719
+ )
720
+
721
+ return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
722
+ num_workers=num_workers, persistent_workers=True, pin_memory=False, drop_last=True, collate_fn=collation_fn)
723
+
724
+ elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility
725
+ wds_configs = []
726
+
727
+ for wds_config in dataset_config["datasets"]:
728
+
729
+ custom_metadata_fn = None
730
+ custom_metadata_module_path = wds_config.get("custom_metadata_module", None)
731
+
732
+ if custom_metadata_module_path is not None:
733
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
734
+ metadata_module = importlib.util.module_from_spec(spec)
735
+ spec.loader.exec_module(metadata_module)
736
+
737
+ custom_metadata_fn = metadata_module.get_custom_metadata
738
+
739
+ if "s3_path" in wds_config:
740
+
741
+ wds_configs.append(
742
+ S3DatasetConfig(
743
+ id=wds_config["id"],
744
+ s3_path=wds_config["s3_path"],
745
+ custom_metadata_fn=custom_metadata_fn,
746
+ profile=wds_config.get("profile", None),
747
+ )
748
+ )
749
+
750
+ elif "path" in wds_config:
751
+
752
+ wds_configs.append(
753
+ LocalWebDatasetConfig(
754
+ id=wds_config["id"],
755
+ path=wds_config["path"],
756
+ custom_metadata_fn=custom_metadata_fn
757
+ )
758
+ )
759
+
760
+ return WebDatasetDataLoader(
761
+ wds_configs,
762
+ sample_rate=sample_rate,
763
+ sample_size=sample_size,
764
+ batch_size=batch_size,
765
+ random_crop=dataset_config.get("random_crop", True),
766
+ num_workers=num_workers,
767
+ persistent_workers=True,
768
+ force_channels=force_channels,
769
+ epoch_steps=dataset_config.get("epoch_steps", 2000)
770
+ ).data_loader
771
+
772
+
773
+
774
+
775
+ def create_dataloader_from_config_valid(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4):
776
+
777
+
778
+ dataset_type = dataset_config.get("dataset_type", None)
779
+
780
+ assert dataset_type is not None, "Dataset type must be specified in dataset config"
781
+
782
+ if audio_channels == 1:
783
+ force_channels = "mono"
784
+ else:
785
+ force_channels = "stereo"
786
+
787
+ if dataset_type == "audio_dir":
788
+
789
+ audio_dir_configs = dataset_config.get("datasets", None)
790
+
791
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
792
+
793
+ configs = []
794
+
795
+ for audio_dir_config in audio_dir_configs:
796
+ audio_dir_path = audio_dir_config.get("path", None)
797
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
798
+
799
+ custom_metadata_fn = None
800
+ custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
801
+
802
+ if custom_metadata_module_path is not None:
803
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
804
+ metadata_module = importlib.util.module_from_spec(spec)
805
+ spec.loader.exec_module(metadata_module)
806
+
807
+ custom_metadata_fn = metadata_module.get_custom_metadata
808
+
809
+ configs.append(
810
+ LocalDatasetConfig(
811
+ id=audio_dir_config["id"],
812
+ path=audio_dir_path,
813
+ custom_metadata_fn=custom_metadata_fn
814
+ )
815
+ )
816
+
817
+ valid_set = SampleDataset(
818
+ configs,
819
+ sample_rate=sample_rate,
820
+ sample_size=sample_size,
821
+ random_crop=dataset_config.get("random_crop", True),
822
+ force_channels=force_channels
823
+ )
824
+
825
+
826
+ return torch.utils.data.DataLoader(valid_set, batch_size, shuffle=False,
827
+ num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
828
+
829
+ elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility
830
+ wds_configs = []
831
+
832
+ for wds_config in dataset_config["datasets"]:
833
+
834
+ custom_metadata_fn = None
835
+ custom_metadata_module_path = wds_config.get("custom_metadata_module", None)
836
+
837
+ if custom_metadata_module_path is not None:
838
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
839
+ metadata_module = importlib.util.module_from_spec(spec)
840
+ spec.loader.exec_module(metadata_module)
841
+
842
+ custom_metadata_fn = metadata_module.get_custom_metadata
843
+
844
+ if "s3_path" in wds_config:
845
+
846
+ wds_configs.append(
847
+ S3DatasetConfig(
848
+ id=wds_config["id"],
849
+ s3_path=wds_config["s3_path"],
850
+ custom_metadata_fn=custom_metadata_fn,
851
+ profile=wds_config.get("profile", None),
852
+ )
853
+ )
854
+
855
+ elif "path" in wds_config:
856
+
857
+ wds_configs.append(
858
+ LocalWebDatasetConfig(
859
+ id=wds_config["id"],
860
+ path=wds_config["path"],
861
+ custom_metadata_fn=custom_metadata_fn
862
+ )
863
+ )
864
+
865
+ return WebDatasetDataLoader(
866
+ wds_configs,
867
+ sample_rate=sample_rate,
868
+ sample_size=sample_size,
869
+ batch_size=batch_size,
870
+ random_crop=dataset_config.get("random_crop", True),
871
+ num_workers=num_workers,
872
+ persistent_workers=True,
873
+ force_channels=force_channels,
874
+ epoch_steps=dataset_config.get("epoch_steps", 2000)
875
+ ).data_loader
876
+
stable_audio_tools/data/utils.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+
5
+ from torch import nn
6
+ from typing import Tuple
7
+ import os
8
+ import subprocess as sp
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from decord import VideoReader, cpu
12
+
13
+ class PadCrop(nn.Module):
14
+ def __init__(self, n_samples, randomize=True):
15
+ super().__init__()
16
+ self.n_samples = n_samples
17
+ self.randomize = randomize
18
+
19
+ def __call__(self, signal):
20
+ n, s = signal.shape
21
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
22
+ end = start + self.n_samples
23
+ output = signal.new_zeros([n, self.n_samples])
24
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
25
+ return output
26
+
27
+
28
+ class PadCrop_Normalized_T(nn.Module):
29
+
30
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
31
+ super().__init__()
32
+ self.n_samples = n_samples
33
+ self.sample_rate = sample_rate
34
+ self.randomize = randomize
35
+
36
+ def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int, torch.Tensor]:
37
+ n_channels, n_samples = source.shape
38
+
39
+ # Calculate the duration of the audio in seconds
40
+ total_duration = n_samples // self.sample_rate
41
+
42
+ # If the audio is shorter than the desired length, pad it
43
+ upper_bound = max(0, n_samples - self.n_samples)
44
+
45
+ # If randomize is False, always start at the beginning of the audio
46
+ offset = 0
47
+
48
+ if self.randomize and n_samples > self.n_samples:
49
+ valid_offsets = [
50
+ i * self.sample_rate for i in range(0, total_duration, 10)
51
+ if i * self.sample_rate + self.n_samples <= n_samples and
52
+ (total_duration <= 20 or total_duration - i >= 15)
53
+ ]
54
+ if valid_offsets:
55
+ offset = random.choice(valid_offsets)
56
+
57
+ # Calculate the start and end times of the chunk
58
+ t_start = offset / (upper_bound + self.n_samples)
59
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
60
+
61
+ # Create the chunk
62
+ chunk = source.new_zeros([n_channels, self.n_samples])
63
+
64
+ # Copy the audio into the chunk
65
+ chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
66
+
67
+ # Calculate the start and end times of the chunk in seconds
68
+ seconds_start = math.floor(offset / self.sample_rate)
69
+ seconds_total = math.ceil(n_samples / self.sample_rate)
70
+
71
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
72
+ padding_mask = torch.zeros([self.n_samples])
73
+ padding_mask[:min(n_samples, self.n_samples)] = 1
74
+
75
+ return (
76
+ chunk,
77
+ t_start,
78
+ t_end,
79
+ seconds_start,
80
+ seconds_total,
81
+ padding_mask
82
+ )
83
+
84
+
85
+ class PhaseFlipper(nn.Module):
86
+ "Randomly invert the phase of a signal"
87
+ def __init__(self, p=0.5):
88
+ super().__init__()
89
+ self.p = p
90
+ def __call__(self, signal):
91
+ return -signal if (random.random() < self.p) else signal
92
+
93
+ class Mono(nn.Module):
94
+ def __call__(self, signal):
95
+ return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
96
+
97
+ class Stereo(nn.Module):
98
+ def __call__(self, signal):
99
+ signal_shape = signal.shape
100
+ # Check if it's mono
101
+ if len(signal_shape) == 1: # s -> 2, s
102
+ signal = signal.unsqueeze(0).repeat(2, 1)
103
+ elif len(signal_shape) == 2:
104
+ if signal_shape[0] == 1: #1, s -> 2, s
105
+ signal = signal.repeat(2, 1)
106
+ elif signal_shape[0] > 2: #?, s -> 2,s
107
+ signal = signal[:2, :]
108
+
109
+ return signal
110
+
111
+
112
+ def adjust_video_duration(video_tensor, duration, target_fps):
113
+ current_duration = video_tensor.shape[0]
114
+ target_duration = duration * target_fps
115
+ if current_duration > target_duration:
116
+ video_tensor = video_tensor[:target_duration]
117
+ elif current_duration < target_duration:
118
+ last_frame = video_tensor[-1:]
119
+ repeat_times = target_duration - current_duration
120
+ video_tensor = torch.cat((video_tensor, last_frame.repeat(repeat_times, 1, 1, 1)), dim=0)
121
+ return video_tensor
122
+
123
+ def read_video(filepath, seek_time=0., duration=-1, target_fps=2):
124
+ if filepath is None:
125
+ return torch.zeros((int(duration * target_fps), 3, 224, 224))
126
+
127
+ ext = os.path.splitext(filepath)[1].lower()
128
+ if ext in ['.jpg', '.jpeg', '.png']:
129
+ resize_transform = transforms.Resize((224, 224))
130
+ image = Image.open(filepath).convert("RGB")
131
+ frame = transforms.ToTensor()(image).unsqueeze(0)
132
+ frame = resize_transform(frame)
133
+ target_frames = int(duration * target_fps)
134
+ frame = frame.repeat(int(math.ceil(target_frames / frame.shape[0])), 1, 1, 1)[:target_frames]
135
+ assert frame.shape[0] == target_frames, f"The shape of frame is {frame.shape}"
136
+ return frame
137
+
138
+ vr = VideoReader(filepath, ctx=cpu(0))
139
+ fps = vr.get_avg_fps()
140
+ total_frames = len(vr)
141
+
142
+ seek_frame = int(seek_time * fps)
143
+ if duration > 0:
144
+ total_frames_to_read = int(target_fps * duration)
145
+ frame_interval = int(math.ceil(fps / target_fps))
146
+ end_frame = min(seek_frame + total_frames_to_read * frame_interval, total_frames)
147
+ frame_ids = list(range(seek_frame, end_frame, frame_interval))
148
+ else:
149
+ frame_interval = int(math.ceil(fps / target_fps))
150
+ frame_ids = list(range(0, total_frames, frame_interval))
151
+
152
+ frames = vr.get_batch(frame_ids).asnumpy()
153
+ frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
154
+
155
+ if frames.shape[2] != 224 or frames.shape[3] != 224:
156
+ resize_transform = transforms.Resize((224, 224))
157
+ frames = resize_transform(frames)
158
+
159
+ video_tensor = adjust_video_duration(frames, duration, target_fps)
160
+ assert video_tensor.shape[0] == duration * target_fps, f"The shape of video_tensor is {video_tensor.shape}"
161
+ return video_tensor
162
+
163
+ def merge_video_audio(video_path, audio_path, output_path, start_time, duration, target_width=None, target_height=None):
164
+ command = [
165
+ 'ffmpeg',
166
+ '-y',
167
+ '-ss', str(start_time),
168
+ '-t', str(duration),
169
+ '-i', video_path,
170
+ '-i', audio_path,
171
+ '-c:v', 'copy',
172
+ '-c:a', 'aac',
173
+ '-map', '0:v:0',
174
+ '-map', '1:a:0',
175
+ '-shortest',
176
+ '-strict', 'experimental',
177
+ ]
178
+
179
+ # 如果指定了目标尺寸,添加缩放参数
180
+ if target_width is not None and target_height is not None:
181
+ command.extend(['-vf', f'scale={target_width}:{target_height}'])
182
+
183
+ command.append(output_path)
184
+
185
+ try:
186
+ sp.run(command, check=True)
187
+ print(f"Successfully merged audio and video into {output_path}")
188
+ return output_path
189
+ except sp.CalledProcessError as e:
190
+ print(f"Error merging audio and video: {e}")
191
+ return None
192
+
193
+ def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total):
194
+ if audio_path is None:
195
+ return torch.zeros((2, int(sample_rate * seconds_total)))
196
+ audio_tensor, sr = torchaudio.load(audio_path)
197
+ start_index = int(sample_rate * seconds_start)
198
+ target_length = int(sample_rate * seconds_total)
199
+ end_index = start_index + target_length
200
+ audio_tensor = audio_tensor[:, start_index:end_index]
201
+ if audio_tensor.shape[1] < target_length:
202
+ pad_length = target_length - audio_tensor.shape[1]
203
+ audio_tensor = F.pad(audio_tensor, (pad_length, 0))
204
+ return audio_tensor
stable_audio_tools/inference/__init__.py ADDED
File without changes
stable_audio_tools/inference/generation.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import typing as tp
4
+ import math
5
+ from torchaudio import transforms as T
6
+
7
+ from .utils import prepare_audio
8
+ from .sampling import sample, sample_k, sample_rf
9
+ from ..data.utils import PadCrop
10
+
11
+ def generate_diffusion_uncond(
12
+ model,
13
+ steps: int = 250,
14
+ batch_size: int = 1,
15
+ sample_size: int = 2097152,
16
+ seed: int = -1,
17
+ device: str = "cuda",
18
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
19
+ init_noise_level: float = 1.0,
20
+ return_latents = False,
21
+ **sampler_kwargs
22
+ ) -> torch.Tensor:
23
+
24
+ # The length of the output in audio samples
25
+ audio_sample_size = sample_size
26
+
27
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
28
+ if model.pretransform is not None:
29
+ sample_size = sample_size // model.pretransform.downsampling_ratio
30
+
31
+ # Seed
32
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
33
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
34
+ # seed = 777
35
+ print(seed)
36
+ torch.manual_seed(seed)
37
+ # Define the initial noise immediately after setting the seed
38
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
39
+
40
+ if init_audio is not None:
41
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
42
+ in_sr, init_audio = init_audio
43
+
44
+ io_channels = model.io_channels
45
+
46
+ # For latent models, set the io_channels to the autoencoder's io_channels
47
+ if model.pretransform is not None:
48
+ io_channels = model.pretransform.io_channels
49
+
50
+ # Prepare the initial audio for use by the model
51
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
52
+
53
+ # For latent models, encode the initial audio into latents
54
+ if model.pretransform is not None:
55
+ init_audio = model.pretransform.encode(init_audio)
56
+
57
+ init_audio = init_audio.repeat(batch_size, 1, 1)
58
+ else:
59
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
60
+ init_audio = None
61
+ init_noise_level = None
62
+
63
+ # Inpainting mask
64
+
65
+ if init_audio is not None:
66
+ # variations
67
+ sampler_kwargs["sigma_max"] = init_noise_level
68
+ mask = None
69
+ else:
70
+ mask = None
71
+
72
+ # Now the generative AI part:
73
+
74
+ diff_objective = model.diffusion_objective
75
+
76
+ if diff_objective == "v":
77
+ # k-diffusion denoising process go!
78
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device)
79
+ elif diff_objective == "rectified_flow":
80
+ sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device)
81
+
82
+ # Denoising process done.
83
+ # If this is latent diffusion, decode latents back into audio
84
+ if model.pretransform is not None and not return_latents:
85
+ sampled = model.pretransform.decode(sampled)
86
+
87
+ # Return audio
88
+ return sampled
89
+
90
+
91
+ def generate_diffusion_cond(
92
+ model,
93
+ steps: int = 250,
94
+ cfg_scale=6,
95
+ conditioning: dict = None,
96
+ conditioning_tensors: tp.Optional[dict] = None,
97
+ negative_conditioning: dict = None,
98
+ negative_conditioning_tensors: tp.Optional[dict] = None,
99
+ batch_size: int = 1,
100
+ sample_size: int = 2097152,
101
+ sample_rate: int = 48000,
102
+ seed: int = -1,
103
+ device: str = "cuda",
104
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
105
+ init_noise_level: float = 1.0,
106
+ mask_args: dict = None,
107
+ return_latents = False,
108
+ **sampler_kwargs
109
+ ) -> torch.Tensor:
110
+ """
111
+ Generate audio from a prompt using a diffusion model.
112
+
113
+ Args:
114
+ model: The diffusion model to use for generation.
115
+ steps: The number of diffusion steps to use.
116
+ cfg_scale: Classifier-free guidance scale
117
+ conditioning: A dictionary of conditioning parameters to use for generation.
118
+ conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation.
119
+ batch_size: The batch size to use for generation.
120
+ sample_size: The length of the audio to generate, in samples.
121
+ sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly)
122
+ seed: The random seed to use for generation, or -1 to use a random seed.
123
+ device: The device to use for generation.
124
+ init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation.
125
+ init_noise_level: The noise level to use when generating from an initial audio sample.
126
+ return_latents: Whether to return the latents used for generation instead of the decoded audio.
127
+ **sampler_kwargs: Additional keyword arguments to pass to the sampler.
128
+ """
129
+
130
+ # The length of the output in audio samples
131
+ audio_sample_size = sample_size
132
+
133
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
134
+ if model.pretransform is not None:
135
+ sample_size = sample_size // model.pretransform.downsampling_ratio
136
+
137
+ # Seed
138
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
139
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
140
+ # seed = 777
141
+ # print(seed)
142
+ torch.manual_seed(seed)
143
+ # Define the initial noise immediately after setting the seed
144
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
145
+
146
+ torch.backends.cuda.matmul.allow_tf32 = False
147
+ torch.backends.cudnn.allow_tf32 = False
148
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
149
+ torch.backends.cudnn.benchmark = False
150
+
151
+ # Conditioning
152
+ assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors"
153
+ if conditioning_tensors is None:
154
+ conditioning_tensors = model.conditioner(conditioning, device)
155
+ conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors)
156
+
157
+ if negative_conditioning is not None or negative_conditioning_tensors is not None:
158
+
159
+ if negative_conditioning_tensors is None:
160
+ negative_conditioning_tensors = model.conditioner(negative_conditioning, device)
161
+
162
+ negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True)
163
+ else:
164
+ negative_conditioning_tensors = {}
165
+
166
+ if init_audio is not None:
167
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
168
+ in_sr, init_audio = init_audio
169
+
170
+ io_channels = model.io_channels
171
+
172
+ # For latent models, set the io_channels to the autoencoder's io_channels
173
+ if model.pretransform is not None:
174
+ io_channels = model.pretransform.io_channels
175
+
176
+ # Prepare the initial audio for use by the model
177
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
178
+
179
+ # For latent models, encode the initial audio into latents
180
+ if model.pretransform is not None:
181
+ init_audio = model.pretransform.encode(init_audio)
182
+
183
+ init_audio = init_audio.repeat(batch_size, 1, 1)
184
+ else:
185
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
186
+ init_audio = None
187
+ init_noise_level = None
188
+ mask_args = None
189
+
190
+ # Inpainting mask
191
+ if init_audio is not None and mask_args is not None:
192
+ # Cut and paste init_audio according to cropfrom, pastefrom, pasteto
193
+ # This is helpful for forward and reverse outpainting
194
+ cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size)
195
+ pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size)
196
+ pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size)
197
+ assert pastefrom < pasteto, "Paste From should be less than Paste To"
198
+ croplen = pasteto - pastefrom
199
+ if cropfrom + croplen > sample_size:
200
+ croplen = sample_size - cropfrom
201
+ cropto = cropfrom + croplen
202
+ pasteto = pastefrom + croplen
203
+ cutpaste = init_audio.new_zeros(init_audio.shape)
204
+ cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto]
205
+ #print(cropfrom, cropto, pastefrom, pasteto)
206
+ init_audio = cutpaste
207
+ # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args
208
+ mask = build_mask(sample_size, mask_args)
209
+ mask = mask.to(device)
210
+ elif init_audio is not None and mask_args is None:
211
+ # variations
212
+ sampler_kwargs["sigma_max"] = init_noise_level
213
+ mask = None
214
+ else:
215
+ mask = None
216
+
217
+ model_dtype = next(model.model.parameters()).dtype
218
+ noise = noise.type(model_dtype)
219
+ conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()}
220
+ # Now the generative AI part:
221
+ # k-diffusion denoising process go!
222
+
223
+ diff_objective = model.diffusion_objective
224
+
225
+ if diff_objective == "v":
226
+ # k-diffusion denoising process go!
227
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
228
+
229
+ elif diff_objective == "rectified_flow":
230
+
231
+ if "sigma_min" in sampler_kwargs:
232
+ del sampler_kwargs["sigma_min"]
233
+
234
+ if "sampler_type" in sampler_kwargs:
235
+ del sampler_kwargs["sampler_type"]
236
+
237
+ sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
238
+
239
+ # v-diffusion:
240
+ del noise
241
+ del conditioning_tensors
242
+ del conditioning_inputs
243
+ torch.cuda.empty_cache()
244
+ # Denoising process done.
245
+ # If this is latent diffusion, decode latents back into audio
246
+
247
+ if model.pretransform is not None and not return_latents:
248
+ #cast sampled latents to pretransform dtype
249
+ sampled = sampled.to(next(model.pretransform.parameters()).dtype)
250
+ sampled = model.pretransform.decode(sampled)
251
+
252
+ return sampled
253
+
254
+ # builds a softmask given the parameters
255
+ # returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio,
256
+ # and anything between is a mixture of old/new
257
+ # ideally 0.5 is half/half mixture but i haven't figured this out yet
258
+ def build_mask(sample_size, mask_args):
259
+ maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size)
260
+ maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size)
261
+ softnessL = round(mask_args["softnessL"]/100.0 * sample_size)
262
+ softnessR = round(mask_args["softnessR"]/100.0 * sample_size)
263
+ marination = mask_args["marination"]
264
+ # use hann windows for softening the transition (i don't know if this is correct)
265
+ hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL]
266
+ hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:]
267
+ # build the mask.
268
+ mask = torch.zeros((sample_size))
269
+ mask[maskstart:maskend] = 1
270
+ mask[maskstart:maskstart+softnessL] = hannL
271
+ mask[maskend-softnessR:maskend] = hannR
272
+ # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds
273
+ if marination > 0:
274
+ mask = mask * (1-marination)
275
+ return mask
stable_audio_tools/inference/sampling.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from tqdm import trange, tqdm
4
+
5
+ import k_diffusion as K
6
+
7
+ # Define the noise schedule and sampling loop
8
+ def get_alphas_sigmas(t):
9
+ """Returns the scaling factors for the clean image (alpha) and for the
10
+ noise (sigma), given a timestep."""
11
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
12
+
13
+ def alpha_sigma_to_t(alpha, sigma):
14
+ """Returns a timestep, given the scaling factors for the clean image and for
15
+ the noise."""
16
+ return torch.atan2(sigma, alpha) / math.pi * 2
17
+
18
+ def t_to_alpha_sigma(t):
19
+ """Returns the scaling factors for the clean image and for the noise, given
20
+ a timestep."""
21
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
22
+
23
+
24
+ @torch.no_grad()
25
+ def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args):
26
+ """Draws samples from a model given starting noise. Euler method"""
27
+
28
+ # Make tensor of ones to broadcast the single t values
29
+ ts = x.new_ones([x.shape[0]])
30
+
31
+ # Create the noise schedule
32
+ t = torch.linspace(sigma_max, 0, steps + 1)
33
+
34
+ #alphas, sigmas = 1-t, t
35
+
36
+ for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])):
37
+ # Broadcast the current timestep to the correct shape
38
+ t_curr_tensor = t_curr * torch.ones(
39
+ (x.shape[0],), dtype=x.dtype, device=x.device
40
+ )
41
+ dt = t_prev - t_curr # we solve backwards in our formulation
42
+ x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc)
43
+
44
+ # If we are on the last timestep, output the denoised image
45
+ return x
46
+
47
+ @torch.no_grad()
48
+ def sample(model, x, steps, eta, **extra_args):
49
+ """Draws samples from a model given starting noise. v-diffusion"""
50
+ ts = x.new_ones([x.shape[0]])
51
+
52
+ # Create the noise schedule
53
+ t = torch.linspace(1, 0, steps + 1)[:-1]
54
+
55
+ alphas, sigmas = get_alphas_sigmas(t)
56
+
57
+ # The sampling loop
58
+ for i in trange(steps):
59
+
60
+ # Get the model output (v, the predicted velocity)
61
+ with torch.cuda.amp.autocast():
62
+ v = model(x, ts * t[i], **extra_args).float()
63
+
64
+ # Predict the noise and the denoised image
65
+ pred = x * alphas[i] - v * sigmas[i]
66
+ eps = x * sigmas[i] + v * alphas[i]
67
+
68
+ # If we are not on the last timestep, compute the noisy image for the
69
+ # next timestep.
70
+ if i < steps - 1:
71
+ # If eta > 0, adjust the scaling factor for the predicted noise
72
+ # downward according to the amount of additional noise to add
73
+ ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
74
+ (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
75
+ adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
76
+
77
+ # Recombine the predicted noise and predicted denoised image in the
78
+ # correct proportions for the next step
79
+ x = pred * alphas[i + 1] + eps * adjusted_sigma
80
+
81
+ # Add the correct amount of fresh noise
82
+ if eta:
83
+ x += torch.randn_like(x) * ddim_sigma
84
+
85
+ # If we are on the last timestep, output the denoised image
86
+ return pred
87
+
88
+ # Soft mask inpainting is just shrinking hard (binary) mask inpainting
89
+ # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
90
+ def get_bmask(i, steps, mask):
91
+ strength = (i+1)/(steps)
92
+ # convert to binary mask
93
+ bmask = torch.where(mask<=strength,1,0)
94
+ return bmask
95
+
96
+ def make_cond_model_fn(model, cond_fn):
97
+ def cond_model_fn(x, sigma, **kwargs):
98
+ with torch.enable_grad():
99
+ x = x.detach().requires_grad_()
100
+ denoised = model(x, sigma, **kwargs)
101
+ cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
102
+ cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
103
+ return cond_denoised
104
+ return cond_model_fn
105
+
106
+ # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
107
+ # init_data is init_audio as latents (if this is latent diffusion)
108
+ # For sampling, set both init_data and mask to None
109
+ # For variations, set init_data
110
+ # For inpainting, set both init_data & mask
111
+ def sample_k(
112
+ model_fn,
113
+ noise,
114
+ init_data=None,
115
+ mask=None,
116
+ steps=100,
117
+ sampler_type="dpmpp-2m-sde",
118
+ sigma_min=0.5,
119
+ sigma_max=50,
120
+ rho=1.0, device="cuda",
121
+ callback=None,
122
+ cond_fn=None,
123
+ **extra_args
124
+ ):
125
+
126
+ denoiser = K.external.VDenoiser(model_fn)
127
+
128
+ if cond_fn is not None:
129
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
130
+
131
+ # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
132
+ sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
133
+ # Scale the initial noise by sigma
134
+ noise = noise * sigmas[0]
135
+
136
+ wrapped_callback = callback
137
+
138
+
139
+ if mask is None and init_data is not None:
140
+ # VARIATION (no inpainting)
141
+ # set the initial latent to the init_data, and noise it with initial sigma
142
+
143
+ x = init_data + noise
144
+
145
+ elif mask is not None and init_data is not None:
146
+ # INPAINTING
147
+ bmask = get_bmask(0, steps, mask)
148
+ # initial noising
149
+ input_noised = init_data + noise
150
+ # set the initial latent to a mix of init_data and noise, based on step 0's binary mask
151
+ x = input_noised * bmask + noise * (1-bmask)
152
+ # define the inpainting callback function (Note: side effects, it mutates x)
153
+ # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
154
+ # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
155
+ # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
156
+ def inpainting_callback(args):
157
+ i = args["i"]
158
+ x = args["x"]
159
+ sigma = args["sigma"]
160
+ #denoised = args["denoised"]
161
+ # noise the init_data input with this step's appropriate amount of noise
162
+ input_noised = init_data + torch.randn_like(init_data) * sigma
163
+ # shrinking hard mask
164
+ bmask = get_bmask(i, steps, mask)
165
+ # mix input_noise with x, using binary mask
166
+ new_x = input_noised * bmask + x * (1-bmask)
167
+ # mutate x
168
+ x[:,:,:] = new_x[:,:,:]
169
+ # wrap together the inpainting callback and the user-submitted callback.
170
+ if callback is None:
171
+ wrapped_callback = inpainting_callback
172
+ else:
173
+ wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
174
+ else:
175
+ # SAMPLING
176
+ # set the initial latent to noise
177
+ x = noise
178
+ # x = noise
179
+
180
+ with torch.cuda.amp.autocast():
181
+ if sampler_type == "k-heun":
182
+ return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
183
+ elif sampler_type == "k-lms":
184
+ return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
185
+ elif sampler_type == "k-dpmpp-2s-ancestral":
186
+ return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
187
+ elif sampler_type == "k-dpm-2":
188
+ return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
189
+ elif sampler_type == "k-dpm-fast":
190
+ return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
191
+ elif sampler_type == "k-dpm-adaptive":
192
+ return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
193
+ elif sampler_type == "dpmpp-2m-sde":
194
+ return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
195
+ elif sampler_type == "dpmpp-3m-sde":
196
+ return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
197
+
198
+ # Uses discrete Euler sampling for rectified flow models
199
+ # init_data is init_audio as latents (if this is latent diffusion)
200
+ # For sampling, set both init_data and mask to None
201
+ # For variations, set init_data
202
+ # For inpainting, set both init_data & mask
203
+ def sample_rf(
204
+ model_fn,
205
+ noise,
206
+ init_data=None,
207
+ steps=100,
208
+ sigma_max=1,
209
+ device="cuda",
210
+ callback=None,
211
+ cond_fn=None,
212
+ **extra_args
213
+ ):
214
+
215
+ if sigma_max > 1:
216
+ sigma_max = 1
217
+
218
+ if cond_fn is not None:
219
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
220
+
221
+ wrapped_callback = callback
222
+
223
+ if init_data is not None:
224
+ # VARIATION (no inpainting)
225
+ # Interpolate the init data and the noise for init audio
226
+ x = init_data * (1 - sigma_max) + noise * sigma_max
227
+ else:
228
+ # SAMPLING
229
+ # set the initial latent to noise
230
+ x = noise
231
+
232
+ with torch.cuda.amp.autocast():
233
+ # TODO: Add callback support
234
+ #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args)
235
+ return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args)
stable_audio_tools/inference/utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..data.utils import PadCrop
2
+
3
+ from torchaudio import transforms as T
4
+
5
+ def set_audio_channels(audio, target_channels):
6
+ if target_channels == 1:
7
+ # Convert to mono
8
+ audio = audio.mean(1, keepdim=True)
9
+ elif target_channels == 2:
10
+ # Convert to stereo
11
+ if audio.shape[1] == 1:
12
+ audio = audio.repeat(1, 2, 1)
13
+ elif audio.shape[1] > 2:
14
+ audio = audio[:, :2, :]
15
+ return audio
16
+
17
+ def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
18
+
19
+ audio = audio.to(device)
20
+
21
+ if in_sr != target_sr:
22
+ resample_tf = T.Resample(in_sr, target_sr).to(device)
23
+ audio = resample_tf(audio)
24
+
25
+ audio = PadCrop(target_length, randomize=False)(audio)
26
+
27
+ # Add batch dimension
28
+ if audio.dim() == 1:
29
+ audio = audio.unsqueeze(0).unsqueeze(0)
30
+ elif audio.dim() == 2:
31
+ audio = audio.unsqueeze(0)
32
+
33
+ audio = set_audio_channels(audio, target_channels)
34
+
35
+ return audio
stable_audio_tools/interface/__init__.py ADDED
File without changes
stable_audio_tools/interface/gradio.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import platform
3
+ import os
4
+ import subprocess as sp
5
+ import gradio as gr
6
+ import json
7
+ import torch
8
+ import torchaudio
9
+
10
+ from aeiou.viz import audio_spectrogram_image
11
+ from einops import rearrange
12
+ from safetensors.torch import load_file
13
+ from torch.nn import functional as F
14
+ from torchaudio import transforms as T
15
+
16
+ from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond
17
+ from ..models.factory import create_model_from_config
18
+ from ..models.pretrained import get_pretrained_model
19
+ from ..models.utils import load_ckpt_state_dict
20
+ from ..inference.utils import prepare_audio
21
+ from ..training.utils import copy_state_dict
22
+ from ..data.utils import read_video, merge_video_audio
23
+
24
+
25
+ import os
26
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
+
28
+ import warnings
29
+ warnings.filterwarnings("ignore", category=UserWarning)
30
+
31
+
32
+ device = torch.device("cpu")
33
+
34
+ os.environ['TMPDIR'] = './tmp'
35
+
36
+ current_model_name = None
37
+ current_model = None
38
+ current_sample_rate = None
39
+ current_sample_size = None
40
+
41
+
42
+
43
+ def load_model(model_name, model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
44
+ global model_configurations
45
+
46
+ if pretrained_name is not None:
47
+ print(f"Loading pretrained model {pretrained_name}")
48
+ model, model_config = get_pretrained_model(pretrained_name)
49
+ elif model_config is not None and model_ckpt_path is not None:
50
+ print(f"Creating model from config")
51
+ model = create_model_from_config(model_config)
52
+ print(f"Loading model checkpoint from {model_ckpt_path}")
53
+ copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
54
+ sample_rate = model_config["sample_rate"]
55
+ sample_size = model_config["sample_size"]
56
+ if pretransform_ckpt_path is not None:
57
+ print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
58
+ model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False)
59
+ print(f"Done loading pretransform")
60
+ model.to(device).eval().requires_grad_(False)
61
+ if model_half:
62
+ model.to(torch.float16)
63
+ print(f"Done loading model")
64
+ return model, model_config, sample_rate, sample_size
65
+
66
+ def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total):
67
+ if audio_path is None:
68
+ return torch.zeros((2, int(sample_rate * seconds_total)))
69
+ audio_tensor, sr = torchaudio.load(audio_path)
70
+ start_index = int(sample_rate * seconds_start)
71
+ target_length = int(sample_rate * seconds_total)
72
+ end_index = start_index + target_length
73
+ audio_tensor = audio_tensor[:, start_index:end_index]
74
+ if audio_tensor.shape[1] < target_length:
75
+ pad_length = target_length - audio_tensor.shape[1]
76
+ audio_tensor = F.pad(audio_tensor, (pad_length, 0))
77
+ return audio_tensor
78
+
79
+ def generate_cond(
80
+ prompt,
81
+ negative_prompt=None,
82
+ video_file=None,
83
+ video_path=None,
84
+ audio_prompt_file=None,
85
+ audio_prompt_path=None,
86
+ seconds_start=0,
87
+ seconds_total=10,
88
+ cfg_scale=6.0,
89
+ steps=250,
90
+ preview_every=None,
91
+ seed=-1,
92
+ sampler_type="dpmpp-3m-sde",
93
+ sigma_min=0.03,
94
+ sigma_max=1000,
95
+ cfg_rescale=0.0,
96
+ use_init=False,
97
+ init_audio=None,
98
+ init_noise_level=1.0,
99
+ mask_cropfrom=None,
100
+ mask_pastefrom=None,
101
+ mask_pasteto=None,
102
+ mask_maskstart=None,
103
+ mask_maskend=None,
104
+ mask_softnessL=None,
105
+ mask_softnessR=None,
106
+ mask_marination=None,
107
+ batch_size=1
108
+ ):
109
+ if torch.cuda.is_available():
110
+ torch.cuda.empty_cache()
111
+ gc.collect()
112
+ print(f"Prompt: {prompt}")
113
+ preview_images = []
114
+ if preview_every == 0:
115
+ preview_every = None
116
+
117
+ try:
118
+ has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
119
+ except Exception:
120
+ has_mps = False
121
+ if has_mps:
122
+ device = torch.device("mps")
123
+ elif torch.cuda.is_available():
124
+ device = torch.device("cuda")
125
+ else:
126
+ device = torch.device("cpu")
127
+ model_name = 'default'
128
+ cfg = model_configurations[model_name]
129
+ model_config_path = cfg.get("model_config")
130
+ ckpt_path = cfg.get("ckpt_path")
131
+ pretrained_name = cfg.get("pretrained_name")
132
+ pretransform_ckpt_path = cfg.get("pretransform_ckpt_path")
133
+ model_type = cfg.get("model_type", "diffusion_cond")
134
+ if model_config_path:
135
+ with open(model_config_path) as f:
136
+ model_config = json.load(f)
137
+ else:
138
+ model_config = None
139
+ target_fps = model_config.get("video_fps", 5)
140
+ global current_model_name, current_model, current_sample_rate, current_sample_size
141
+ if current_model is None or model_name != current_model_name:
142
+ current_model, model_config, sample_rate, sample_size = load_model(
143
+ model_name=model_name,
144
+ model_config=model_config,
145
+ model_ckpt_path=ckpt_path,
146
+ pretrained_name=pretrained_name,
147
+ pretransform_ckpt_path=pretransform_ckpt_path,
148
+ device=device,
149
+ model_half=False
150
+ )
151
+ current_model_name = model_name
152
+ model = current_model
153
+ current_sample_rate = sample_rate
154
+ current_sample_size = sample_size
155
+ else:
156
+ model = current_model
157
+ sample_rate = current_sample_rate
158
+ sample_size = current_sample_size
159
+ if video_file is not None:
160
+ video_path = video_file.name
161
+ elif video_path:
162
+ video_path = video_path.strip()
163
+ else:
164
+ video_path = None
165
+
166
+ if audio_prompt_file is not None:
167
+ print(f'audio_prompt_file: {audio_prompt_file}')
168
+ audio_path = audio_prompt_file.name
169
+ elif audio_prompt_path:
170
+ audio_path = audio_prompt_path.strip()
171
+ else:
172
+ audio_path = None
173
+
174
+ Video_tensors = read_video(video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps)
175
+ audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total)
176
+
177
+ audio_tensor = audio_tensor.to(device)
178
+ seconds_input = sample_size / sample_rate
179
+ print(f'video_path: {video_path}')
180
+
181
+ if not prompt:
182
+ prompt = ""
183
+
184
+ conditioning = [{
185
+ "video_prompt": [Video_tensors.unsqueeze(0)],
186
+ "text_prompt": prompt,
187
+ "audio_prompt": audio_tensor.unsqueeze(0),
188
+ "seconds_start": seconds_start,
189
+ "seconds_total": seconds_input
190
+ }] * batch_size
191
+ if negative_prompt:
192
+ negative_conditioning = [{
193
+ "video_prompt": [Video_tensors.unsqueeze(0)],
194
+ "text_prompt": negative_prompt,
195
+ "audio_prompt": audio_tensor.unsqueeze(0),
196
+ "seconds_start": seconds_start,
197
+ "seconds_total": seconds_total
198
+ }] * batch_size
199
+ else:
200
+ negative_conditioning = None
201
+ try:
202
+ device = next(model.parameters()).device
203
+ except Exception as e:
204
+ device = next(current_model.parameters()).device
205
+ seed = int(seed)
206
+ if not use_init:
207
+ init_audio = None
208
+ input_sample_size = sample_size
209
+ if init_audio is not None:
210
+ in_sr, init_audio = init_audio
211
+ init_audio = torch.from_numpy(init_audio).float().div(32767)
212
+ if init_audio.dim() == 1:
213
+ init_audio = init_audio.unsqueeze(0)
214
+ elif init_audio.dim() == 2:
215
+ init_audio = init_audio.transpose(0, 1)
216
+ if in_sr != sample_rate:
217
+ resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
218
+ init_audio = resample_tf(init_audio)
219
+ audio_length = init_audio.shape[-1]
220
+ if audio_length > sample_size:
221
+ input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
222
+ init_audio = (sample_rate, init_audio)
223
+ def progress_callback(callback_info):
224
+ nonlocal preview_images
225
+ denoised = callback_info["denoised"]
226
+ current_step = callback_info["i"]
227
+ sigma = callback_info["sigma"]
228
+ if (current_step - 1) % preview_every == 0:
229
+ if model.pretransform is not None:
230
+ denoised = model.pretransform.decode(denoised)
231
+ denoised = rearrange(denoised, "b d n -> d (b n)")
232
+ denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
233
+ audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
234
+ preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
235
+ if mask_cropfrom is not None:
236
+ mask_args = {
237
+ "cropfrom": mask_cropfrom,
238
+ "pastefrom": mask_pastefrom,
239
+ "pasteto": mask_pasteto,
240
+ "maskstart": mask_maskstart,
241
+ "maskend": mask_maskend,
242
+ "softnessL": mask_softnessL,
243
+ "softnessR": mask_softnessR,
244
+ "marination": mask_marination,
245
+ }
246
+ else:
247
+ mask_args = None
248
+ if model_type == "diffusion_cond":
249
+ audio = generate_diffusion_cond(
250
+ model,
251
+ conditioning=conditioning,
252
+ negative_conditioning=negative_conditioning,
253
+ steps=steps,
254
+ cfg_scale=cfg_scale,
255
+ batch_size=batch_size,
256
+ sample_size=input_sample_size,
257
+ sample_rate=sample_rate,
258
+ seed=seed,
259
+ device=device,
260
+ sampler_type=sampler_type,
261
+ sigma_min=sigma_min,
262
+ sigma_max=sigma_max,
263
+ init_audio=init_audio,
264
+ init_noise_level=init_noise_level,
265
+ mask_args=mask_args,
266
+ callback=progress_callback if preview_every is not None else None,
267
+ scale_phi=cfg_rescale
268
+ )
269
+ elif model_type == "diffusion_uncond":
270
+ audio = generate_diffusion_uncond(
271
+ model,
272
+ steps=steps,
273
+ batch_size=batch_size,
274
+ sample_size=input_sample_size,
275
+ seed=seed,
276
+ device=device,
277
+ sampler_type=sampler_type,
278
+ sigma_min=sigma_min,
279
+ sigma_max=sigma_max,
280
+ init_audio=init_audio,
281
+ init_noise_level=init_noise_level,
282
+ callback=progress_callback if preview_every is not None else None
283
+ )
284
+ else:
285
+ raise ValueError(f"Unsupported model type: {model_type}")
286
+ audio = rearrange(audio, "b d n -> d (b n)")
287
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
288
+ file_name = os.path.basename(video_path) if video_path else "output"
289
+ output_dir = f"demo_result"
290
+ if not os.path.exists(output_dir):
291
+ os.makedirs(output_dir)
292
+ output_video_path = f"{output_dir}/{file_name}"
293
+ torchaudio.save(f"{output_dir}/output.wav", audio, sample_rate)
294
+ if not os.path.exists(output_dir):
295
+ os.makedirs(output_dir)
296
+ if video_path:
297
+ merge_video_audio(video_path, f"{output_dir}/output.wav", output_video_path, seconds_start, seconds_total)
298
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
299
+ del video_path
300
+ torch.cuda.empty_cache()
301
+ gc.collect()
302
+ return (output_video_path, f"{output_dir}/output.wav")
303
+
304
+ def toggle_custom_model(selected_model):
305
+ return gr.Row.update(visible=(selected_model == "Custom Model"))
306
+
307
+ def create_sampling_ui(model_config_map, inpainting=False):
308
+ with gr.Blocks() as demo:
309
+ gr.Markdown(
310
+ """
311
+ # 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation
312
+ **[Project Page](https://zeyuet.github.io/AudioX/) · [Huggingface](https://huggingface.co/Zeyue7/AudioX) · [GitHub](https://github.com/ZeyueT/AudioX)**
313
+ """
314
+ )
315
+
316
+ with gr.Tab("Generation"):
317
+
318
+ with gr.Row():
319
+ with gr.Column():
320
+ prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt")
321
+ negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt", visible=False)
322
+ video_path = gr.Textbox(label="Video Path", placeholder="Enter video file path")
323
+ video_file = gr.File(label="Upload Video File")
324
+ audio_prompt_file = gr.File(label="Upload Audio Prompt File", visible=False)
325
+ audio_prompt_path = gr.Textbox(label="Audio Prompt Path", placeholder="Enter audio file path", visible=False)
326
+ with gr.Row():
327
+ with gr.Column(scale=6):
328
+ with gr.Accordion("Video Params", open=False):
329
+ seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Video Seconds Start")
330
+ seconds_total_slider = gr.Slider(minimum=0, maximum=10, step=1, value=10, label="Seconds Total", interactive=False)
331
+ with gr.Row():
332
+ with gr.Column(scale=4):
333
+ with gr.Accordion("Sampler Params", open=False):
334
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
335
+ preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every")
336
+ cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG Scale")
337
+ seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
338
+ sampler_type_dropdown = gr.Dropdown(
339
+ ["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"],
340
+ label="Sampler Type",
341
+ value="dpmpp-3m-sde"
342
+ )
343
+ sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma Min")
344
+ sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma Max")
345
+ cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG Rescale Amount")
346
+ with gr.Row():
347
+ with gr.Column(scale=4):
348
+ with gr.Accordion("Init Audio", open=False, visible=False):
349
+ init_audio_checkbox = gr.Checkbox(label="Use Init Audio")
350
+ init_audio_input = gr.Audio(label="Init Audio")
351
+ init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init Noise Level")
352
+ gr.Markdown("## Examples")
353
+ with gr.Accordion("Click to show examples", open=False):
354
+ with gr.Row():
355
+ gr.Markdown("**📝 Task: Text-to-Audio**")
356
+ with gr.Column(scale=1.2):
357
+ gr.Markdown("Prompt: *Typing on a keyboard*")
358
+ ex1 = gr.Button("Load Example")
359
+ with gr.Column(scale=1.2):
360
+ gr.Markdown("Prompt: *Ocean waves crashing*")
361
+ ex2 = gr.Button("Load Example")
362
+ with gr.Column(scale=1.2):
363
+ gr.Markdown("Prompt: *Footsteps in snow*")
364
+ ex3 = gr.Button("Load Example")
365
+ with gr.Row():
366
+ gr.Markdown("**🎶 Task: Text-to-Music**")
367
+ with gr.Column(scale=1.2):
368
+ gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*")
369
+ ex4 = gr.Button("Load Example")
370
+ with gr.Column(scale=1.2):
371
+ gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*")
372
+ ex5 = gr.Button("Load Example")
373
+ with gr.Column(scale=1.2):
374
+ gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*")
375
+ ex6 = gr.Button("Load Example")
376
+ with gr.Row():
377
+ gr.Markdown("**🎬 Task: Video-to-Audio**\nPrompt: *Generate general audio for the video*")
378
+ with gr.Column(scale=1.2):
379
+ gr.Video("example/V2A_sample-1.mp4")
380
+ ex7 = gr.Button("Load Example")
381
+ with gr.Column(scale=1.2):
382
+ gr.Video("example/V2A_sample-2.mp4")
383
+ ex8 = gr.Button("Load Example")
384
+ with gr.Column(scale=1.2):
385
+ gr.Video("example/V2A_sample-3.mp4")
386
+ ex9 = gr.Button("Load Example")
387
+ with gr.Row():
388
+ gr.Markdown("**🎵 Task: Video-to-Music**\nPrompt: *Generate music for the video*")
389
+ with gr.Column(scale=1.2):
390
+ gr.Video("example/V2M_sample-1.mp4")
391
+ ex10 = gr.Button("Load Example")
392
+ with gr.Column(scale=1.2):
393
+ gr.Video("example/V2M_sample-2.mp4")
394
+ ex11 = gr.Button("Load Example")
395
+ with gr.Column(scale=1.2):
396
+ gr.Video("example/V2M_sample-3.mp4")
397
+ ex12 = gr.Button("Load Example")
398
+ with gr.Row():
399
+ generate_button = gr.Button("Generate", variant='primary', scale=1)
400
+ with gr.Row():
401
+ with gr.Column(scale=6):
402
+ video_output = gr.Video(label="Output Video", interactive=False)
403
+ audio_output = gr.Audio(label="Output Audio", interactive=False)
404
+ send_to_init_button = gr.Button("Send to Init Audio", scale=1, visible=False)
405
+ send_to_init_button.click(
406
+ fn=lambda audio: audio,
407
+ inputs=[audio_output],
408
+ outputs=[init_audio_input]
409
+ )
410
+ inputs = [
411
+ prompt,
412
+ negative_prompt,
413
+ video_file,
414
+ video_path,
415
+ audio_prompt_file,
416
+ audio_prompt_path,
417
+ seconds_start_slider,
418
+ seconds_total_slider,
419
+ cfg_scale_slider,
420
+ steps_slider,
421
+ preview_every_slider,
422
+ seed_textbox,
423
+ sampler_type_dropdown,
424
+ sigma_min_slider,
425
+ sigma_max_slider,
426
+ cfg_rescale_slider,
427
+ init_audio_checkbox,
428
+ init_audio_input,
429
+ init_noise_level_slider
430
+ ]
431
+ generate_button.click(
432
+ fn=generate_cond,
433
+ inputs=inputs,
434
+ outputs=[
435
+ video_output,
436
+ audio_output
437
+ ],
438
+ api_name="generate"
439
+ )
440
+ ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
441
+ ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
442
+ ex3.click(lambda: ["Footsteps in snow", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
443
+ ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
444
+ ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
445
+ ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
446
+ ex7.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3737819478", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
447
+ ex8.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "1900718499", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
448
+ ex9.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "2289822202", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
449
+ ex10.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3498087420", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
450
+ ex11.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "3753837734", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
451
+ ex12.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "3510832996", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
452
+ return demo
453
+
454
+ def create_txt2audio_ui(model_config_map):
455
+ with gr.Blocks(css=".gradio-container { max-width: 1120px; margin: auto; }") as ui:
456
+ with gr.Tab("Generation"):
457
+ create_sampling_ui(model_config_map)
458
+ return ui
459
+
460
+ def toggle_custom_model(selected_model):
461
+ return gr.Row.update(visible=(selected_model == "Custom Model"))
462
+
463
+ def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False):
464
+ global model_configurations
465
+ global device
466
+
467
+ try:
468
+ has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
469
+ except Exception:
470
+ has_mps = False
471
+
472
+ if has_mps:
473
+ device = torch.device("mps")
474
+ elif torch.cuda.is_available():
475
+ device = torch.device("cuda")
476
+ else:
477
+ device = torch.device("cpu")
478
+
479
+ print("Using device:", device)
480
+
481
+ model_configurations = {
482
+ "default": {
483
+ "model_config": "./model/config.json",
484
+ "ckpt_path": "./model/model.ckpt"
485
+ }
486
+ }
487
+ ui = create_txt2audio_ui(model_configurations)
488
+ return ui
489
+
490
+ if __name__ == "__main__":
491
+ ui = create_ui(
492
+ model_config_path='./model/config.json',
493
+ share=True
494
+ )
495
+ ui.launch()
stable_audio_tools/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .factory import create_model_from_config, create_model_from_config_path
stable_audio_tools/models/adp.py ADDED
@@ -0,0 +1,1588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
2
+ # License can be found in LICENSES/LICENSE_ADP.txt
3
+
4
+ import math
5
+ from inspect import isfunction
6
+ from math import ceil, floor, log, pi, log2
7
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
8
+ from packaging import version
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange, reduce, repeat
13
+ from einops.layers.torch import Rearrange
14
+ from einops_exts import rearrange_many
15
+ from torch import Tensor, einsum
16
+ from torch.backends.cuda import sdp_kernel
17
+ from torch.nn import functional as F
18
+ from dac.nn.layers import Snake1d
19
+
20
+ """
21
+ Utils
22
+ """
23
+
24
+
25
+ class ConditionedSequential(nn.Module):
26
+ def __init__(self, *modules):
27
+ super().__init__()
28
+ self.module_list = nn.ModuleList(*modules)
29
+
30
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
31
+ for module in self.module_list:
32
+ x = module(x, mapping)
33
+ return x
34
+
35
+ T = TypeVar("T")
36
+
37
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
38
+ if exists(val):
39
+ return val
40
+ return d() if isfunction(d) else d
41
+
42
+ def exists(val: Optional[T]) -> T:
43
+ return val is not None
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
52
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
53
+ for key in d.keys():
54
+ no_prefix = int(not key.startswith(prefix))
55
+ return_dicts[no_prefix][key] = d[key]
56
+ return return_dicts
57
+
58
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
59
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
60
+ if keep_prefix:
61
+ return kwargs_with_prefix, kwargs
62
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
63
+ return kwargs_no_prefix, kwargs
64
+
65
+ """
66
+ Convolutional Blocks
67
+ """
68
+ import typing as tp
69
+
70
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
71
+ # License available in LICENSES/LICENSE_META.txt
72
+
73
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
74
+ padding_total: int = 0) -> int:
75
+ """See `pad_for_conv1d`."""
76
+ length = x.shape[-1]
77
+ n_frames = (length - kernel_size + padding_total) / stride + 1
78
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
79
+ return ideal_length - length
80
+
81
+
82
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
83
+ """Pad for a convolution to make sure that the last window is full.
84
+ Extra padding is added at the end. This is required to ensure that we can rebuild
85
+ an output of the same length, as otherwise, even with padding, some time steps
86
+ might get removed.
87
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
88
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
89
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
90
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
91
+ 1 2 3 4 # once you removed padding, we are missing one time step !
92
+ """
93
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
94
+ return F.pad(x, (0, extra_padding))
95
+
96
+
97
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
98
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
99
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
100
+ """
101
+ length = x.shape[-1]
102
+ padding_left, padding_right = paddings
103
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
104
+ if mode == 'reflect':
105
+ max_pad = max(padding_left, padding_right)
106
+ extra_pad = 0
107
+ if length <= max_pad:
108
+ extra_pad = max_pad - length + 1
109
+ x = F.pad(x, (0, extra_pad))
110
+ padded = F.pad(x, paddings, mode, value)
111
+ end = padded.shape[-1] - extra_pad
112
+ return padded[..., :end]
113
+ else:
114
+ return F.pad(x, paddings, mode, value)
115
+
116
+
117
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
118
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
119
+ padding_left, padding_right = paddings
120
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
121
+ assert (padding_left + padding_right) <= x.shape[-1]
122
+ end = x.shape[-1] - padding_right
123
+ return x[..., padding_left: end]
124
+
125
+
126
+ class Conv1d(nn.Conv1d):
127
+ def __init__(self, *args, **kwargs):
128
+ super().__init__(*args, **kwargs)
129
+
130
+ def forward(self, x: Tensor, causal=False) -> Tensor:
131
+ kernel_size = self.kernel_size[0]
132
+ stride = self.stride[0]
133
+ dilation = self.dilation[0]
134
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
135
+ padding_total = kernel_size - stride
136
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
137
+ if causal:
138
+ # Left padding for causal
139
+ x = pad1d(x, (padding_total, extra_padding))
140
+ else:
141
+ # Asymmetric padding required for odd strides
142
+ padding_right = padding_total // 2
143
+ padding_left = padding_total - padding_right
144
+ x = pad1d(x, (padding_left, padding_right + extra_padding))
145
+ return super().forward(x)
146
+
147
+ class ConvTranspose1d(nn.ConvTranspose1d):
148
+ def __init__(self, *args, **kwargs):
149
+ super().__init__(*args, **kwargs)
150
+
151
+ def forward(self, x: Tensor, causal=False) -> Tensor:
152
+ kernel_size = self.kernel_size[0]
153
+ stride = self.stride[0]
154
+ padding_total = kernel_size - stride
155
+
156
+ y = super().forward(x)
157
+
158
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
159
+ # removed at the very end, when keeping only the right length for the output,
160
+ # as removing it here would require also passing the length at the matching layer
161
+ # in the encoder.
162
+ if causal:
163
+ padding_right = ceil(padding_total)
164
+ padding_left = padding_total - padding_right
165
+ y = unpad1d(y, (padding_left, padding_right))
166
+ else:
167
+ # Asymmetric padding required for odd strides
168
+ padding_right = padding_total // 2
169
+ padding_left = padding_total - padding_right
170
+ y = unpad1d(y, (padding_left, padding_right))
171
+ return y
172
+
173
+
174
+ def Downsample1d(
175
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
176
+ ) -> nn.Module:
177
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
178
+
179
+ return Conv1d(
180
+ in_channels=in_channels,
181
+ out_channels=out_channels,
182
+ kernel_size=factor * kernel_multiplier + 1,
183
+ stride=factor
184
+ )
185
+
186
+
187
+ def Upsample1d(
188
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
189
+ ) -> nn.Module:
190
+
191
+ if factor == 1:
192
+ return Conv1d(
193
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3
194
+ )
195
+
196
+ if use_nearest:
197
+ return nn.Sequential(
198
+ nn.Upsample(scale_factor=factor, mode="nearest"),
199
+ Conv1d(
200
+ in_channels=in_channels,
201
+ out_channels=out_channels,
202
+ kernel_size=3
203
+ ),
204
+ )
205
+ else:
206
+ return ConvTranspose1d(
207
+ in_channels=in_channels,
208
+ out_channels=out_channels,
209
+ kernel_size=factor * 2,
210
+ stride=factor
211
+ )
212
+
213
+
214
+ class ConvBlock1d(nn.Module):
215
+ def __init__(
216
+ self,
217
+ in_channels: int,
218
+ out_channels: int,
219
+ *,
220
+ kernel_size: int = 3,
221
+ stride: int = 1,
222
+ dilation: int = 1,
223
+ num_groups: int = 8,
224
+ use_norm: bool = True,
225
+ use_snake: bool = False
226
+ ) -> None:
227
+ super().__init__()
228
+
229
+ self.groupnorm = (
230
+ nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
231
+ if use_norm
232
+ else nn.Identity()
233
+ )
234
+
235
+ if use_snake:
236
+ self.activation = Snake1d(in_channels)
237
+ else:
238
+ self.activation = nn.SiLU()
239
+
240
+ self.project = Conv1d(
241
+ in_channels=in_channels,
242
+ out_channels=out_channels,
243
+ kernel_size=kernel_size,
244
+ stride=stride,
245
+ dilation=dilation,
246
+ )
247
+
248
+ def forward(
249
+ self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
250
+ ) -> Tensor:
251
+ x = self.groupnorm(x)
252
+ if exists(scale_shift):
253
+ scale, shift = scale_shift
254
+ x = x * (scale + 1) + shift
255
+ x = self.activation(x)
256
+ return self.project(x, causal=causal)
257
+
258
+
259
+ class MappingToScaleShift(nn.Module):
260
+ def __init__(
261
+ self,
262
+ features: int,
263
+ channels: int,
264
+ ):
265
+ super().__init__()
266
+
267
+ self.to_scale_shift = nn.Sequential(
268
+ nn.SiLU(),
269
+ nn.Linear(in_features=features, out_features=channels * 2),
270
+ )
271
+
272
+ def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
273
+ scale_shift = self.to_scale_shift(mapping)
274
+ scale_shift = rearrange(scale_shift, "b c -> b c 1")
275
+ scale, shift = scale_shift.chunk(2, dim=1)
276
+ return scale, shift
277
+
278
+
279
+ class ResnetBlock1d(nn.Module):
280
+ def __init__(
281
+ self,
282
+ in_channels: int,
283
+ out_channels: int,
284
+ *,
285
+ kernel_size: int = 3,
286
+ stride: int = 1,
287
+ dilation: int = 1,
288
+ use_norm: bool = True,
289
+ use_snake: bool = False,
290
+ num_groups: int = 8,
291
+ context_mapping_features: Optional[int] = None,
292
+ ) -> None:
293
+ super().__init__()
294
+
295
+ self.use_mapping = exists(context_mapping_features)
296
+
297
+ self.block1 = ConvBlock1d(
298
+ in_channels=in_channels,
299
+ out_channels=out_channels,
300
+ kernel_size=kernel_size,
301
+ stride=stride,
302
+ dilation=dilation,
303
+ use_norm=use_norm,
304
+ num_groups=num_groups,
305
+ use_snake=use_snake
306
+ )
307
+
308
+ if self.use_mapping:
309
+ assert exists(context_mapping_features)
310
+ self.to_scale_shift = MappingToScaleShift(
311
+ features=context_mapping_features, channels=out_channels
312
+ )
313
+
314
+ self.block2 = ConvBlock1d(
315
+ in_channels=out_channels,
316
+ out_channels=out_channels,
317
+ use_norm=use_norm,
318
+ num_groups=num_groups,
319
+ use_snake=use_snake
320
+ )
321
+
322
+ self.to_out = (
323
+ Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
324
+ if in_channels != out_channels
325
+ else nn.Identity()
326
+ )
327
+
328
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
329
+ assert_message = "context mapping required if context_mapping_features > 0"
330
+ assert not (self.use_mapping ^ exists(mapping)), assert_message
331
+
332
+ h = self.block1(x, causal=causal)
333
+
334
+ scale_shift = None
335
+ if self.use_mapping:
336
+ scale_shift = self.to_scale_shift(mapping)
337
+
338
+ h = self.block2(h, scale_shift=scale_shift, causal=causal)
339
+
340
+ return h + self.to_out(x)
341
+
342
+
343
+ class Patcher(nn.Module):
344
+ def __init__(
345
+ self,
346
+ in_channels: int,
347
+ out_channels: int,
348
+ patch_size: int,
349
+ context_mapping_features: Optional[int] = None,
350
+ use_snake: bool = False,
351
+ ):
352
+ super().__init__()
353
+ assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
354
+ assert out_channels % patch_size == 0, assert_message
355
+ self.patch_size = patch_size
356
+
357
+ self.block = ResnetBlock1d(
358
+ in_channels=in_channels,
359
+ out_channels=out_channels // patch_size,
360
+ num_groups=1,
361
+ context_mapping_features=context_mapping_features,
362
+ use_snake=use_snake
363
+ )
364
+
365
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
366
+ x = self.block(x, mapping, causal=causal)
367
+ x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
368
+ return x
369
+
370
+
371
+ class Unpatcher(nn.Module):
372
+ def __init__(
373
+ self,
374
+ in_channels: int,
375
+ out_channels: int,
376
+ patch_size: int,
377
+ context_mapping_features: Optional[int] = None,
378
+ use_snake: bool = False
379
+ ):
380
+ super().__init__()
381
+ assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
382
+ assert in_channels % patch_size == 0, assert_message
383
+ self.patch_size = patch_size
384
+
385
+ self.block = ResnetBlock1d(
386
+ in_channels=in_channels // patch_size,
387
+ out_channels=out_channels,
388
+ num_groups=1,
389
+ context_mapping_features=context_mapping_features,
390
+ use_snake=use_snake
391
+ )
392
+
393
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
394
+ x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
395
+ x = self.block(x, mapping, causal=causal)
396
+ return x
397
+
398
+
399
+ """
400
+ Attention Components
401
+ """
402
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
403
+ mid_features = features * multiplier
404
+ return nn.Sequential(
405
+ nn.Linear(in_features=features, out_features=mid_features),
406
+ nn.GELU(),
407
+ nn.Linear(in_features=mid_features, out_features=features),
408
+ )
409
+
410
+ def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
411
+ b, ndim = sim.shape[0], mask.ndim
412
+ if ndim == 3:
413
+ mask = rearrange(mask, "b n m -> b 1 n m")
414
+ if ndim == 2:
415
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
416
+ max_neg_value = -torch.finfo(sim.dtype).max
417
+ sim = sim.masked_fill(~mask, max_neg_value)
418
+ return sim
419
+
420
+ def causal_mask(q: Tensor, k: Tensor) -> Tensor:
421
+ b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
422
+ mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
423
+ mask = repeat(mask, "n m -> b n m", b=b)
424
+ return mask
425
+
426
+ class AttentionBase(nn.Module):
427
+ def __init__(
428
+ self,
429
+ features: int,
430
+ *,
431
+ head_features: int,
432
+ num_heads: int,
433
+ out_features: Optional[int] = None,
434
+ ):
435
+ super().__init__()
436
+ self.scale = head_features**-0.5
437
+ self.num_heads = num_heads
438
+ mid_features = head_features * num_heads
439
+ out_features = default(out_features, features)
440
+
441
+ self.to_out = nn.Linear(
442
+ in_features=mid_features, out_features=out_features
443
+ )
444
+
445
+ self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
446
+
447
+ if not self.use_flash:
448
+ return
449
+
450
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
451
+
452
+ if device_properties.major == 8 and device_properties.minor == 0:
453
+ # Use flash attention for A100 GPUs
454
+ self.sdp_kernel_config = (True, False, False)
455
+ else:
456
+ # Don't use flash attention for other GPUs
457
+ self.sdp_kernel_config = (False, True, True)
458
+
459
+ def forward(
460
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
461
+ ) -> Tensor:
462
+ # Split heads
463
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
464
+
465
+ if not self.use_flash:
466
+ if is_causal and not mask:
467
+ # Mask out future tokens for causal attention
468
+ mask = causal_mask(q, k)
469
+
470
+ # Compute similarity matrix and add eventual mask
471
+ sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
472
+ sim = add_mask(sim, mask) if exists(mask) else sim
473
+
474
+ # Get attention matrix with softmax
475
+ attn = sim.softmax(dim=-1, dtype=torch.float32)
476
+
477
+ # Compute values
478
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
479
+ else:
480
+ with sdp_kernel(*self.sdp_kernel_config):
481
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
482
+
483
+ out = rearrange(out, "b h n d -> b n (h d)")
484
+ return self.to_out(out)
485
+
486
+ class Attention(nn.Module):
487
+ def __init__(
488
+ self,
489
+ features: int,
490
+ *,
491
+ head_features: int,
492
+ num_heads: int,
493
+ out_features: Optional[int] = None,
494
+ context_features: Optional[int] = None,
495
+ causal: bool = False,
496
+ ):
497
+ super().__init__()
498
+ self.context_features = context_features
499
+ self.causal = causal
500
+ mid_features = head_features * num_heads
501
+ context_features = default(context_features, features)
502
+
503
+ self.norm = nn.LayerNorm(features)
504
+ self.norm_context = nn.LayerNorm(context_features)
505
+ self.to_q = nn.Linear(
506
+ in_features=features, out_features=mid_features, bias=False
507
+ )
508
+ self.to_kv = nn.Linear(
509
+ in_features=context_features, out_features=mid_features * 2, bias=False
510
+ )
511
+ self.attention = AttentionBase(
512
+ features,
513
+ num_heads=num_heads,
514
+ head_features=head_features,
515
+ out_features=out_features,
516
+ )
517
+
518
+ def forward(
519
+ self,
520
+ x: Tensor, # [b, n, c]
521
+ context: Optional[Tensor] = None, # [b, m, d]
522
+ context_mask: Optional[Tensor] = None, # [b, m], false is masked,
523
+ causal: Optional[bool] = False,
524
+ ) -> Tensor:
525
+ assert_message = "You must provide a context when using context_features"
526
+ assert not self.context_features or exists(context), assert_message
527
+ # Use context if provided
528
+ context = default(context, x)
529
+ # Normalize then compute q from input and k,v from context
530
+ x, context = self.norm(x), self.norm_context(context)
531
+
532
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
533
+
534
+ if exists(context_mask):
535
+ # Mask out cross-attention for padding tokens
536
+ mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
537
+ k, v = k * mask, v * mask
538
+
539
+ # Compute and return attention
540
+ return self.attention(q, k, v, is_causal=self.causal or causal)
541
+
542
+
543
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
544
+ mid_features = features * multiplier
545
+ return nn.Sequential(
546
+ nn.Linear(in_features=features, out_features=mid_features),
547
+ nn.GELU(),
548
+ nn.Linear(in_features=mid_features, out_features=features),
549
+ )
550
+
551
+ """
552
+ Transformer Blocks
553
+ """
554
+
555
+
556
+ class TransformerBlock(nn.Module):
557
+ def __init__(
558
+ self,
559
+ features: int,
560
+ num_heads: int,
561
+ head_features: int,
562
+ multiplier: int,
563
+ context_features: Optional[int] = None,
564
+ ):
565
+ super().__init__()
566
+
567
+ self.use_cross_attention = exists(context_features) and context_features > 0
568
+
569
+ self.attention = Attention(
570
+ features=features,
571
+ num_heads=num_heads,
572
+ head_features=head_features
573
+ )
574
+
575
+ if self.use_cross_attention:
576
+ self.cross_attention = Attention(
577
+ features=features,
578
+ num_heads=num_heads,
579
+ head_features=head_features,
580
+ context_features=context_features
581
+ )
582
+
583
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
584
+
585
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor:
586
+ x = self.attention(x, causal=causal) + x
587
+ if self.use_cross_attention:
588
+ x = self.cross_attention(x, context=context, context_mask=context_mask) + x
589
+ x = self.feed_forward(x) + x
590
+ return x
591
+
592
+
593
+ """
594
+ Transformers
595
+ """
596
+
597
+
598
+ class Transformer1d(nn.Module):
599
+ def __init__(
600
+ self,
601
+ num_layers: int,
602
+ channels: int,
603
+ num_heads: int,
604
+ head_features: int,
605
+ multiplier: int,
606
+ context_features: Optional[int] = None,
607
+ ):
608
+ super().__init__()
609
+
610
+ self.to_in = nn.Sequential(
611
+ nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
612
+ Conv1d(
613
+ in_channels=channels,
614
+ out_channels=channels,
615
+ kernel_size=1,
616
+ ),
617
+ Rearrange("b c t -> b t c"),
618
+ )
619
+
620
+ self.blocks = nn.ModuleList(
621
+ [
622
+ TransformerBlock(
623
+ features=channels,
624
+ head_features=head_features,
625
+ num_heads=num_heads,
626
+ multiplier=multiplier,
627
+ context_features=context_features,
628
+ )
629
+ for i in range(num_layers)
630
+ ]
631
+ )
632
+
633
+ self.to_out = nn.Sequential(
634
+ Rearrange("b t c -> b c t"),
635
+ Conv1d(
636
+ in_channels=channels,
637
+ out_channels=channels,
638
+ kernel_size=1,
639
+ ),
640
+ )
641
+
642
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor:
643
+ x = self.to_in(x)
644
+ for block in self.blocks:
645
+ x = block(x, context=context, context_mask=context_mask, causal=causal)
646
+ x = self.to_out(x)
647
+ return x
648
+
649
+
650
+ """
651
+ Time Embeddings
652
+ """
653
+
654
+
655
+ class SinusoidalEmbedding(nn.Module):
656
+ def __init__(self, dim: int):
657
+ super().__init__()
658
+ self.dim = dim
659
+
660
+ def forward(self, x: Tensor) -> Tensor:
661
+ device, half_dim = x.device, self.dim // 2
662
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
663
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
664
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
665
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
666
+
667
+
668
+ class LearnedPositionalEmbedding(nn.Module):
669
+ """Used for continuous time"""
670
+
671
+ def __init__(self, dim: int):
672
+ super().__init__()
673
+ assert (dim % 2) == 0
674
+ half_dim = dim // 2
675
+ self.weights = nn.Parameter(torch.randn(half_dim))
676
+
677
+ def forward(self, x: Tensor) -> Tensor:
678
+ x = rearrange(x, "b -> b 1")
679
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
680
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
681
+ fouriered = torch.cat((x, fouriered), dim=-1)
682
+ return fouriered
683
+
684
+
685
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
686
+ return nn.Sequential(
687
+ LearnedPositionalEmbedding(dim),
688
+ nn.Linear(in_features=dim + 1, out_features=out_features),
689
+ )
690
+
691
+
692
+ """
693
+ Encoder/Decoder Components
694
+ """
695
+
696
+
697
+ class DownsampleBlock1d(nn.Module):
698
+ def __init__(
699
+ self,
700
+ in_channels: int,
701
+ out_channels: int,
702
+ *,
703
+ factor: int,
704
+ num_groups: int,
705
+ num_layers: int,
706
+ kernel_multiplier: int = 2,
707
+ use_pre_downsample: bool = True,
708
+ use_skip: bool = False,
709
+ use_snake: bool = False,
710
+ extract_channels: int = 0,
711
+ context_channels: int = 0,
712
+ num_transformer_blocks: int = 0,
713
+ attention_heads: Optional[int] = None,
714
+ attention_features: Optional[int] = None,
715
+ attention_multiplier: Optional[int] = None,
716
+ context_mapping_features: Optional[int] = None,
717
+ context_embedding_features: Optional[int] = None,
718
+ ):
719
+ super().__init__()
720
+ self.use_pre_downsample = use_pre_downsample
721
+ self.use_skip = use_skip
722
+ self.use_transformer = num_transformer_blocks > 0
723
+ self.use_extract = extract_channels > 0
724
+ self.use_context = context_channels > 0
725
+
726
+ channels = out_channels if use_pre_downsample else in_channels
727
+
728
+ self.downsample = Downsample1d(
729
+ in_channels=in_channels,
730
+ out_channels=out_channels,
731
+ factor=factor,
732
+ kernel_multiplier=kernel_multiplier,
733
+ )
734
+
735
+ self.blocks = nn.ModuleList(
736
+ [
737
+ ResnetBlock1d(
738
+ in_channels=channels + context_channels if i == 0 else channels,
739
+ out_channels=channels,
740
+ num_groups=num_groups,
741
+ context_mapping_features=context_mapping_features,
742
+ use_snake=use_snake
743
+ )
744
+ for i in range(num_layers)
745
+ ]
746
+ )
747
+
748
+ if self.use_transformer:
749
+ assert (
750
+ (exists(attention_heads) or exists(attention_features))
751
+ and exists(attention_multiplier)
752
+ )
753
+
754
+ if attention_features is None and attention_heads is not None:
755
+ attention_features = channels // attention_heads
756
+
757
+ if attention_heads is None and attention_features is not None:
758
+ attention_heads = channels // attention_features
759
+
760
+ self.transformer = Transformer1d(
761
+ num_layers=num_transformer_blocks,
762
+ channels=channels,
763
+ num_heads=attention_heads,
764
+ head_features=attention_features,
765
+ multiplier=attention_multiplier,
766
+ context_features=context_embedding_features
767
+ )
768
+
769
+ if self.use_extract:
770
+ num_extract_groups = min(num_groups, extract_channels)
771
+ self.to_extracted = ResnetBlock1d(
772
+ in_channels=out_channels,
773
+ out_channels=extract_channels,
774
+ num_groups=num_extract_groups,
775
+ use_snake=use_snake
776
+ )
777
+
778
+ def forward(
779
+ self,
780
+ x: Tensor,
781
+ *,
782
+ mapping: Optional[Tensor] = None,
783
+ channels: Optional[Tensor] = None,
784
+ embedding: Optional[Tensor] = None,
785
+ embedding_mask: Optional[Tensor] = None,
786
+ causal: Optional[bool] = False
787
+ ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
788
+
789
+ if self.use_pre_downsample:
790
+ x = self.downsample(x)
791
+
792
+ if self.use_context and exists(channels):
793
+ x = torch.cat([x, channels], dim=1)
794
+
795
+ skips = []
796
+ for block in self.blocks:
797
+ x = block(x, mapping=mapping, causal=causal)
798
+ skips += [x] if self.use_skip else []
799
+
800
+ if self.use_transformer:
801
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
802
+ skips += [x] if self.use_skip else []
803
+
804
+ if not self.use_pre_downsample:
805
+ x = self.downsample(x)
806
+
807
+ if self.use_extract:
808
+ extracted = self.to_extracted(x)
809
+ return x, extracted
810
+
811
+ return (x, skips) if self.use_skip else x
812
+
813
+
814
+ class UpsampleBlock1d(nn.Module):
815
+ def __init__(
816
+ self,
817
+ in_channels: int,
818
+ out_channels: int,
819
+ *,
820
+ factor: int,
821
+ num_layers: int,
822
+ num_groups: int,
823
+ use_nearest: bool = False,
824
+ use_pre_upsample: bool = False,
825
+ use_skip: bool = False,
826
+ use_snake: bool = False,
827
+ skip_channels: int = 0,
828
+ use_skip_scale: bool = False,
829
+ extract_channels: int = 0,
830
+ num_transformer_blocks: int = 0,
831
+ attention_heads: Optional[int] = None,
832
+ attention_features: Optional[int] = None,
833
+ attention_multiplier: Optional[int] = None,
834
+ context_mapping_features: Optional[int] = None,
835
+ context_embedding_features: Optional[int] = None,
836
+ ):
837
+ super().__init__()
838
+
839
+ self.use_extract = extract_channels > 0
840
+ self.use_pre_upsample = use_pre_upsample
841
+ self.use_transformer = num_transformer_blocks > 0
842
+ self.use_skip = use_skip
843
+ self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
844
+
845
+ channels = out_channels if use_pre_upsample else in_channels
846
+
847
+ self.blocks = nn.ModuleList(
848
+ [
849
+ ResnetBlock1d(
850
+ in_channels=channels + skip_channels,
851
+ out_channels=channels,
852
+ num_groups=num_groups,
853
+ context_mapping_features=context_mapping_features,
854
+ use_snake=use_snake
855
+ )
856
+ for _ in range(num_layers)
857
+ ]
858
+ )
859
+
860
+ if self.use_transformer:
861
+ assert (
862
+ (exists(attention_heads) or exists(attention_features))
863
+ and exists(attention_multiplier)
864
+ )
865
+
866
+ if attention_features is None and attention_heads is not None:
867
+ attention_features = channels // attention_heads
868
+
869
+ if attention_heads is None and attention_features is not None:
870
+ attention_heads = channels // attention_features
871
+
872
+ self.transformer = Transformer1d(
873
+ num_layers=num_transformer_blocks,
874
+ channels=channels,
875
+ num_heads=attention_heads,
876
+ head_features=attention_features,
877
+ multiplier=attention_multiplier,
878
+ context_features=context_embedding_features,
879
+ )
880
+
881
+ self.upsample = Upsample1d(
882
+ in_channels=in_channels,
883
+ out_channels=out_channels,
884
+ factor=factor,
885
+ use_nearest=use_nearest,
886
+ )
887
+
888
+ if self.use_extract:
889
+ num_extract_groups = min(num_groups, extract_channels)
890
+ self.to_extracted = ResnetBlock1d(
891
+ in_channels=out_channels,
892
+ out_channels=extract_channels,
893
+ num_groups=num_extract_groups,
894
+ use_snake=use_snake
895
+ )
896
+
897
+ def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
898
+ return torch.cat([x, skip * self.skip_scale], dim=1)
899
+
900
+ def forward(
901
+ self,
902
+ x: Tensor,
903
+ *,
904
+ skips: Optional[List[Tensor]] = None,
905
+ mapping: Optional[Tensor] = None,
906
+ embedding: Optional[Tensor] = None,
907
+ embedding_mask: Optional[Tensor] = None,
908
+ causal: Optional[bool] = False
909
+ ) -> Union[Tuple[Tensor, Tensor], Tensor]:
910
+
911
+ if self.use_pre_upsample:
912
+ x = self.upsample(x)
913
+
914
+ for block in self.blocks:
915
+ x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
916
+ x = block(x, mapping=mapping, causal=causal)
917
+
918
+ if self.use_transformer:
919
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
920
+
921
+ if not self.use_pre_upsample:
922
+ x = self.upsample(x)
923
+
924
+ if self.use_extract:
925
+ extracted = self.to_extracted(x)
926
+ return x, extracted
927
+
928
+ return x
929
+
930
+
931
+ class BottleneckBlock1d(nn.Module):
932
+ def __init__(
933
+ self,
934
+ channels: int,
935
+ *,
936
+ num_groups: int,
937
+ num_transformer_blocks: int = 0,
938
+ attention_heads: Optional[int] = None,
939
+ attention_features: Optional[int] = None,
940
+ attention_multiplier: Optional[int] = None,
941
+ context_mapping_features: Optional[int] = None,
942
+ context_embedding_features: Optional[int] = None,
943
+ use_snake: bool = False,
944
+ ):
945
+ super().__init__()
946
+ self.use_transformer = num_transformer_blocks > 0
947
+
948
+ self.pre_block = ResnetBlock1d(
949
+ in_channels=channels,
950
+ out_channels=channels,
951
+ num_groups=num_groups,
952
+ context_mapping_features=context_mapping_features,
953
+ use_snake=use_snake
954
+ )
955
+
956
+ if self.use_transformer:
957
+ assert (
958
+ (exists(attention_heads) or exists(attention_features))
959
+ and exists(attention_multiplier)
960
+ )
961
+
962
+ if attention_features is None and attention_heads is not None:
963
+ attention_features = channels // attention_heads
964
+
965
+ if attention_heads is None and attention_features is not None:
966
+ attention_heads = channels // attention_features
967
+
968
+ self.transformer = Transformer1d(
969
+ num_layers=num_transformer_blocks,
970
+ channels=channels,
971
+ num_heads=attention_heads,
972
+ head_features=attention_features,
973
+ multiplier=attention_multiplier,
974
+ context_features=context_embedding_features,
975
+ )
976
+
977
+ self.post_block = ResnetBlock1d(
978
+ in_channels=channels,
979
+ out_channels=channels,
980
+ num_groups=num_groups,
981
+ context_mapping_features=context_mapping_features,
982
+ use_snake=use_snake
983
+ )
984
+
985
+ def forward(
986
+ self,
987
+ x: Tensor,
988
+ *,
989
+ mapping: Optional[Tensor] = None,
990
+ embedding: Optional[Tensor] = None,
991
+ embedding_mask: Optional[Tensor] = None,
992
+ causal: Optional[bool] = False
993
+ ) -> Tensor:
994
+ x = self.pre_block(x, mapping=mapping, causal=causal)
995
+ if self.use_transformer:
996
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
997
+ x = self.post_block(x, mapping=mapping, causal=causal)
998
+ return x
999
+
1000
+
1001
+ """
1002
+ UNet
1003
+ """
1004
+
1005
+
1006
+ class UNet1d(nn.Module):
1007
+ def __init__(
1008
+ self,
1009
+ in_channels: int,
1010
+ channels: int,
1011
+ multipliers: Sequence[int],
1012
+ factors: Sequence[int],
1013
+ num_blocks: Sequence[int],
1014
+ attentions: Sequence[int],
1015
+ patch_size: int = 1,
1016
+ resnet_groups: int = 8,
1017
+ use_context_time: bool = True,
1018
+ kernel_multiplier_downsample: int = 2,
1019
+ use_nearest_upsample: bool = False,
1020
+ use_skip_scale: bool = True,
1021
+ use_snake: bool = False,
1022
+ use_stft: bool = False,
1023
+ use_stft_context: bool = False,
1024
+ out_channels: Optional[int] = None,
1025
+ context_features: Optional[int] = None,
1026
+ context_features_multiplier: int = 4,
1027
+ context_channels: Optional[Sequence[int]] = None,
1028
+ context_embedding_features: Optional[int] = None,
1029
+ **kwargs,
1030
+ ):
1031
+ super().__init__()
1032
+ out_channels = default(out_channels, in_channels)
1033
+ context_channels = list(default(context_channels, []))
1034
+ num_layers = len(multipliers) - 1
1035
+ use_context_features = exists(context_features)
1036
+ use_context_channels = len(context_channels) > 0
1037
+ context_mapping_features = None
1038
+
1039
+ attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
1040
+
1041
+ self.num_layers = num_layers
1042
+ self.use_context_time = use_context_time
1043
+ self.use_context_features = use_context_features
1044
+ self.use_context_channels = use_context_channels
1045
+ self.use_stft = use_stft
1046
+ self.use_stft_context = use_stft_context
1047
+
1048
+ self.context_features = context_features
1049
+ context_channels_pad_length = num_layers + 1 - len(context_channels)
1050
+ context_channels = context_channels + [0] * context_channels_pad_length
1051
+ self.context_channels = context_channels
1052
+ self.context_embedding_features = context_embedding_features
1053
+
1054
+ if use_context_channels:
1055
+ has_context = [c > 0 for c in context_channels]
1056
+ self.has_context = has_context
1057
+ self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
1058
+
1059
+ assert (
1060
+ len(factors) == num_layers
1061
+ and len(attentions) >= num_layers
1062
+ and len(num_blocks) == num_layers
1063
+ )
1064
+
1065
+ if use_context_time or use_context_features:
1066
+ context_mapping_features = channels * context_features_multiplier
1067
+
1068
+ self.to_mapping = nn.Sequential(
1069
+ nn.Linear(context_mapping_features, context_mapping_features),
1070
+ nn.GELU(),
1071
+ nn.Linear(context_mapping_features, context_mapping_features),
1072
+ nn.GELU(),
1073
+ )
1074
+
1075
+ if use_context_time:
1076
+ assert exists(context_mapping_features)
1077
+ self.to_time = nn.Sequential(
1078
+ TimePositionalEmbedding(
1079
+ dim=channels, out_features=context_mapping_features
1080
+ ),
1081
+ nn.GELU(),
1082
+ )
1083
+
1084
+ if use_context_features:
1085
+ assert exists(context_features) and exists(context_mapping_features)
1086
+ self.to_features = nn.Sequential(
1087
+ nn.Linear(
1088
+ in_features=context_features, out_features=context_mapping_features
1089
+ ),
1090
+ nn.GELU(),
1091
+ )
1092
+
1093
+ if use_stft:
1094
+ stft_kwargs, kwargs = groupby("stft_", kwargs)
1095
+ assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
1096
+ stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
1097
+ in_channels *= stft_channels
1098
+ out_channels *= stft_channels
1099
+ context_channels[0] *= stft_channels if use_stft_context else 1
1100
+ assert exists(in_channels) and exists(out_channels)
1101
+ self.stft = STFT(**stft_kwargs)
1102
+
1103
+ assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
1104
+
1105
+ self.to_in = Patcher(
1106
+ in_channels=in_channels + context_channels[0],
1107
+ out_channels=channels * multipliers[0],
1108
+ patch_size=patch_size,
1109
+ context_mapping_features=context_mapping_features,
1110
+ use_snake=use_snake
1111
+ )
1112
+
1113
+ self.downsamples = nn.ModuleList(
1114
+ [
1115
+ DownsampleBlock1d(
1116
+ in_channels=channels * multipliers[i],
1117
+ out_channels=channels * multipliers[i + 1],
1118
+ context_mapping_features=context_mapping_features,
1119
+ context_channels=context_channels[i + 1],
1120
+ context_embedding_features=context_embedding_features,
1121
+ num_layers=num_blocks[i],
1122
+ factor=factors[i],
1123
+ kernel_multiplier=kernel_multiplier_downsample,
1124
+ num_groups=resnet_groups,
1125
+ use_pre_downsample=True,
1126
+ use_skip=True,
1127
+ use_snake=use_snake,
1128
+ num_transformer_blocks=attentions[i],
1129
+ **attention_kwargs,
1130
+ )
1131
+ for i in range(num_layers)
1132
+ ]
1133
+ )
1134
+
1135
+ self.bottleneck = BottleneckBlock1d(
1136
+ channels=channels * multipliers[-1],
1137
+ context_mapping_features=context_mapping_features,
1138
+ context_embedding_features=context_embedding_features,
1139
+ num_groups=resnet_groups,
1140
+ num_transformer_blocks=attentions[-1],
1141
+ use_snake=use_snake,
1142
+ **attention_kwargs,
1143
+ )
1144
+
1145
+ self.upsamples = nn.ModuleList(
1146
+ [
1147
+ UpsampleBlock1d(
1148
+ in_channels=channels * multipliers[i + 1],
1149
+ out_channels=channels * multipliers[i],
1150
+ context_mapping_features=context_mapping_features,
1151
+ context_embedding_features=context_embedding_features,
1152
+ num_layers=num_blocks[i] + (1 if attentions[i] else 0),
1153
+ factor=factors[i],
1154
+ use_nearest=use_nearest_upsample,
1155
+ num_groups=resnet_groups,
1156
+ use_skip_scale=use_skip_scale,
1157
+ use_pre_upsample=False,
1158
+ use_skip=True,
1159
+ use_snake=use_snake,
1160
+ skip_channels=channels * multipliers[i + 1],
1161
+ num_transformer_blocks=attentions[i],
1162
+ **attention_kwargs,
1163
+ )
1164
+ for i in reversed(range(num_layers))
1165
+ ]
1166
+ )
1167
+
1168
+ self.to_out = Unpatcher(
1169
+ in_channels=channels * multipliers[0],
1170
+ out_channels=out_channels,
1171
+ patch_size=patch_size,
1172
+ context_mapping_features=context_mapping_features,
1173
+ use_snake=use_snake
1174
+ )
1175
+
1176
+ def get_channels(
1177
+ self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
1178
+ ) -> Optional[Tensor]:
1179
+ """Gets context channels at `layer` and checks that shape is correct"""
1180
+ use_context_channels = self.use_context_channels and self.has_context[layer]
1181
+ if not use_context_channels:
1182
+ return None
1183
+ assert exists(channels_list), "Missing context"
1184
+ # Get channels index (skipping zero channel contexts)
1185
+ channels_id = self.channels_ids[layer]
1186
+ # Get channels
1187
+ channels = channels_list[channels_id]
1188
+ message = f"Missing context for layer {layer} at index {channels_id}"
1189
+ assert exists(channels), message
1190
+ # Check channels
1191
+ num_channels = self.context_channels[layer]
1192
+ message = f"Expected context with {num_channels} channels at idx {channels_id}"
1193
+ assert channels.shape[1] == num_channels, message
1194
+ # STFT channels if requested
1195
+ channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
1196
+ return channels
1197
+
1198
+ def get_mapping(
1199
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
1200
+ ) -> Optional[Tensor]:
1201
+ """Combines context time features and features into mapping"""
1202
+ items, mapping = [], None
1203
+ # Compute time features
1204
+ if self.use_context_time:
1205
+ assert_message = "use_context_time=True but no time features provided"
1206
+ assert exists(time), assert_message
1207
+ items += [self.to_time(time)]
1208
+ # Compute features
1209
+ if self.use_context_features:
1210
+ assert_message = "context_features exists but no features provided"
1211
+ assert exists(features), assert_message
1212
+ items += [self.to_features(features)]
1213
+ # Compute joint mapping
1214
+ if self.use_context_time or self.use_context_features:
1215
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
1216
+ mapping = self.to_mapping(mapping)
1217
+ return mapping
1218
+
1219
+ def forward(
1220
+ self,
1221
+ x: Tensor,
1222
+ time: Optional[Tensor] = None,
1223
+ *,
1224
+ features: Optional[Tensor] = None,
1225
+ channels_list: Optional[Sequence[Tensor]] = None,
1226
+ embedding: Optional[Tensor] = None,
1227
+ embedding_mask: Optional[Tensor] = None,
1228
+ causal: Optional[bool] = False,
1229
+ ) -> Tensor:
1230
+ channels = self.get_channels(channels_list, layer=0)
1231
+ # Apply stft if required
1232
+ x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
1233
+ # Concat context channels at layer 0 if provided
1234
+ x = torch.cat([x, channels], dim=1) if exists(channels) else x
1235
+ # Compute mapping from time and features
1236
+ mapping = self.get_mapping(time, features)
1237
+ x = self.to_in(x, mapping, causal=causal)
1238
+ skips_list = [x]
1239
+
1240
+ for i, downsample in enumerate(self.downsamples):
1241
+ channels = self.get_channels(channels_list, layer=i + 1)
1242
+ x, skips = downsample(
1243
+ x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal
1244
+ )
1245
+ skips_list += [skips]
1246
+
1247
+ x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1248
+
1249
+ for i, upsample in enumerate(self.upsamples):
1250
+ skips = skips_list.pop()
1251
+ x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1252
+
1253
+ x += skips_list.pop()
1254
+ x = self.to_out(x, mapping, causal=causal)
1255
+ x = self.stft.decode1d(x) if self.use_stft else x
1256
+
1257
+ return x
1258
+
1259
+
1260
+ """ Conditioning Modules """
1261
+
1262
+
1263
+ class FixedEmbedding(nn.Module):
1264
+ def __init__(self, max_length: int, features: int):
1265
+ super().__init__()
1266
+ self.max_length = max_length
1267
+ self.embedding = nn.Embedding(max_length, features)
1268
+
1269
+ def forward(self, x: Tensor) -> Tensor:
1270
+ batch_size, length, device = *x.shape[0:2], x.device
1271
+ assert_message = "Input sequence length must be <= max_length"
1272
+ assert length <= self.max_length, assert_message
1273
+ position = torch.arange(length, device=device)
1274
+ fixed_embedding = self.embedding(position)
1275
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
1276
+ return fixed_embedding
1277
+
1278
+
1279
+ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
1280
+ if proba == 1:
1281
+ return torch.ones(shape, device=device, dtype=torch.bool)
1282
+ elif proba == 0:
1283
+ return torch.zeros(shape, device=device, dtype=torch.bool)
1284
+ else:
1285
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
1286
+
1287
+
1288
+ class UNetCFG1d(UNet1d):
1289
+
1290
+ """UNet1d with Classifier-Free Guidance"""
1291
+
1292
+ def __init__(
1293
+ self,
1294
+ context_embedding_max_length: int,
1295
+ context_embedding_features: int,
1296
+ use_xattn_time: bool = False,
1297
+ **kwargs,
1298
+ ):
1299
+ super().__init__(
1300
+ context_embedding_features=context_embedding_features, **kwargs
1301
+ )
1302
+
1303
+ self.use_xattn_time = use_xattn_time
1304
+
1305
+ if use_xattn_time:
1306
+ assert exists(context_embedding_features)
1307
+ self.to_time_embedding = nn.Sequential(
1308
+ TimePositionalEmbedding(
1309
+ dim=kwargs["channels"], out_features=context_embedding_features
1310
+ ),
1311
+ nn.GELU(),
1312
+ )
1313
+
1314
+ context_embedding_max_length += 1 # Add one for time embedding
1315
+
1316
+ self.fixed_embedding = FixedEmbedding(
1317
+ max_length=context_embedding_max_length, features=context_embedding_features
1318
+ )
1319
+
1320
+ def forward( # type: ignore
1321
+ self,
1322
+ x: Tensor,
1323
+ time: Tensor,
1324
+ *,
1325
+ embedding: Tensor,
1326
+ embedding_mask: Optional[Tensor] = None,
1327
+ embedding_scale: float = 1.0,
1328
+ embedding_mask_proba: float = 0.0,
1329
+ batch_cfg: bool = False,
1330
+ rescale_cfg: bool = False,
1331
+ scale_phi: float = 0.4,
1332
+ negative_embedding: Optional[Tensor] = None,
1333
+ negative_embedding_mask: Optional[Tensor] = None,
1334
+ **kwargs,
1335
+ ) -> Tensor:
1336
+ b, device = embedding.shape[0], embedding.device
1337
+
1338
+ if self.use_xattn_time:
1339
+ embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1)
1340
+
1341
+ if embedding_mask is not None:
1342
+ embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1)
1343
+
1344
+ fixed_embedding = self.fixed_embedding(embedding)
1345
+
1346
+ if embedding_mask_proba > 0.0:
1347
+ # Randomly mask embedding
1348
+ batch_mask = rand_bool(
1349
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
1350
+ )
1351
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
1352
+
1353
+ if embedding_scale != 1.0:
1354
+ if batch_cfg:
1355
+ batch_x = torch.cat([x, x], dim=0)
1356
+ batch_time = torch.cat([time, time], dim=0)
1357
+
1358
+ if negative_embedding is not None:
1359
+ if negative_embedding_mask is not None:
1360
+ negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2)
1361
+
1362
+ negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding)
1363
+
1364
+ batch_embed = torch.cat([embedding, negative_embedding], dim=0)
1365
+
1366
+ else:
1367
+ batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
1368
+
1369
+ batch_mask = None
1370
+ if embedding_mask is not None:
1371
+ batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
1372
+
1373
+ batch_features = None
1374
+ features = kwargs.pop("features", None)
1375
+ if self.use_context_features:
1376
+ batch_features = torch.cat([features, features], dim=0)
1377
+
1378
+ batch_channels = None
1379
+ channels_list = kwargs.pop("channels_list", None)
1380
+ if self.use_context_channels:
1381
+ batch_channels = []
1382
+ for channels in channels_list:
1383
+ batch_channels += [torch.cat([channels, channels], dim=0)]
1384
+
1385
+ # Compute both normal and fixed embedding outputs
1386
+ batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs)
1387
+ out, out_masked = batch_out.chunk(2, dim=0)
1388
+
1389
+ else:
1390
+ # Compute both normal and fixed embedding outputs
1391
+ out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1392
+ out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs)
1393
+
1394
+ out_cfg = out_masked + (out - out_masked) * embedding_scale
1395
+
1396
+ if rescale_cfg:
1397
+
1398
+ out_std = out.std(dim=1, keepdim=True)
1399
+ out_cfg_std = out_cfg.std(dim=1, keepdim=True)
1400
+
1401
+ return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg
1402
+
1403
+ else:
1404
+
1405
+ return out_cfg
1406
+
1407
+ else:
1408
+ return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1409
+
1410
+
1411
+ class UNetNCCA1d(UNet1d):
1412
+
1413
+ """UNet1d with Noise Channel Conditioning Augmentation"""
1414
+
1415
+ def __init__(self, context_features: int, **kwargs):
1416
+ super().__init__(context_features=context_features, **kwargs)
1417
+ self.embedder = NumberEmbedder(features=context_features)
1418
+
1419
+ def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
1420
+ x = x if torch.is_tensor(x) else torch.tensor(x)
1421
+ return x.expand(shape)
1422
+
1423
+ def forward( # type: ignore
1424
+ self,
1425
+ x: Tensor,
1426
+ time: Tensor,
1427
+ *,
1428
+ channels_list: Sequence[Tensor],
1429
+ channels_augmentation: Union[
1430
+ bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
1431
+ ] = False,
1432
+ channels_scale: Union[
1433
+ float, Sequence[float], Sequence[Sequence[float]], Tensor
1434
+ ] = 0,
1435
+ **kwargs,
1436
+ ) -> Tensor:
1437
+ b, n = x.shape[0], len(channels_list)
1438
+ channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
1439
+ channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
1440
+
1441
+ # Augmentation (for each channel list item)
1442
+ for i in range(n):
1443
+ scale = channels_scale[:, i] * channels_augmentation[:, i]
1444
+ scale = rearrange(scale, "b -> b 1 1")
1445
+ item = channels_list[i]
1446
+ channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
1447
+
1448
+ # Scale embedding (sum reduction if more than one channel list item)
1449
+ channels_scale_emb = self.embedder(channels_scale)
1450
+ channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
1451
+
1452
+ return super().forward(
1453
+ x=x,
1454
+ time=time,
1455
+ channels_list=channels_list,
1456
+ features=channels_scale_emb,
1457
+ **kwargs,
1458
+ )
1459
+
1460
+
1461
+ class UNetAll1d(UNetCFG1d, UNetNCCA1d):
1462
+ def __init__(self, *args, **kwargs):
1463
+ super().__init__(*args, **kwargs)
1464
+
1465
+ def forward(self, *args, **kwargs): # type: ignore
1466
+ return UNetCFG1d.forward(self, *args, **kwargs)
1467
+
1468
+
1469
+ def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
1470
+ if type == "base":
1471
+ return UNet1d(**kwargs)
1472
+ elif type == "all":
1473
+ return UNetAll1d(**kwargs)
1474
+ elif type == "cfg":
1475
+ return UNetCFG1d(**kwargs)
1476
+ elif type == "ncca":
1477
+ return UNetNCCA1d(**kwargs)
1478
+ else:
1479
+ raise ValueError(f"Unknown XUNet1d type: {type}")
1480
+
1481
+ class NumberEmbedder(nn.Module):
1482
+ def __init__(
1483
+ self,
1484
+ features: int,
1485
+ dim: int = 256,
1486
+ ):
1487
+ super().__init__()
1488
+ self.features = features
1489
+ self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
1490
+
1491
+ def forward(self, x: Union[List[float], Tensor]) -> Tensor:
1492
+ if not torch.is_tensor(x):
1493
+ device = next(self.embedding.parameters()).device
1494
+ x = torch.tensor(x, device=device)
1495
+ assert isinstance(x, Tensor)
1496
+ shape = x.shape
1497
+ x = rearrange(x, "... -> (...)")
1498
+ embedding = self.embedding(x)
1499
+ x = embedding.view(*shape, self.features)
1500
+ return x # type: ignore
1501
+
1502
+
1503
+ """
1504
+ Audio Transforms
1505
+ """
1506
+
1507
+
1508
+ class STFT(nn.Module):
1509
+ """Helper for torch stft and istft"""
1510
+
1511
+ def __init__(
1512
+ self,
1513
+ num_fft: int = 1023,
1514
+ hop_length: int = 256,
1515
+ window_length: Optional[int] = None,
1516
+ length: Optional[int] = None,
1517
+ use_complex: bool = False,
1518
+ ):
1519
+ super().__init__()
1520
+ self.num_fft = num_fft
1521
+ self.hop_length = default(hop_length, floor(num_fft // 4))
1522
+ self.window_length = default(window_length, num_fft)
1523
+ self.length = length
1524
+ self.register_buffer("window", torch.hann_window(self.window_length))
1525
+ self.use_complex = use_complex
1526
+
1527
+ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
1528
+ b = wave.shape[0]
1529
+ wave = rearrange(wave, "b c t -> (b c) t")
1530
+
1531
+ stft = torch.stft(
1532
+ wave,
1533
+ n_fft=self.num_fft,
1534
+ hop_length=self.hop_length,
1535
+ win_length=self.window_length,
1536
+ window=self.window, # type: ignore
1537
+ return_complex=True,
1538
+ normalized=True,
1539
+ )
1540
+
1541
+ if self.use_complex:
1542
+ # Returns real and imaginary
1543
+ stft_a, stft_b = stft.real, stft.imag
1544
+ else:
1545
+ # Returns magnitude and phase matrices
1546
+ magnitude, phase = torch.abs(stft), torch.angle(stft)
1547
+ stft_a, stft_b = magnitude, phase
1548
+
1549
+ return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
1550
+
1551
+ def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
1552
+ b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
1553
+ length = closest_power_2(l * self.hop_length)
1554
+
1555
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
1556
+
1557
+ if self.use_complex:
1558
+ real, imag = stft_a, stft_b
1559
+ else:
1560
+ magnitude, phase = stft_a, stft_b
1561
+ real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
1562
+
1563
+ stft = torch.stack([real, imag], dim=-1)
1564
+
1565
+ wave = torch.istft(
1566
+ stft,
1567
+ n_fft=self.num_fft,
1568
+ hop_length=self.hop_length,
1569
+ win_length=self.window_length,
1570
+ window=self.window, # type: ignore
1571
+ length=default(self.length, length),
1572
+ normalized=True,
1573
+ )
1574
+
1575
+ return rearrange(wave, "(b c) t -> b c t", b=b)
1576
+
1577
+ def encode1d(
1578
+ self, wave: Tensor, stacked: bool = True
1579
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
1580
+ stft_a, stft_b = self.encode(wave)
1581
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
1582
+ return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
1583
+
1584
+ def decode1d(self, stft_pair: Tensor) -> Tensor:
1585
+ f = self.num_fft // 2 + 1
1586
+ stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
1587
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
1588
+ return self.decode(stft_a, stft_b)
stable_audio_tools/models/autoencoders.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torchaudio import transforms as T
8
+ from alias_free_torch import Activation1d
9
+ from dac.nn.layers import WNConv1d, WNConvTranspose1d
10
+ from typing import Literal, Dict, Any
11
+
12
+ from ..inference.sampling import sample
13
+ from ..inference.utils import prepare_audio
14
+ from .blocks import SnakeBeta
15
+ from .bottleneck import Bottleneck, DiscreteBottleneck
16
+ from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
17
+ from .factory import create_pretransform_from_config, create_bottleneck_from_config
18
+ from .pretransforms import Pretransform
19
+
20
+ def checkpoint(function, *args, **kwargs):
21
+ kwargs.setdefault("use_reentrant", False)
22
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
23
+
24
+ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
25
+ if activation == "elu":
26
+ act = nn.ELU()
27
+ elif activation == "snake":
28
+ act = SnakeBeta(channels)
29
+ elif activation == "none":
30
+ act = nn.Identity()
31
+ else:
32
+ raise ValueError(f"Unknown activation {activation}")
33
+
34
+ if antialias:
35
+ act = Activation1d(act)
36
+
37
+ return act
38
+
39
+ class ResidualUnit(nn.Module):
40
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
41
+ super().__init__()
42
+
43
+ self.dilation = dilation
44
+
45
+ padding = (dilation * (7-1)) // 2
46
+
47
+ self.layers = nn.Sequential(
48
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
49
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
50
+ kernel_size=7, dilation=dilation, padding=padding),
51
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
52
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
53
+ kernel_size=1)
54
+ )
55
+
56
+ def forward(self, x):
57
+ res = x
58
+
59
+ #x = checkpoint(self.layers, x)
60
+ x = self.layers(x)
61
+
62
+ return x + res
63
+
64
+ class EncoderBlock(nn.Module):
65
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
66
+ super().__init__()
67
+
68
+ self.layers = nn.Sequential(
69
+ ResidualUnit(in_channels=in_channels,
70
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
71
+ ResidualUnit(in_channels=in_channels,
72
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
73
+ ResidualUnit(in_channels=in_channels,
74
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
75
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
76
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
77
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
78
+ )
79
+
80
+ def forward(self, x):
81
+ return self.layers(x)
82
+
83
+ class DecoderBlock(nn.Module):
84
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
85
+ super().__init__()
86
+
87
+ if use_nearest_upsample:
88
+ upsample_layer = nn.Sequential(
89
+ nn.Upsample(scale_factor=stride, mode="nearest"),
90
+ WNConv1d(in_channels=in_channels,
91
+ out_channels=out_channels,
92
+ kernel_size=2*stride,
93
+ stride=1,
94
+ bias=False,
95
+ padding='same')
96
+ )
97
+ else:
98
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
99
+ out_channels=out_channels,
100
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
101
+
102
+ self.layers = nn.Sequential(
103
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
104
+ upsample_layer,
105
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
106
+ dilation=1, use_snake=use_snake),
107
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
108
+ dilation=3, use_snake=use_snake),
109
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
110
+ dilation=9, use_snake=use_snake),
111
+ )
112
+
113
+ def forward(self, x):
114
+ return self.layers(x)
115
+
116
+ class OobleckEncoder(nn.Module):
117
+ def __init__(self,
118
+ in_channels=2,
119
+ channels=128,
120
+ latent_dim=32,
121
+ c_mults = [1, 2, 4, 8],
122
+ strides = [2, 4, 8, 8],
123
+ use_snake=False,
124
+ antialias_activation=False
125
+ ):
126
+ super().__init__()
127
+
128
+ c_mults = [1] + c_mults
129
+
130
+ self.depth = len(c_mults)
131
+
132
+ layers = [
133
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
134
+ ]
135
+
136
+ for i in range(self.depth-1):
137
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
138
+
139
+ layers += [
140
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
141
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
142
+ ]
143
+
144
+ self.layers = nn.Sequential(*layers)
145
+
146
+ def forward(self, x):
147
+ return self.layers(x)
148
+
149
+
150
+ class OobleckDecoder(nn.Module):
151
+ def __init__(self,
152
+ out_channels=2,
153
+ channels=128,
154
+ latent_dim=32,
155
+ c_mults = [1, 2, 4, 8],
156
+ strides = [2, 4, 8, 8],
157
+ use_snake=False,
158
+ antialias_activation=False,
159
+ use_nearest_upsample=False,
160
+ final_tanh=True):
161
+ super().__init__()
162
+
163
+ c_mults = [1] + c_mults
164
+
165
+ self.depth = len(c_mults)
166
+
167
+ layers = [
168
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
169
+ ]
170
+
171
+ for i in range(self.depth-1, 0, -1):
172
+ layers += [DecoderBlock(
173
+ in_channels=c_mults[i]*channels,
174
+ out_channels=c_mults[i-1]*channels,
175
+ stride=strides[i-1],
176
+ use_snake=use_snake,
177
+ antialias_activation=antialias_activation,
178
+ use_nearest_upsample=use_nearest_upsample
179
+ )
180
+ ]
181
+
182
+ layers += [
183
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
184
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
185
+ nn.Tanh() if final_tanh else nn.Identity()
186
+ ]
187
+
188
+ self.layers = nn.Sequential(*layers)
189
+
190
+ def forward(self, x):
191
+ return self.layers(x)
192
+
193
+
194
+ class DACEncoderWrapper(nn.Module):
195
+ def __init__(self, in_channels=1, **kwargs):
196
+ super().__init__()
197
+
198
+ from dac.model.dac import Encoder as DACEncoder
199
+
200
+ latent_dim = kwargs.pop("latent_dim", None)
201
+
202
+ encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
203
+ self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
204
+ self.latent_dim = latent_dim
205
+
206
+ # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
207
+ self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
208
+
209
+ if in_channels != 1:
210
+ self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
211
+
212
+ def forward(self, x):
213
+ x = self.encoder(x)
214
+ x = self.proj_out(x)
215
+ return x
216
+
217
+ class DACDecoderWrapper(nn.Module):
218
+ def __init__(self, latent_dim, out_channels=1, **kwargs):
219
+ super().__init__()
220
+
221
+ from dac.model.dac import Decoder as DACDecoder
222
+
223
+ self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
224
+
225
+ self.latent_dim = latent_dim
226
+
227
+ def forward(self, x):
228
+ return self.decoder(x)
229
+
230
+ class AudioAutoencoder(nn.Module):
231
+ def __init__(
232
+ self,
233
+ encoder,
234
+ decoder,
235
+ latent_dim,
236
+ downsampling_ratio,
237
+ sample_rate,
238
+ io_channels=2,
239
+ bottleneck: Bottleneck = None,
240
+ pretransform: Pretransform = None,
241
+ in_channels = None,
242
+ out_channels = None,
243
+ soft_clip = False
244
+ ):
245
+ super().__init__()
246
+
247
+ self.downsampling_ratio = downsampling_ratio
248
+ self.sample_rate = sample_rate
249
+
250
+ self.latent_dim = latent_dim
251
+ self.io_channels = io_channels
252
+ self.in_channels = io_channels
253
+ self.out_channels = io_channels
254
+
255
+ self.min_length = self.downsampling_ratio
256
+
257
+ if in_channels is not None:
258
+ self.in_channels = in_channels
259
+
260
+ if out_channels is not None:
261
+ self.out_channels = out_channels
262
+
263
+ self.bottleneck = bottleneck
264
+
265
+ self.encoder = encoder
266
+
267
+ self.decoder = decoder
268
+
269
+ self.pretransform = pretransform
270
+
271
+ self.soft_clip = soft_clip
272
+
273
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
274
+
275
+ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
276
+
277
+ info = {}
278
+
279
+ if self.pretransform is not None and not skip_pretransform:
280
+ if self.pretransform.enable_grad:
281
+ if iterate_batch:
282
+ audios = []
283
+ for i in range(audio.shape[0]):
284
+ audios.append(self.pretransform.encode(audio[i:i+1]))
285
+ audio = torch.cat(audios, dim=0)
286
+ else:
287
+ audio = self.pretransform.encode(audio)
288
+ else:
289
+ with torch.no_grad():
290
+ if iterate_batch:
291
+ audios = []
292
+ for i in range(audio.shape[0]):
293
+ audios.append(self.pretransform.encode(audio[i:i+1]))
294
+ audio = torch.cat(audios, dim=0)
295
+ else:
296
+ audio = self.pretransform.encode(audio)
297
+
298
+ if self.encoder is not None:
299
+ if iterate_batch:
300
+ latents = []
301
+ for i in range(audio.shape[0]):
302
+ latents.append(self.encoder(audio[i:i+1]))
303
+ latents = torch.cat(latents, dim=0)
304
+ else:
305
+ latents = self.encoder(audio)
306
+ else:
307
+ latents = audio
308
+
309
+ if self.bottleneck is not None:
310
+ # TODO: Add iterate batch logic, needs to merge the info dicts
311
+ latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
312
+
313
+ info.update(bottleneck_info)
314
+
315
+ if return_info:
316
+ return latents, info
317
+
318
+ return latents
319
+
320
+ def decode(self, latents, iterate_batch=False, **kwargs):
321
+
322
+ if self.bottleneck is not None:
323
+ if iterate_batch:
324
+ decoded = []
325
+ for i in range(latents.shape[0]):
326
+ decoded.append(self.bottleneck.decode(latents[i:i+1]))
327
+ latents = torch.cat(decoded, dim=0)
328
+ else:
329
+ latents = self.bottleneck.decode(latents)
330
+
331
+ if iterate_batch:
332
+ decoded = []
333
+ for i in range(latents.shape[0]):
334
+ decoded.append(self.decoder(latents[i:i+1]))
335
+ decoded = torch.cat(decoded, dim=0)
336
+ else:
337
+ decoded = self.decoder(latents, **kwargs)
338
+
339
+ if self.pretransform is not None:
340
+ if self.pretransform.enable_grad:
341
+ if iterate_batch:
342
+ decodeds = []
343
+ for i in range(decoded.shape[0]):
344
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
345
+ decoded = torch.cat(decodeds, dim=0)
346
+ else:
347
+ decoded = self.pretransform.decode(decoded)
348
+ else:
349
+ with torch.no_grad():
350
+ if iterate_batch:
351
+ decodeds = []
352
+ for i in range(latents.shape[0]):
353
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
354
+ decoded = torch.cat(decodeds, dim=0)
355
+ else:
356
+ decoded = self.pretransform.decode(decoded)
357
+
358
+ if self.soft_clip:
359
+ decoded = torch.tanh(decoded)
360
+
361
+ return decoded
362
+
363
+ def decode_tokens(self, tokens, **kwargs):
364
+ '''
365
+ Decode discrete tokens to audio
366
+ Only works with discrete autoencoders
367
+ '''
368
+
369
+ assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
370
+
371
+ latents = self.bottleneck.decode_tokens(tokens, **kwargs)
372
+
373
+ return self.decode(latents, **kwargs)
374
+
375
+
376
+ def preprocess_audio_for_encoder(self, audio, in_sr):
377
+ '''
378
+ Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
379
+ If the model is mono, stereo audio will be converted to mono.
380
+ Audio will be silence-padded to be a multiple of the model's downsampling ratio.
381
+ Audio will be resampled to the model's sample rate.
382
+ The output will have batch size 1 and be shape (1 x Channels x Length)
383
+ '''
384
+ return self.preprocess_audio_list_for_encoder([audio], [in_sr])
385
+
386
+ def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
387
+ '''
388
+ Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
389
+ The audio in that list can be of different lengths and channels.
390
+ in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
391
+ All audio will be resampled to the model's sample rate.
392
+ Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
393
+ If the model is mono, all audio will be converted to mono.
394
+ The output will be a tensor of shape (Batch x Channels x Length)
395
+ '''
396
+ batch_size = len(audio_list)
397
+ if isinstance(in_sr_list, int):
398
+ in_sr_list = [in_sr_list]*batch_size
399
+ assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
400
+ new_audio = []
401
+ max_length = 0
402
+ # resample & find the max length
403
+ for i in range(batch_size):
404
+ audio = audio_list[i]
405
+ in_sr = in_sr_list[i]
406
+ if len(audio.shape) == 3 and audio.shape[0] == 1:
407
+ # batchsize 1 was given by accident. Just squeeze it.
408
+ audio = audio.squeeze(0)
409
+ elif len(audio.shape) == 1:
410
+ # Mono signal, channel dimension is missing, unsqueeze it in
411
+ audio = audio.unsqueeze(0)
412
+ assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
413
+ # Resample audio
414
+ if in_sr != self.sample_rate:
415
+ resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
416
+ audio = resample_tf(audio)
417
+ new_audio.append(audio)
418
+ if audio.shape[-1] > max_length:
419
+ max_length = audio.shape[-1]
420
+ # Pad every audio to the same length, multiple of model's downsampling ratio
421
+ padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
422
+ for i in range(batch_size):
423
+ # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
424
+ new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
425
+ target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
426
+ # convert to tensor
427
+ return torch.stack(new_audio)
428
+
429
+ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
430
+ '''
431
+ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
432
+ If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
433
+ Overlap and chunk_size params are both measured in number of latents (not audio samples)
434
+ # and therefore you likely could use the same values with decode_audio.
435
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
436
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
437
+ You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
438
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
439
+ Smaller chunk_size uses less memory, but more compute.
440
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
441
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
442
+ '''
443
+ if not chunked:
444
+ # default behavior. Encode the entire audio in parallel
445
+ return self.encode(audio, **kwargs)
446
+ else:
447
+ # CHUNKED ENCODING
448
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
449
+ samples_per_latent = self.downsampling_ratio
450
+ total_size = audio.shape[2] # in samples
451
+ batch_size = audio.shape[0]
452
+ chunk_size *= samples_per_latent # converting metric in latents to samples
453
+ overlap *= samples_per_latent # converting metric in latents to samples
454
+ hop_size = chunk_size - overlap
455
+ chunks = []
456
+ for i in range(0, total_size - chunk_size + 1, hop_size):
457
+ chunk = audio[:,:,i:i+chunk_size]
458
+ chunks.append(chunk)
459
+ if i+chunk_size != total_size:
460
+ # Final chunk
461
+ chunk = audio[:,:,-chunk_size:]
462
+ chunks.append(chunk)
463
+ chunks = torch.stack(chunks)
464
+ num_chunks = chunks.shape[0]
465
+ # Note: y_size might be a different value from the latent length used in diffusion training
466
+ # because we can encode audio of varying lengths
467
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
468
+ y_size = total_size // samples_per_latent
469
+ # Create an empty latent, we will populate it with chunks as we encode them
470
+ y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
471
+ for i in range(num_chunks):
472
+ x_chunk = chunks[i,:]
473
+ # encode the chunk
474
+ y_chunk = self.encode(x_chunk)
475
+ # figure out where to put the audio along the time domain
476
+ if i == num_chunks-1:
477
+ # final chunk always goes at the end
478
+ t_end = y_size
479
+ t_start = t_end - y_chunk.shape[2]
480
+ else:
481
+ t_start = i * hop_size // samples_per_latent
482
+ t_end = t_start + chunk_size // samples_per_latent
483
+ # remove the edges of the overlaps
484
+ ol = overlap//samples_per_latent//2
485
+ chunk_start = 0
486
+ chunk_end = y_chunk.shape[2]
487
+ if i > 0:
488
+ # no overlap for the start of the first chunk
489
+ t_start += ol
490
+ chunk_start += ol
491
+ if i < num_chunks-1:
492
+ # no overlap for the end of the last chunk
493
+ t_end -= ol
494
+ chunk_end -= ol
495
+ # paste the chunked audio into our y_final output audio
496
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
497
+ return y_final
498
+
499
+ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
500
+ '''
501
+ Decode latents to audio.
502
+ If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
503
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
504
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
505
+ You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
506
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
507
+ Smaller chunk_size uses less memory, but more compute.
508
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
509
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
510
+ '''
511
+ if not chunked:
512
+ # default behavior. Decode the entire latent in parallel
513
+ return self.decode(latents, **kwargs)
514
+ else:
515
+ # chunked decoding
516
+ hop_size = chunk_size - overlap
517
+ total_size = latents.shape[2]
518
+ batch_size = latents.shape[0]
519
+ chunks = []
520
+ for i in range(0, total_size - chunk_size + 1, hop_size):
521
+ chunk = latents[:,:,i:i+chunk_size]
522
+ chunks.append(chunk)
523
+ if i+chunk_size != total_size:
524
+ # Final chunk
525
+ chunk = latents[:,:,-chunk_size:]
526
+ chunks.append(chunk)
527
+ chunks = torch.stack(chunks)
528
+ num_chunks = chunks.shape[0]
529
+ # samples_per_latent is just the downsampling ratio
530
+ samples_per_latent = self.downsampling_ratio
531
+ # Create an empty waveform, we will populate it with chunks as decode them
532
+ y_size = total_size * samples_per_latent
533
+ y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
534
+ for i in range(num_chunks):
535
+ x_chunk = chunks[i,:]
536
+ # decode the chunk
537
+ y_chunk = self.decode(x_chunk)
538
+ # figure out where to put the audio along the time domain
539
+ if i == num_chunks-1:
540
+ # final chunk always goes at the end
541
+ t_end = y_size
542
+ t_start = t_end - y_chunk.shape[2]
543
+ else:
544
+ t_start = i * hop_size * samples_per_latent
545
+ t_end = t_start + chunk_size * samples_per_latent
546
+ # remove the edges of the overlaps
547
+ ol = (overlap//2) * samples_per_latent
548
+ chunk_start = 0
549
+ chunk_end = y_chunk.shape[2]
550
+ if i > 0:
551
+ # no overlap for the start of the first chunk
552
+ t_start += ol
553
+ chunk_start += ol
554
+ if i < num_chunks-1:
555
+ # no overlap for the end of the last chunk
556
+ t_end -= ol
557
+ chunk_end -= ol
558
+ # paste the chunked audio into our y_final output audio
559
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
560
+ return y_final
561
+
562
+
563
+ class DiffusionAutoencoder(AudioAutoencoder):
564
+ def __init__(
565
+ self,
566
+ diffusion: ConditionedDiffusionModel,
567
+ diffusion_downsampling_ratio,
568
+ *args,
569
+ **kwargs
570
+ ):
571
+ super().__init__(*args, **kwargs)
572
+
573
+ self.diffusion = diffusion
574
+
575
+ self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
576
+
577
+ if self.encoder is not None:
578
+ # Shrink the initial encoder parameters to avoid saturated latents
579
+ with torch.no_grad():
580
+ for param in self.encoder.parameters():
581
+ param *= 0.5
582
+
583
+ def decode(self, latents, steps=100):
584
+
585
+ upsampled_length = latents.shape[2] * self.downsampling_ratio
586
+
587
+ if self.bottleneck is not None:
588
+ latents = self.bottleneck.decode(latents)
589
+
590
+ if self.decoder is not None:
591
+ latents = self.decode(latents)
592
+
593
+ # Upsample latents to match diffusion length
594
+ if latents.shape[2] != upsampled_length:
595
+ latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
596
+
597
+ noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
598
+ decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
599
+
600
+ if self.pretransform is not None:
601
+ if self.pretransform.enable_grad:
602
+ decoded = self.pretransform.decode(decoded)
603
+ else:
604
+ with torch.no_grad():
605
+ decoded = self.pretransform.decode(decoded)
606
+
607
+ return decoded
608
+
609
+ # AE factories
610
+
611
+ def create_encoder_from_config(encoder_config: Dict[str, Any]):
612
+ encoder_type = encoder_config.get("type", None)
613
+ assert encoder_type is not None, "Encoder type must be specified"
614
+
615
+ if encoder_type == "oobleck":
616
+ encoder = OobleckEncoder(
617
+ **encoder_config["config"]
618
+ )
619
+
620
+ elif encoder_type == "seanet":
621
+ from encodec.modules import SEANetEncoder
622
+ seanet_encoder_config = encoder_config["config"]
623
+
624
+ #SEANet encoder expects strides in reverse order
625
+ seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
626
+ encoder = SEANetEncoder(
627
+ **seanet_encoder_config
628
+ )
629
+ elif encoder_type == "dac":
630
+ dac_config = encoder_config["config"]
631
+
632
+ encoder = DACEncoderWrapper(**dac_config)
633
+ elif encoder_type == "local_attn":
634
+ from .local_attention import TransformerEncoder1D
635
+
636
+ local_attn_config = encoder_config["config"]
637
+
638
+ encoder = TransformerEncoder1D(
639
+ **local_attn_config
640
+ )
641
+ else:
642
+ raise ValueError(f"Unknown encoder type {encoder_type}")
643
+
644
+ requires_grad = encoder_config.get("requires_grad", True)
645
+ if not requires_grad:
646
+ for param in encoder.parameters():
647
+ param.requires_grad = False
648
+
649
+ return encoder
650
+
651
+ def create_decoder_from_config(decoder_config: Dict[str, Any]):
652
+ decoder_type = decoder_config.get("type", None)
653
+ assert decoder_type is not None, "Decoder type must be specified"
654
+
655
+ if decoder_type == "oobleck":
656
+ decoder = OobleckDecoder(
657
+ **decoder_config["config"]
658
+ )
659
+ elif decoder_type == "seanet":
660
+ from encodec.modules import SEANetDecoder
661
+
662
+ decoder = SEANetDecoder(
663
+ **decoder_config["config"]
664
+ )
665
+ elif decoder_type == "dac":
666
+ dac_config = decoder_config["config"]
667
+
668
+ decoder = DACDecoderWrapper(**dac_config)
669
+ elif decoder_type == "local_attn":
670
+ from .local_attention import TransformerDecoder1D
671
+
672
+ local_attn_config = decoder_config["config"]
673
+
674
+ decoder = TransformerDecoder1D(
675
+ **local_attn_config
676
+ )
677
+ else:
678
+ raise ValueError(f"Unknown decoder type {decoder_type}")
679
+
680
+ requires_grad = decoder_config.get("requires_grad", True)
681
+ if not requires_grad:
682
+ for param in decoder.parameters():
683
+ param.requires_grad = False
684
+
685
+ return decoder
686
+
687
+ def create_autoencoder_from_config(config: Dict[str, Any]):
688
+
689
+ ae_config = config["model"]
690
+
691
+ encoder = create_encoder_from_config(ae_config["encoder"])
692
+ decoder = create_decoder_from_config(ae_config["decoder"])
693
+
694
+ bottleneck = ae_config.get("bottleneck", None)
695
+
696
+ latent_dim = ae_config.get("latent_dim", None)
697
+ assert latent_dim is not None, "latent_dim must be specified in model config"
698
+ downsampling_ratio = ae_config.get("downsampling_ratio", None)
699
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
700
+ io_channels = ae_config.get("io_channels", None)
701
+ assert io_channels is not None, "io_channels must be specified in model config"
702
+ sample_rate = config.get("sample_rate", None)
703
+ assert sample_rate is not None, "sample_rate must be specified in model config"
704
+
705
+ in_channels = ae_config.get("in_channels", None)
706
+ out_channels = ae_config.get("out_channels", None)
707
+
708
+ pretransform = ae_config.get("pretransform", None)
709
+
710
+ if pretransform is not None:
711
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
712
+
713
+ if bottleneck is not None:
714
+ bottleneck = create_bottleneck_from_config(bottleneck)
715
+
716
+ soft_clip = ae_config["decoder"].get("soft_clip", False)
717
+
718
+ return AudioAutoencoder(
719
+ encoder,
720
+ decoder,
721
+ io_channels=io_channels,
722
+ latent_dim=latent_dim,
723
+ downsampling_ratio=downsampling_ratio,
724
+ sample_rate=sample_rate,
725
+ bottleneck=bottleneck,
726
+ pretransform=pretransform,
727
+ in_channels=in_channels,
728
+ out_channels=out_channels,
729
+ soft_clip=soft_clip
730
+ )
731
+
732
+ def create_diffAE_from_config(config: Dict[str, Any]):
733
+
734
+ diffae_config = config["model"]
735
+
736
+ if "encoder" in diffae_config:
737
+ encoder = create_encoder_from_config(diffae_config["encoder"])
738
+ else:
739
+ encoder = None
740
+
741
+ if "decoder" in diffae_config:
742
+ decoder = create_decoder_from_config(diffae_config["decoder"])
743
+ else:
744
+ decoder = None
745
+
746
+ diffusion_model_type = diffae_config["diffusion"]["type"]
747
+
748
+ if diffusion_model_type == "DAU1d":
749
+ diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
750
+ elif diffusion_model_type == "adp_1d":
751
+ diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
752
+ elif diffusion_model_type == "dit":
753
+ diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
754
+
755
+ latent_dim = diffae_config.get("latent_dim", None)
756
+ assert latent_dim is not None, "latent_dim must be specified in model config"
757
+ downsampling_ratio = diffae_config.get("downsampling_ratio", None)
758
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
759
+ io_channels = diffae_config.get("io_channels", None)
760
+ assert io_channels is not None, "io_channels must be specified in model config"
761
+ sample_rate = config.get("sample_rate", None)
762
+ assert sample_rate is not None, "sample_rate must be specified in model config"
763
+
764
+ bottleneck = diffae_config.get("bottleneck", None)
765
+
766
+ pretransform = diffae_config.get("pretransform", None)
767
+
768
+ if pretransform is not None:
769
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
770
+
771
+ if bottleneck is not None:
772
+ bottleneck = create_bottleneck_from_config(bottleneck)
773
+
774
+ diffusion_downsampling_ratio = None,
775
+
776
+ if diffusion_model_type == "DAU1d":
777
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
778
+ elif diffusion_model_type == "adp_1d":
779
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
780
+ elif diffusion_model_type == "dit":
781
+ diffusion_downsampling_ratio = 1
782
+
783
+ return DiffusionAutoencoder(
784
+ encoder=encoder,
785
+ decoder=decoder,
786
+ diffusion=diffusion,
787
+ io_channels=io_channels,
788
+ sample_rate=sample_rate,
789
+ latent_dim=latent_dim,
790
+ downsampling_ratio=downsampling_ratio,
791
+ diffusion_downsampling_ratio=diffusion_downsampling_ratio,
792
+ bottleneck=bottleneck,
793
+ pretransform=pretransform
794
+ )
stable_audio_tools/models/blocks.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from torch.backends.cuda import sdp_kernel
9
+ from packaging import version
10
+
11
+ from dac.nn.layers import Snake1d
12
+
13
+ class ResidualBlock(nn.Module):
14
+ def __init__(self, main, skip=None):
15
+ super().__init__()
16
+ self.main = nn.Sequential(*main)
17
+ self.skip = skip if skip else nn.Identity()
18
+
19
+ def forward(self, input):
20
+ return self.main(input) + self.skip(input)
21
+
22
+ class ResConvBlock(ResidualBlock):
23
+ def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
24
+ skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
25
+ super().__init__([
26
+ nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
27
+ nn.GroupNorm(1, c_mid),
28
+ Snake1d(c_mid) if use_snake else nn.GELU(),
29
+ nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
30
+ nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
31
+ (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
32
+ ], skip)
33
+
34
+ class SelfAttention1d(nn.Module):
35
+ def __init__(self, c_in, n_head=1, dropout_rate=0.):
36
+ super().__init__()
37
+ assert c_in % n_head == 0
38
+ self.norm = nn.GroupNorm(1, c_in)
39
+ self.n_head = n_head
40
+ self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
41
+ self.out_proj = nn.Conv1d(c_in, c_in, 1)
42
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
43
+
44
+ def forward(self, input):
45
+ n, c, s = input.shape
46
+ qkv = self.qkv_proj(self.norm(input))
47
+ qkv = qkv.view(
48
+ [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
49
+ q, k, v = qkv.chunk(3, dim=1)
50
+ scale = k.shape[3]**-0.25
51
+
52
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
53
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
54
+
55
+ return input + self.dropout(self.out_proj(y))
56
+
57
+ class SkipBlock(nn.Module):
58
+ def __init__(self, *main):
59
+ super().__init__()
60
+ self.main = nn.Sequential(*main)
61
+
62
+ def forward(self, input):
63
+ return torch.cat([self.main(input), input], dim=1)
64
+
65
+ class FourierFeatures(nn.Module):
66
+ def __init__(self, in_features, out_features, std=1.):
67
+ super().__init__()
68
+ assert out_features % 2 == 0
69
+ self.weight = nn.Parameter(torch.randn(
70
+ [out_features // 2, in_features]) * std)
71
+
72
+ def forward(self, input):
73
+ f = 2 * math.pi * input @ self.weight.T
74
+ return torch.cat([f.cos(), f.sin()], dim=-1)
75
+
76
+ def expand_to_planes(input, shape):
77
+ return input[..., None].repeat([1, 1, shape[2]])
78
+
79
+ _kernels = {
80
+ 'linear':
81
+ [1 / 8, 3 / 8, 3 / 8, 1 / 8],
82
+ 'cubic':
83
+ [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
84
+ 0.43359375, 0.11328125, -0.03515625, -0.01171875],
85
+ 'lanczos3':
86
+ [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
87
+ -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
88
+ 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
89
+ -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
90
+ }
91
+
92
+ class Downsample1d(nn.Module):
93
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
94
+ super().__init__()
95
+ self.pad_mode = pad_mode
96
+ kernel_1d = torch.tensor(_kernels[kernel])
97
+ self.pad = kernel_1d.shape[0] // 2 - 1
98
+ self.register_buffer('kernel', kernel_1d)
99
+ self.channels_last = channels_last
100
+
101
+ def forward(self, x):
102
+ if self.channels_last:
103
+ x = x.permute(0, 2, 1)
104
+ x = F.pad(x, (self.pad,) * 2, self.pad_mode)
105
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
106
+ indices = torch.arange(x.shape[1], device=x.device)
107
+ weight[indices, indices] = self.kernel.to(weight)
108
+ x = F.conv1d(x, weight, stride=2)
109
+ if self.channels_last:
110
+ x = x.permute(0, 2, 1)
111
+ return x
112
+
113
+
114
+ class Upsample1d(nn.Module):
115
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
116
+ super().__init__()
117
+ self.pad_mode = pad_mode
118
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
119
+ self.pad = kernel_1d.shape[0] // 2 - 1
120
+ self.register_buffer('kernel', kernel_1d)
121
+ self.channels_last = channels_last
122
+
123
+ def forward(self, x):
124
+ if self.channels_last:
125
+ x = x.permute(0, 2, 1)
126
+ x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
127
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
128
+ indices = torch.arange(x.shape[1], device=x.device)
129
+ weight[indices, indices] = self.kernel.to(weight)
130
+ x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
131
+ if self.channels_last:
132
+ x = x.permute(0, 2, 1)
133
+ return x
134
+
135
+ def Downsample1d_2(
136
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
137
+ ) -> nn.Module:
138
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
139
+
140
+ return nn.Conv1d(
141
+ in_channels=in_channels,
142
+ out_channels=out_channels,
143
+ kernel_size=factor * kernel_multiplier + 1,
144
+ stride=factor,
145
+ padding=factor * (kernel_multiplier // 2),
146
+ )
147
+
148
+
149
+ def Upsample1d_2(
150
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
151
+ ) -> nn.Module:
152
+
153
+ if factor == 1:
154
+ return nn.Conv1d(
155
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
156
+ )
157
+
158
+ if use_nearest:
159
+ return nn.Sequential(
160
+ nn.Upsample(scale_factor=factor, mode="nearest"),
161
+ nn.Conv1d(
162
+ in_channels=in_channels,
163
+ out_channels=out_channels,
164
+ kernel_size=3,
165
+ padding=1,
166
+ ),
167
+ )
168
+ else:
169
+ return nn.ConvTranspose1d(
170
+ in_channels=in_channels,
171
+ out_channels=out_channels,
172
+ kernel_size=factor * 2,
173
+ stride=factor,
174
+ padding=factor // 2 + factor % 2,
175
+ output_padding=factor % 2,
176
+ )
177
+
178
+ def zero_init(layer):
179
+ nn.init.zeros_(layer.weight)
180
+ if layer.bias is not None:
181
+ nn.init.zeros_(layer.bias)
182
+ return layer
183
+
184
+ def rms_norm(x, scale, eps):
185
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
186
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
187
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
188
+ return x * scale.to(x.dtype)
189
+
190
+ #rms_norm = torch.compile(rms_norm)
191
+
192
+ class AdaRMSNorm(nn.Module):
193
+ def __init__(self, features, cond_features, eps=1e-6):
194
+ super().__init__()
195
+ self.eps = eps
196
+ self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
197
+
198
+ def extra_repr(self):
199
+ return f"eps={self.eps},"
200
+
201
+ def forward(self, x, cond):
202
+ return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
203
+
204
+ def normalize(x, eps=1e-4):
205
+ dim = list(range(1, x.ndim))
206
+ n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
207
+ alpha = np.sqrt(n.numel() / x.numel())
208
+ return x / torch.add(eps, n, alpha=alpha)
209
+
210
+ class ForcedWNConv1d(nn.Module):
211
+ def __init__(self, in_channels, out_channels, kernel_size=1):
212
+ super().__init__()
213
+ self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
214
+
215
+ def forward(self, x):
216
+ if self.training:
217
+ with torch.no_grad():
218
+ self.weight.copy_(normalize(self.weight))
219
+
220
+ fan_in = self.weight[0].numel()
221
+
222
+ w = normalize(self.weight) / math.sqrt(fan_in)
223
+
224
+ return F.conv1d(x, w, padding='same')
225
+
226
+ # Kernels
227
+
228
+ # use_compile = True
229
+ use_compile = False
230
+
231
+ def compile(function, *args, **kwargs):
232
+ if not use_compile:
233
+ return function
234
+ try:
235
+ return torch.compile(function, *args, **kwargs)
236
+ except RuntimeError:
237
+ return function
238
+
239
+
240
+ @compile
241
+ def linear_geglu(x, weight, bias=None):
242
+ x = x @ weight.mT
243
+ if bias is not None:
244
+ x = x + bias
245
+ x, gate = x.chunk(2, dim=-1)
246
+ return x * F.gelu(gate)
247
+
248
+
249
+ @compile
250
+ def rms_norm(x, scale, eps):
251
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
252
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
253
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
254
+ return x * scale.to(x.dtype)
255
+
256
+ # Layers
257
+
258
+ class LinearGEGLU(nn.Linear):
259
+ def __init__(self, in_features, out_features, bias=True):
260
+ super().__init__(in_features, out_features * 2, bias=bias)
261
+ self.out_features = out_features
262
+
263
+ def forward(self, x):
264
+ return linear_geglu(x, self.weight, self.bias)
265
+
266
+
267
+ class RMSNorm(nn.Module):
268
+ def __init__(self, shape, fix_scale = False, eps=1e-6):
269
+ super().__init__()
270
+ self.eps = eps
271
+
272
+ if fix_scale:
273
+ self.register_buffer("scale", torch.ones(shape))
274
+ else:
275
+ self.scale = nn.Parameter(torch.ones(shape))
276
+
277
+ def extra_repr(self):
278
+ return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
279
+
280
+ def forward(self, x):
281
+ return rms_norm(x, self.scale, self.eps)
282
+
283
+ def snake_beta(x, alpha, beta):
284
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
285
+
286
+ # try:
287
+ # snake_beta = torch.compile(snake_beta)
288
+ # except RuntimeError:
289
+ # pass
290
+
291
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
292
+ # License available in LICENSES/LICENSE_NVIDIA.txt
293
+ class SnakeBeta(nn.Module):
294
+
295
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
296
+ super(SnakeBeta, self).__init__()
297
+ self.in_features = in_features
298
+
299
+ # initialize alpha
300
+ self.alpha_logscale = alpha_logscale
301
+ if self.alpha_logscale: # log scale alphas initialized to zeros
302
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
303
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
304
+ else: # linear scale alphas initialized to ones
305
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
306
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
307
+
308
+ self.alpha.requires_grad = alpha_trainable
309
+ self.beta.requires_grad = alpha_trainable
310
+
311
+ self.no_div_by_zero = 0.000000001
312
+
313
+ def forward(self, x):
314
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
315
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
316
+ if self.alpha_logscale:
317
+ alpha = torch.exp(alpha)
318
+ beta = torch.exp(beta)
319
+ x = snake_beta(x, alpha, beta)
320
+
321
+ return x
stable_audio_tools/models/bottleneck.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from einops import rearrange
7
+ from vector_quantize_pytorch import ResidualVQ, FSQ
8
+ from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
9
+
10
+ class Bottleneck(nn.Module):
11
+ def __init__(self, is_discrete: bool = False):
12
+ super().__init__()
13
+
14
+ self.is_discrete = is_discrete
15
+
16
+ def encode(self, x, return_info=False, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def decode(self, x):
20
+ raise NotImplementedError
21
+
22
+ class DiscreteBottleneck(Bottleneck):
23
+ def __init__(self, num_quantizers, codebook_size, tokens_id):
24
+ super().__init__(is_discrete=True)
25
+
26
+ self.num_quantizers = num_quantizers
27
+ self.codebook_size = codebook_size
28
+ self.tokens_id = tokens_id
29
+
30
+ def decode_tokens(self, codes, **kwargs):
31
+ raise NotImplementedError
32
+
33
+ class TanhBottleneck(Bottleneck):
34
+ def __init__(self):
35
+ super().__init__(is_discrete=False)
36
+ self.tanh = nn.Tanh()
37
+
38
+ def encode(self, x, return_info=False):
39
+ info = {}
40
+
41
+ x = torch.tanh(x)
42
+
43
+ if return_info:
44
+ return x, info
45
+ else:
46
+ return x
47
+
48
+ def decode(self, x):
49
+ return x
50
+
51
+ def vae_sample(mean, scale):
52
+ stdev = nn.functional.softplus(scale) + 1e-4
53
+ var = stdev * stdev
54
+ logvar = torch.log(var)
55
+ latents = torch.randn_like(mean) * stdev + mean
56
+
57
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
58
+
59
+ return latents, kl
60
+
61
+ class VAEBottleneck(Bottleneck):
62
+ def __init__(self):
63
+ super().__init__(is_discrete=False)
64
+
65
+ def encode(self, x, return_info=False, **kwargs):
66
+ info = {}
67
+
68
+ mean, scale = x.chunk(2, dim=1)
69
+
70
+ x, kl = vae_sample(mean, scale)
71
+
72
+ info["kl"] = kl
73
+
74
+ if return_info:
75
+ return x, info
76
+ else:
77
+ return x
78
+
79
+ def decode(self, x):
80
+ return x
81
+
82
+ def compute_mean_kernel(x, y):
83
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
84
+ return torch.exp(-kernel_input).mean()
85
+
86
+ def compute_mmd(latents):
87
+ latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
88
+ noise = torch.randn_like(latents_reshaped)
89
+
90
+ latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
91
+ noise_kernel = compute_mean_kernel(noise, noise)
92
+ latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
93
+
94
+ mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
95
+ return mmd.mean()
96
+
97
+ class WassersteinBottleneck(Bottleneck):
98
+ def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
99
+ super().__init__(is_discrete=False)
100
+
101
+ self.noise_augment_dim = noise_augment_dim
102
+ self.bypass_mmd = bypass_mmd
103
+
104
+ def encode(self, x, return_info=False):
105
+ info = {}
106
+
107
+ if self.training and return_info:
108
+ if self.bypass_mmd:
109
+ mmd = torch.tensor(0.0)
110
+ else:
111
+ mmd = compute_mmd(x)
112
+
113
+ info["mmd"] = mmd
114
+
115
+ if return_info:
116
+ return x, info
117
+
118
+ return x
119
+
120
+ def decode(self, x):
121
+
122
+ if self.noise_augment_dim > 0:
123
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
124
+ x.shape[-1]).type_as(x)
125
+ x = torch.cat([x, noise], dim=1)
126
+
127
+ return x
128
+
129
+ class L2Bottleneck(Bottleneck):
130
+ def __init__(self):
131
+ super().__init__(is_discrete=False)
132
+
133
+ def encode(self, x, return_info=False):
134
+ info = {}
135
+
136
+ x = F.normalize(x, dim=1)
137
+
138
+ if return_info:
139
+ return x, info
140
+ else:
141
+ return x
142
+
143
+ def decode(self, x):
144
+ return F.normalize(x, dim=1)
145
+
146
+ class RVQBottleneck(DiscreteBottleneck):
147
+ def __init__(self, **quantizer_kwargs):
148
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
149
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
150
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
151
+
152
+ def encode(self, x, return_info=False, **kwargs):
153
+ info = {}
154
+
155
+ x = rearrange(x, "b c n -> b n c")
156
+ x, indices, loss = self.quantizer(x)
157
+ x = rearrange(x, "b n c -> b c n")
158
+
159
+ info["quantizer_indices"] = indices
160
+ info["quantizer_loss"] = loss.mean()
161
+
162
+ if return_info:
163
+ return x, info
164
+ else:
165
+ return x
166
+
167
+ def decode(self, x):
168
+ return x
169
+
170
+ def decode_tokens(self, codes, **kwargs):
171
+ latents = self.quantizer.get_outputs_from_indices(codes)
172
+
173
+ return self.decode(latents, **kwargs)
174
+
175
+ class RVQVAEBottleneck(DiscreteBottleneck):
176
+ def __init__(self, **quantizer_kwargs):
177
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
178
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
179
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
180
+
181
+ def encode(self, x, return_info=False):
182
+ info = {}
183
+
184
+ x, kl = vae_sample(*x.chunk(2, dim=1))
185
+
186
+ info["kl"] = kl
187
+
188
+ x = rearrange(x, "b c n -> b n c")
189
+ x, indices, loss = self.quantizer(x)
190
+ x = rearrange(x, "b n c -> b c n")
191
+
192
+ info["quantizer_indices"] = indices
193
+ info["quantizer_loss"] = loss.mean()
194
+
195
+ if return_info:
196
+ return x, info
197
+ else:
198
+ return x
199
+
200
+ def decode(self, x):
201
+ return x
202
+
203
+ def decode_tokens(self, codes, **kwargs):
204
+ latents = self.quantizer.get_outputs_from_indices(codes)
205
+
206
+ return self.decode(latents, **kwargs)
207
+
208
+ class DACRVQBottleneck(DiscreteBottleneck):
209
+ def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
210
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
211
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
212
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
213
+ self.quantize_on_decode = quantize_on_decode
214
+ self.noise_augment_dim = noise_augment_dim
215
+
216
+ def encode(self, x, return_info=False, **kwargs):
217
+ info = {}
218
+
219
+ info["pre_quantizer"] = x
220
+
221
+ if self.quantize_on_decode:
222
+ return x, info if return_info else x
223
+
224
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
225
+
226
+ output = {
227
+ "z": z,
228
+ "codes": codes,
229
+ "latents": latents,
230
+ "vq/commitment_loss": commitment_loss,
231
+ "vq/codebook_loss": codebook_loss,
232
+ }
233
+
234
+ output["vq/commitment_loss"] /= self.num_quantizers
235
+ output["vq/codebook_loss"] /= self.num_quantizers
236
+
237
+ info.update(output)
238
+
239
+ if return_info:
240
+ return output["z"], info
241
+
242
+ return output["z"]
243
+
244
+ def decode(self, x):
245
+
246
+ if self.quantize_on_decode:
247
+ x = self.quantizer(x)[0]
248
+
249
+ if self.noise_augment_dim > 0:
250
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
251
+ x.shape[-1]).type_as(x)
252
+ x = torch.cat([x, noise], dim=1)
253
+
254
+ return x
255
+
256
+ def decode_tokens(self, codes, **kwargs):
257
+ latents, _, _ = self.quantizer.from_codes(codes)
258
+
259
+ return self.decode(latents, **kwargs)
260
+
261
+ class DACRVQVAEBottleneck(DiscreteBottleneck):
262
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
263
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
264
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
265
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
266
+ self.quantize_on_decode = quantize_on_decode
267
+
268
+ def encode(self, x, return_info=False, n_quantizers: int = None):
269
+ info = {}
270
+
271
+ mean, scale = x.chunk(2, dim=1)
272
+
273
+ x, kl = vae_sample(mean, scale)
274
+
275
+ info["pre_quantizer"] = x
276
+ info["kl"] = kl
277
+
278
+ if self.quantize_on_decode:
279
+ return x, info if return_info else x
280
+
281
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
282
+
283
+ output = {
284
+ "z": z,
285
+ "codes": codes,
286
+ "latents": latents,
287
+ "vq/commitment_loss": commitment_loss,
288
+ "vq/codebook_loss": codebook_loss,
289
+ }
290
+
291
+ output["vq/commitment_loss"] /= self.num_quantizers
292
+ output["vq/codebook_loss"] /= self.num_quantizers
293
+
294
+ info.update(output)
295
+
296
+ if return_info:
297
+ return output["z"], info
298
+
299
+ return output["z"]
300
+
301
+ def decode(self, x):
302
+
303
+ if self.quantize_on_decode:
304
+ x = self.quantizer(x)[0]
305
+
306
+ return x
307
+
308
+ def decode_tokens(self, codes, **kwargs):
309
+ latents, _, _ = self.quantizer.from_codes(codes)
310
+
311
+ return self.decode(latents, **kwargs)
312
+
313
+ class FSQBottleneck(DiscreteBottleneck):
314
+ def __init__(self, noise_augment_dim=0, **kwargs):
315
+ super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
316
+
317
+ self.noise_augment_dim = noise_augment_dim
318
+
319
+ self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
320
+
321
+ def encode(self, x, return_info=False):
322
+ info = {}
323
+
324
+ orig_dtype = x.dtype
325
+ x = x.float()
326
+
327
+ x = rearrange(x, "b c n -> b n c")
328
+ x, indices = self.quantizer(x)
329
+ x = rearrange(x, "b n c -> b c n")
330
+
331
+ x = x.to(orig_dtype)
332
+
333
+ # Reorder indices to match the expected format
334
+ indices = rearrange(indices, "b n q -> b q n")
335
+
336
+ info["quantizer_indices"] = indices
337
+
338
+ if return_info:
339
+ return x, info
340
+ else:
341
+ return x
342
+
343
+ def decode(self, x):
344
+
345
+ if self.noise_augment_dim > 0:
346
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
347
+ x.shape[-1]).type_as(x)
348
+ x = torch.cat([x, noise], dim=1)
349
+
350
+ return x
351
+
352
+ def decode_tokens(self, tokens, **kwargs):
353
+ latents = self.quantizer.indices_to_codes(tokens)
354
+
355
+ return self.decode(latents, **kwargs)
stable_audio_tools/models/codebook_patterns.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py under MIT License
2
+ # License available in LICENSES/LICENSE_META.txt
3
+
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ import logging
8
+ import typing as tp
9
+
10
+ from abc import ABC, abstractmethod
11
+ import torch
12
+
13
+ LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
14
+ PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class Pattern:
20
+ """Base implementation of a pattern over a sequence with multiple codebooks.
21
+
22
+ The codebook pattern consists in a layout, defining for each sequence step
23
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
24
+ The first item of the pattern is always an empty list in order to properly insert a special token
25
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
26
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
27
+
28
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
29
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
30
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
31
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
32
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
33
+ is returned along with a mask indicating valid tokens.
34
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
35
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
36
+ to fill and specify invalid positions if needed.
37
+ See the dedicated methods for more details.
38
+ """
39
+ # Pattern layout, for each sequence step, we have a list of coordinates
40
+ # corresponding to the original codebook timestep and position.
41
+ # The first list is always an empty list in order to properly insert
42
+ # a special token to start with.
43
+ layout: PatternLayout
44
+ timesteps: int
45
+ n_q: int
46
+
47
+ def __post_init__(self):
48
+ assert len(self.layout) > 0
49
+ self._validate_layout()
50
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
51
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
52
+ logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
53
+
54
+ def _validate_layout(self):
55
+ """Runs checks on the layout to ensure a valid pattern is defined.
56
+ A pattern is considered invalid if:
57
+ - Multiple timesteps for a same codebook are defined in the same sequence step
58
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
59
+ (this would mean that we have future timesteps before past timesteps).
60
+ """
61
+ q_timesteps = {q: 0 for q in range(self.n_q)}
62
+ for s, seq_coords in enumerate(self.layout):
63
+ if len(seq_coords) > 0:
64
+ qs = set()
65
+ for coord in seq_coords:
66
+ qs.add(coord.q)
67
+ last_q_timestep = q_timesteps[coord.q]
68
+ assert coord.t >= last_q_timestep, \
69
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
70
+ q_timesteps[coord.q] = coord.t
71
+ # each sequence step contains at max 1 coordinate per codebook
72
+ assert len(qs) == len(seq_coords), \
73
+ f"Multiple entries for a same codebook are found at step {s}"
74
+
75
+ @property
76
+ def num_sequence_steps(self):
77
+ return len(self.layout) - 1
78
+
79
+ @property
80
+ def max_delay(self):
81
+ max_t_in_seq_coords = 0
82
+ for seq_coords in self.layout[1:]:
83
+ for coords in seq_coords:
84
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
85
+ return max_t_in_seq_coords - self.timesteps
86
+
87
+ @property
88
+ def valid_layout(self):
89
+ valid_step = len(self.layout) - self.max_delay
90
+ return self.layout[:valid_step]
91
+
92
+ def starts_with_special_token(self):
93
+ return self.layout[0] == []
94
+
95
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
96
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
97
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
98
+ and the actual codebook coordinates.
99
+ """
100
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
101
+ if q is not None:
102
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
103
+ coords = []
104
+ for s, seq_codes in enumerate(self.layout):
105
+ for code in seq_codes:
106
+ if code.t == t and (q is None or code.q == q):
107
+ coords.append((s, code))
108
+ return coords
109
+
110
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
111
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
112
+
113
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
114
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
115
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
116
+
117
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
118
+ device: tp.Union[torch.device, str] = 'cpu'):
119
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
120
+
121
+ Args:
122
+ timesteps (int): Maximum number of timesteps steps to consider.
123
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
124
+ device (torch.device or str): Device for created tensors.
125
+ Returns:
126
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
127
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
128
+ """
129
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
130
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
131
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
132
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
133
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
134
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
135
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
136
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
137
+ # fill indexes with last sequence step value that will correspond to our special token
138
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
139
+ # which will correspond to the index: n_q * timesteps
140
+ indexes[:] = n_q * timesteps
141
+ # iterate over the pattern and fill scattered indexes and mask
142
+ for s, sequence_coords in enumerate(ref_layout):
143
+ for coords in sequence_coords:
144
+ if coords.t < timesteps:
145
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
146
+ mask[coords.q, s] = 1
147
+ indexes = torch.from_numpy(indexes).to(device)
148
+ mask = torch.from_numpy(mask).to(device)
149
+ return indexes, mask
150
+
151
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
152
+ """Build sequence corresponding to the pattern from the input tensor z.
153
+ The sequence is built using up to sequence_steps if specified, and non-pattern
154
+ coordinates are filled with the special token.
155
+
156
+ Args:
157
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
158
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
159
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
160
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
161
+ Returns:
162
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
163
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
164
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
165
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
166
+ """
167
+ B, K, T = z.shape
168
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
169
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
170
+ )
171
+ z = z.view(B, -1)
172
+ # we append the special token as the last index of our flattened z tensor
173
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
174
+ values = z[:, indexes.view(-1)]
175
+ values = values.view(B, K, indexes.shape[-1])
176
+ return values, indexes, mask
177
+
178
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
179
+ keep_only_valid_steps: bool = False,
180
+ is_model_output: bool = False,
181
+ device: tp.Union[torch.device, str] = 'cpu'):
182
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
183
+ from interleaving pattern.
184
+
185
+ Args:
186
+ sequence_steps (int): Sequence steps.
187
+ n_q (int): Number of codebooks.
188
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
189
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
190
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
191
+ device (torch.device or str): Device for created tensors.
192
+ Returns:
193
+ indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
194
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
195
+ """
196
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
197
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
198
+ timesteps = self.timesteps
199
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
200
+ assert sequence_steps <= len(ref_layout), \
201
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
202
+
203
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
204
+ if is_model_output and self.starts_with_special_token():
205
+ ref_layout = ref_layout[1:]
206
+
207
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
208
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
209
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
210
+ # fill indexes with last sequence step value that will correspond to our special token
211
+ indexes[:] = n_q * sequence_steps
212
+ for s, sequence_codes in enumerate(ref_layout):
213
+ if s < sequence_steps:
214
+ for code in sequence_codes:
215
+ if code.t < timesteps:
216
+ indexes[code.q, code.t] = s + code.q * sequence_steps
217
+ mask[code.q, code.t] = 1
218
+ indexes = torch.from_numpy(indexes).to(device)
219
+ mask = torch.from_numpy(mask).to(device)
220
+ return indexes, mask
221
+
222
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
223
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
224
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
225
+ are filled with the special token.
226
+
227
+ Args:
228
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
229
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
230
+ Returns:
231
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
232
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
233
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
234
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
235
+ """
236
+ B, K, S = s.shape
237
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
238
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
239
+ )
240
+ s = s.view(B, -1)
241
+ # we append the special token as the last index of our flattened z tensor
242
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
243
+ values = s[:, indexes.view(-1)]
244
+ values = values.view(B, K, indexes.shape[-1])
245
+ return values, indexes, mask
246
+
247
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
248
+ """Revert model logits obtained on a sequence built from the pattern
249
+ back to a tensor matching the original sequence.
250
+
251
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
252
+ 1. It is designed to work with the extra cardinality dimension
253
+ 2. We return the logits for the first sequence item that matches the special_token and
254
+ which matching target in the original sequence is the first item of the sequence,
255
+ while we skip the last logits as there is no matching target
256
+ """
257
+ B, card, K, S = logits.shape
258
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
259
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
260
+ )
261
+ logits = logits.reshape(B, card, -1)
262
+ # we append the special token as the last index of our flattened z tensor
263
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
264
+ values = logits[:, :, indexes.view(-1)]
265
+ values = values.view(B, card, K, indexes.shape[-1])
266
+ return values, indexes, mask
267
+
268
+
269
+ class CodebooksPatternProvider(ABC):
270
+ """Abstraction around providing pattern for interleaving codebooks.
271
+
272
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
273
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
274
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
275
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
276
+ can be used to construct a new sequence from the original codes respecting the specified
277
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
278
+ being a tuple with the original timestep and codebook to build the new sequence.
279
+ Note that all patterns must start with an empty list that is then used to insert a first
280
+ sequence step of special tokens in the newly generated sequence.
281
+
282
+ Args:
283
+ n_q (int): number of codebooks.
284
+ cached (bool): if True, patterns for a given length are cached. In general
285
+ that should be true for efficiency reason to avoid synchronization points.
286
+ """
287
+ def __init__(self, n_q: int, cached: bool = True):
288
+ assert n_q > 0
289
+ self.n_q = n_q
290
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
291
+
292
+ @abstractmethod
293
+ def get_pattern(self, timesteps: int) -> Pattern:
294
+ """Builds pattern with specific interleaving between codebooks.
295
+
296
+ Args:
297
+ timesteps (int): Total number of timesteps.
298
+ """
299
+ raise NotImplementedError()
300
+
301
+
302
+ class DelayedPatternProvider(CodebooksPatternProvider):
303
+ """Provider for delayed pattern across delayed codebooks.
304
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
305
+ from different timesteps.
306
+
307
+ Example:
308
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
309
+ [[1, 2, 3, 4],
310
+ [1, 2, 3, 4],
311
+ [1, 2, 3, 4]]
312
+ The resulting sequence obtained from the returned pattern is:
313
+ [[S, 1, 2, 3, 4],
314
+ [S, S, 1, 2, 3],
315
+ [S, S, S, 1, 2]]
316
+ (with S being a special token)
317
+
318
+ Args:
319
+ n_q (int): Number of codebooks.
320
+ delays (list of int, optional): Delay for each of the codebooks.
321
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
322
+ flatten_first (int): Flatten the first N timesteps.
323
+ empty_initial (int): Prepend with N empty list of coordinates.
324
+ """
325
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
326
+ flatten_first: int = 0, empty_initial: int = 0):
327
+ super().__init__(n_q)
328
+ if delays is None:
329
+ delays = list(range(n_q))
330
+ self.delays = delays
331
+ self.flatten_first = flatten_first
332
+ self.empty_initial = empty_initial
333
+ assert len(self.delays) == self.n_q
334
+ assert sorted(self.delays) == self.delays
335
+
336
+ def get_pattern(self, timesteps: int) -> Pattern:
337
+ omit_special_token = self.empty_initial < 0
338
+ out: PatternLayout = [] if omit_special_token else [[]]
339
+ max_delay = max(self.delays)
340
+ if self.empty_initial:
341
+ out += [[] for _ in range(self.empty_initial)]
342
+ if self.flatten_first:
343
+ for t in range(min(timesteps, self.flatten_first)):
344
+ for q in range(self.n_q):
345
+ out.append([LayoutCoord(t, q)])
346
+ for t in range(self.flatten_first, timesteps + max_delay):
347
+ v = []
348
+ for q, delay in enumerate(self.delays):
349
+ t_for_q = t - delay
350
+ if t_for_q >= self.flatten_first:
351
+ v.append(LayoutCoord(t_for_q, q))
352
+ out.append(v)
353
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
354
+
355
+
356
+ class ParallelPatternProvider(DelayedPatternProvider):
357
+ """Provider for parallel pattern across codebooks.
358
+ This pattern provider is a special case of the delayed pattern with actually no delay,
359
+ hence delays=repeat(0, n_q).
360
+
361
+ Args:
362
+ n_q (int): Number of codebooks.
363
+ empty_initial (int): Prepend with N empty list of coordinates.
364
+ """
365
+ def __init__(self, n_q: int, empty_initial: int = 0):
366
+ super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
367
+
368
+
369
+ class UnrolledPatternProvider(CodebooksPatternProvider):
370
+ """Provider for unrolling codebooks pattern.
371
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
372
+ while also specifying a given delay between the flattened codebooks representation, allowing to
373
+ unroll the codebooks in the sequence.
374
+
375
+ Example:
376
+ 1. Flattening of the codebooks.
377
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
378
+ taking n_q = 3 and timesteps = 4:
379
+ [[1, 2, 3, 4],
380
+ [1, 2, 3, 4],
381
+ [1, 2, 3, 4]]
382
+ will result into:
383
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
384
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
385
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
386
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
387
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
388
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
389
+ [[1, 2, 3, 4],
390
+ [1, 2, 3, 4],
391
+ [1, 2, 3, 4]]
392
+ will result into:
393
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
394
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
395
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
396
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
397
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
398
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
399
+ and delays = [0, 3, 3]:
400
+ [[1, 2, 3, 4],
401
+ [1, 2, 3, 4],
402
+ [1, 2, 3, 4]]
403
+ will result into:
404
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
405
+ [S, S, S, 1, S, 2, S, 3, S, 4],
406
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
407
+
408
+ Args:
409
+ n_q (int): Number of codebooks.
410
+ flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
411
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
412
+ have n_q extra steps for each timestep.
413
+ delays (list of int, optional): Delay for each of the codebooks. If not defined,
414
+ no delay is added and therefore will default to [0] * ``n_q``.
415
+ Note that two codebooks that will be flattened to the same inner step
416
+ should have the same delay, otherwise the pattern is considered as invalid.
417
+ """
418
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
419
+
420
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
421
+ delays: tp.Optional[tp.List[int]] = None):
422
+ super().__init__(n_q)
423
+ if flattening is None:
424
+ flattening = list(range(n_q))
425
+ if delays is None:
426
+ delays = [0] * n_q
427
+ assert len(flattening) == n_q
428
+ assert len(delays) == n_q
429
+ assert sorted(flattening) == flattening
430
+ assert sorted(delays) == delays
431
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
432
+ self.max_delay = max(delays)
433
+
434
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
435
+ """Build a flattened codebooks representation as a dictionary of inner step
436
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
437
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
438
+ """
439
+ flattened_codebooks: dict = {}
440
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
441
+ if inner_step not in flattened_codebooks:
442
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
443
+ else:
444
+ flat_codebook = flattened_codebooks[inner_step]
445
+ assert flat_codebook.delay == delay, (
446
+ "Delay and flattening between codebooks is inconsistent: ",
447
+ "two codebooks flattened to the same position should have the same delay."
448
+ )
449
+ flat_codebook.codebooks.append(q)
450
+ flattened_codebooks[inner_step] = flat_codebook
451
+ return flattened_codebooks
452
+
453
+ @property
454
+ def _num_inner_steps(self):
455
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
456
+ """
457
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
458
+
459
+ def num_virtual_steps(self, timesteps: int) -> int:
460
+ return timesteps * self._num_inner_steps + 1
461
+
462
+ def get_pattern(self, timesteps: int) -> Pattern:
463
+ """Builds pattern for delay across codebooks.
464
+
465
+ Args:
466
+ timesteps (int): Total number of timesteps.
467
+ """
468
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
469
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
470
+ indexed_out: list = [(-1, [])]
471
+ max_timesteps = timesteps + self.max_delay
472
+ for t in range(max_timesteps):
473
+ # for each timestep, we unroll the flattened codebooks,
474
+ # emitting the sequence step with the corresponding delay
475
+ for step in range(self._num_inner_steps):
476
+ if step in self._flattened_codebooks:
477
+ # we have codebooks at this virtual step to emit
478
+ step_codebooks = self._flattened_codebooks[step]
479
+ t_for_q = t + step_codebooks.delay
480
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
481
+ if t_for_q < max_timesteps and t < max_timesteps:
482
+ indexed_out.append((t_for_q, coords))
483
+ else:
484
+ # there is no codebook in this virtual step so we emit an empty list
485
+ indexed_out.append((t, []))
486
+ out = [coords for _, coords in sorted(indexed_out)]
487
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
488
+
489
+
490
+ class CoarseFirstPattern(CodebooksPatternProvider):
491
+ """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
492
+ potentially with delays.
493
+
494
+ ..Warning:: You must always generate the full training duration at test time, for instance,
495
+ 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
496
+ location. This is due to the non causality of the remaining codebooks with respect to
497
+ the first ones.
498
+
499
+ Args:
500
+ n_q (int): Number of codebooks.
501
+ delays (list of int, optional): Delay for each of the codebooks.
502
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
503
+ """
504
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
505
+ super().__init__(n_q)
506
+ if delays is None:
507
+ delays = [0] * (n_q - 1)
508
+ self.delays = delays
509
+ assert len(self.delays) == self.n_q - 1
510
+ assert sorted(self.delays) == self.delays
511
+
512
+ def get_pattern(self, timesteps: int) -> Pattern:
513
+ out: PatternLayout = [[]]
514
+ for t in range(timesteps):
515
+ out.append([LayoutCoord(t, 0)])
516
+ max_delay = max(self.delays)
517
+ for t in range(timesteps + max_delay):
518
+ v = []
519
+ for q, delay in enumerate(self.delays):
520
+ t_for_q = t - delay
521
+ if t_for_q >= 0:
522
+ v.append(LayoutCoord(t_for_q, q + 1))
523
+ out.append(v)
524
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
525
+
526
+
527
+ class MusicLMPattern(CodebooksPatternProvider):
528
+ """Almost MusicLM style pattern. This is equivalent to full flattening
529
+ but in a different order.
530
+
531
+ Args:
532
+ n_q (int): Number of codebooks.
533
+ group_by (int): Number of codebooks to group together.
534
+ """
535
+ def __init__(self, n_q: int, group_by: int = 2):
536
+ super().__init__(n_q)
537
+ self.group_by = group_by
538
+
539
+ def get_pattern(self, timesteps: int) -> Pattern:
540
+ out: PatternLayout = [[]]
541
+ for offset in range(0, self.n_q, self.group_by):
542
+ for t in range(timesteps):
543
+ for q in range(offset, offset + self.group_by):
544
+ out.append([LayoutCoord(t, q)])
545
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
stable_audio_tools/models/conditioners.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging, warnings
3
+ import string
4
+ import typing as tp
5
+ import gc
6
+
7
+ from .adp import NumberEmbedder
8
+ from ..inference.utils import set_audio_channels
9
+ from .factory import create_pretransform_from_config
10
+ from .pretransforms import Pretransform
11
+ from .utils import load_ckpt_state_dict
12
+
13
+ from torch import nn
14
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection
15
+ import einops
16
+ from .temptransformer import SA_Transformer
17
+ from torchvision import transforms
18
+ import torch
19
+ import einops
20
+ import torchvision.transforms as transforms
21
+
22
+
23
+ class Conditioner(nn.Module):
24
+ def __init__(
25
+ self,
26
+ dim: int,
27
+ output_dim: int,
28
+ project_out: bool = False
29
+ ):
30
+
31
+ super().__init__()
32
+
33
+ self.dim = dim
34
+ self.output_dim = output_dim
35
+ self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
36
+
37
+ def forward(self, x: tp.Any) -> tp.Any:
38
+ raise NotImplementedError()
39
+
40
+ class IntConditioner(Conditioner):
41
+ def __init__(self,
42
+ output_dim: int,
43
+ min_val: int=0,
44
+ max_val: int=512
45
+ ):
46
+ super().__init__(output_dim, output_dim)
47
+
48
+ self.min_val = min_val
49
+ self.max_val = max_val
50
+ self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True)
51
+
52
+ def forward(self, ints: tp.List[int], device=None) -> tp.Any:
53
+
54
+ #self.int_embedder.to(device)
55
+
56
+ ints = torch.tensor(ints).to(device)
57
+ ints = ints.clamp(self.min_val, self.max_val)
58
+
59
+ int_embeds = self.int_embedder(ints).unsqueeze(1)
60
+
61
+ return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)]
62
+
63
+ class NumberConditioner(Conditioner):
64
+ '''
65
+ Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
66
+ '''
67
+ def __init__(self,
68
+ output_dim: int,
69
+ min_val: float=0,
70
+ max_val: float=1
71
+ ):
72
+ super().__init__(output_dim, output_dim)
73
+
74
+ self.min_val = min_val
75
+ self.max_val = max_val
76
+
77
+ self.embedder = NumberEmbedder(features=output_dim)
78
+
79
+ def forward(self, floats: tp.List[float], device=None) -> tp.Any:
80
+
81
+ # Cast the inputs to floats
82
+ floats = [float(x) for x in floats]
83
+
84
+ floats = torch.tensor(floats).to(device)
85
+
86
+ floats = floats.clamp(self.min_val, self.max_val)
87
+
88
+ normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
89
+
90
+ # Cast floats to same type as embedder
91
+ embedder_dtype = next(self.embedder.parameters()).dtype
92
+ normalized_floats = normalized_floats.to(embedder_dtype)
93
+
94
+ float_embeds = self.embedder(normalized_floats).unsqueeze(1)
95
+
96
+ return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
97
+
98
+ class CLAPTextConditioner(Conditioner):
99
+ def __init__(self,
100
+ output_dim: int,
101
+ clap_ckpt_path,
102
+ use_text_features = False,
103
+ feature_layer_ix: int = -1,
104
+ audio_model_type="HTSAT-base",
105
+ enable_fusion=True,
106
+ project_out: bool = False,
107
+ finetune: bool = False):
108
+ super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out)
109
+
110
+ self.use_text_features = use_text_features
111
+ self.feature_layer_ix = feature_layer_ix
112
+ self.finetune = finetune
113
+
114
+ # Suppress logging from transformers
115
+ previous_level = logging.root.manager.disable
116
+ logging.disable(logging.ERROR)
117
+ with warnings.catch_warnings():
118
+ warnings.simplefilter("ignore")
119
+ try:
120
+ import laion_clap
121
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
122
+
123
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
124
+
125
+ if self.finetune:
126
+ self.model = model
127
+ else:
128
+ self.__dict__["model"] = model
129
+
130
+ state_dict = clap_load_state_dict(clap_ckpt_path)
131
+ self.model.model.load_state_dict(state_dict, strict=False)
132
+
133
+ if self.finetune:
134
+ self.model.model.text_branch.requires_grad_(True)
135
+ self.model.model.text_branch.train()
136
+ else:
137
+ self.model.model.text_branch.requires_grad_(False)
138
+ self.model.model.text_branch.eval()
139
+
140
+ finally:
141
+ logging.disable(previous_level)
142
+
143
+ del self.model.model.audio_branch
144
+
145
+ gc.collect()
146
+ torch.cuda.empty_cache()
147
+
148
+ def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
149
+ prompt_tokens = self.model.tokenizer(prompts)
150
+ attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True)
151
+ prompt_features = self.model.model.text_branch(
152
+ input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True),
153
+ attention_mask=attention_mask,
154
+ output_hidden_states=True
155
+ )["hidden_states"][layer_ix]
156
+
157
+ return prompt_features, attention_mask
158
+
159
+ def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
160
+ self.model.to(device)
161
+
162
+ if self.use_text_features:
163
+ if len(texts) == 1:
164
+ text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device)
165
+ text_features = text_features[:1, ...]
166
+ text_attention_mask = text_attention_mask[:1, ...]
167
+ else:
168
+ text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device)
169
+ return [self.proj_out(text_features), text_attention_mask]
170
+
171
+ # Fix for CLAP bug when only one text is passed
172
+ if len(texts) == 1:
173
+ text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...]
174
+ else:
175
+ text_embedding = self.model.get_text_embedding(texts, use_tensor=True)
176
+
177
+ text_embedding = text_embedding.unsqueeze(1).to(device)
178
+
179
+ return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)]
180
+
181
+ class CLAPAudioConditioner(Conditioner):
182
+ def __init__(self,
183
+ output_dim: int,
184
+ clap_ckpt_path,
185
+ audio_model_type="HTSAT-base",
186
+ enable_fusion=True,
187
+ project_out: bool = False):
188
+ super().__init__(512, output_dim, project_out=project_out)
189
+
190
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
191
+
192
+ # Suppress logging from transformers
193
+ previous_level = logging.root.manager.disable
194
+ logging.disable(logging.ERROR)
195
+ with warnings.catch_warnings():
196
+ warnings.simplefilter("ignore")
197
+ try:
198
+ import laion_clap
199
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
200
+
201
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
202
+
203
+ if self.finetune:
204
+ self.model = model
205
+ else:
206
+ self.__dict__["model"] = model
207
+
208
+ state_dict = clap_load_state_dict(clap_ckpt_path)
209
+ self.model.model.load_state_dict(state_dict, strict=False)
210
+
211
+ if self.finetune:
212
+ self.model.model.audio_branch.requires_grad_(True)
213
+ self.model.model.audio_branch.train()
214
+ else:
215
+ self.model.model.audio_branch.requires_grad_(False)
216
+ self.model.model.audio_branch.eval()
217
+
218
+ finally:
219
+ logging.disable(previous_level)
220
+
221
+ del self.model.model.text_branch
222
+
223
+ gc.collect()
224
+ torch.cuda.empty_cache()
225
+
226
+ def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any:
227
+
228
+ self.model.to(device)
229
+
230
+ if isinstance(audios, list) or isinstance(audios, tuple):
231
+ audios = torch.cat(audios, dim=0)
232
+
233
+ # Convert to mono
234
+ mono_audios = audios.mean(dim=1)
235
+
236
+ with torch.cuda.amp.autocast(enabled=False):
237
+ audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True)
238
+
239
+ audio_embedding = audio_embedding.unsqueeze(1).to(device)
240
+
241
+ return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)]
242
+
243
+
244
+ class CLIPConditioner(Conditioner):
245
+ CLIP_MODELS = ["clip-vit-base-patch32"]
246
+
247
+ def __init__(
248
+ self,
249
+ output_dim: int,
250
+ clip_model_name: str = "clip-vit-base-patch32",
251
+ video_fps: int = 5,
252
+ out_features: str = 128,
253
+ enable_grad: bool = False,
254
+ in_features: int = 5000,
255
+ project_out: bool = False,
256
+ ):
257
+ assert clip_model_name in self.CLIP_MODELS, f"Unknown clip model name: {clip_model_name}"
258
+ super().__init__(dim = 768, output_dim=output_dim, project_out=project_out)
259
+
260
+ sa_depth=4
261
+ num_heads=16
262
+ dim_head=64
263
+ hidden_scale=4
264
+ duration = 10
265
+
266
+ self.clip_model_name=clip_model_name
267
+
268
+ if self.clip_model_name=='clip-vit-base-patch32':
269
+ out_features = 128
270
+ temporal_dim=768
271
+
272
+ self.empty_visual_feat = nn.Parameter(torch.zeros(1, out_features, temporal_dim), requires_grad=True)
273
+ nn.init.constant_(self.empty_visual_feat, 0)
274
+
275
+ in_features = 50*video_fps*duration
276
+
277
+ self.visual_encoder_model = CLIPVisionModelWithProjection.from_pretrained('openai/clip-vit-base-patch32')
278
+ self.proj = nn.Linear(in_features=in_features, out_features=out_features)
279
+
280
+ self.in_features = in_features
281
+ self.out_features = out_features
282
+
283
+ self.Temp_transformer = SA_Transformer(temporal_dim, sa_depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.)
284
+ self.Temp_pos_embedding = nn.Parameter(torch.randn(1, duration*video_fps, temporal_dim))
285
+
286
+ clip_mean = [0.48145466, 0.4578275, 0.40821073]
287
+ clip_std = [0.26862954, 0.26130258, 0.27577711]
288
+ self.preprocess_CLIP = transforms.Compose([
289
+ transforms.Normalize(mean=clip_mean, std=clip_std)
290
+ ])
291
+
292
+ def process_video_with_custom_preprocessing(self, video_tensor):
293
+ video_tensor = video_tensor / 255.0
294
+ video_tensor = self.preprocess_CLIP(video_tensor)
295
+ return video_tensor
296
+
297
+ def init_first_from_ckpt(self, path):
298
+ model = torch.load(path, map_location="cpu")
299
+ if "state_dict" in list(model.keys()):
300
+ model = model["state_dict"]
301
+ # Remove: module prefix
302
+ new_model = {}
303
+ for key in model.keys():
304
+ new_key = key.replace("module.","")
305
+ new_model[new_key] = model[key]
306
+ missing, unexpected = self.visual_encoder_model.load_state_dict(new_model, strict=False)
307
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
308
+ if len(missing) > 0:
309
+ print(f"Missing Keys: {missing}")
310
+ if len(unexpected) > 0:
311
+ print(f"Unexpected Keys: {unexpected}")
312
+
313
+ def forward(self, Video_tensors: tp.List[torch.Tensor], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
314
+ visual_encoder_model = self.visual_encoder_model.eval().to(device)
315
+ proj = self.proj.to(device)
316
+
317
+ original_videos = torch.cat(Video_tensors, dim=0).to(device)
318
+ batch_size, time_length, _, _, _ = original_videos.size()
319
+ is_zero = torch.all(original_videos == 0, dim=1)
320
+ is_zero = torch.all(is_zero, dim=1)
321
+ is_zero = torch.all(is_zero, dim=1)
322
+ is_zero = torch.all(is_zero, dim=1)
323
+ Video_tensors = original_videos
324
+ Video_tensors = einops.rearrange(Video_tensors, 'b t c h w -> (b t) c h w')
325
+
326
+ video_cond_pixel_values = self.process_video_with_custom_preprocessing(video_tensor=Video_tensors.to(device)).to(device)
327
+ if self.clip_model_name=='clip-vit-base-patch32':
328
+ with torch.no_grad():
329
+ outputs = visual_encoder_model(pixel_values=video_cond_pixel_values)
330
+ video_hidden = outputs.last_hidden_state
331
+
332
+ video_hidden = einops.rearrange(video_hidden, '(b t) q h -> (b q) t h',b=batch_size,t=time_length)
333
+ video_hidden += self.Temp_pos_embedding
334
+ video_hidden = self.Temp_transformer(video_hidden)
335
+ video_hidden = einops.rearrange(video_hidden, '(b q) t h -> b (t q) h',b=batch_size,t=time_length)
336
+
337
+ video_hidden = proj(video_hidden.view(-1, self.in_features))
338
+ video_hidden = video_hidden.view(batch_size, self.out_features, -1)
339
+
340
+ empty_visual_feat = self.empty_visual_feat.expand(batch_size, -1, -1)
341
+ is_zero_expanded = is_zero.view(batch_size, 1, 1)
342
+ video_hidden = torch.where(is_zero_expanded, empty_visual_feat, video_hidden)
343
+
344
+ return video_hidden, torch.ones(video_hidden.shape[0], 1).to(device)
345
+
346
+
347
+
348
+ class T5Conditioner(Conditioner):
349
+
350
+ T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
351
+ "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
352
+ "google/flan-t5-xl", "google/flan-t5-xxl"]
353
+
354
+ T5_MODEL_DIMS = {
355
+ "t5-small": 512,
356
+ "t5-base": 768,
357
+ "t5-large": 1024,
358
+ "t5-3b": 1024,
359
+ "t5-11b": 1024,
360
+ "t5-xl": 2048,
361
+ "t5-xxl": 4096,
362
+ "google/flan-t5-small": 512,
363
+ "google/flan-t5-base": 768,
364
+ "google/flan-t5-large": 1024,
365
+ "google/flan-t5-3b": 1024,
366
+ "google/flan-t5-11b": 1024,
367
+ "google/flan-t5-xl": 2048,
368
+ "google/flan-t5-xxl": 4096,
369
+ }
370
+
371
+ def __init__(
372
+ self,
373
+ output_dim: int,
374
+ t5_model_name: str = "t5-base",
375
+ max_length: str = 128,
376
+ enable_grad: bool = False,
377
+ project_out: bool = False,
378
+ ):
379
+ assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}"
380
+ super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out)
381
+
382
+ from transformers import T5EncoderModel, AutoTokenizer
383
+
384
+ self.max_length = max_length
385
+ self.enable_grad = enable_grad
386
+ # Suppress logging from transformers
387
+ previous_level = logging.root.manager.disable
388
+ logging.disable(logging.ERROR)
389
+ with warnings.catch_warnings():
390
+ warnings.simplefilter("ignore")
391
+ try:
392
+ self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
393
+ model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
394
+ finally:
395
+ logging.disable(previous_level)
396
+
397
+ if self.enable_grad:
398
+ self.model = model
399
+ else:
400
+ self.__dict__["model"] = model
401
+
402
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
403
+
404
+ self.model.to(device)
405
+ self.proj_out.to(device)
406
+
407
+ encoded = self.tokenizer(
408
+ texts,
409
+ truncation=True,
410
+ max_length=self.max_length,
411
+ padding="max_length",
412
+ return_tensors="pt",
413
+ )
414
+
415
+ input_ids = encoded["input_ids"].to(device)
416
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
417
+
418
+ self.model.eval()
419
+
420
+ with torch.cuda.amp.autocast(dtype=torch.float16), torch.set_grad_enabled(self.enable_grad):
421
+ embeddings = self.model(
422
+ input_ids=input_ids, attention_mask=attention_mask
423
+ )["last_hidden_state"]
424
+
425
+ embeddings = self.proj_out(embeddings.float())
426
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
427
+
428
+ return embeddings, attention_mask
429
+
430
+ class PhonemeConditioner(Conditioner):
431
+ """
432
+ A conditioner that turns text into phonemes and embeds them using a lookup table
433
+ Only works for English text
434
+
435
+ Args:
436
+ output_dim: the dimension of the output embeddings
437
+ max_length: the maximum number of phonemes to embed
438
+ project_out: whether to add another linear projection to the output embeddings
439
+ """
440
+
441
+ def __init__(
442
+ self,
443
+ output_dim: int,
444
+ max_length: int = 1024,
445
+ project_out: bool = False,
446
+ ):
447
+ super().__init__(output_dim, output_dim, project_out=project_out)
448
+
449
+ from g2p_en import G2p
450
+ self.max_length = max_length
451
+ self.g2p = G2p()
452
+ # Reserving 0 for padding, 1 for ignored
453
+ self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim)
454
+
455
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
456
+
457
+ self.phoneme_embedder.to(device)
458
+ self.proj_out.to(device)
459
+
460
+ batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length]
461
+ phoneme_ignore = [" ", *string.punctuation]
462
+ # Remove ignored phonemes and cut to max length
463
+ batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes]
464
+
465
+ # Convert to ids
466
+ phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes]
467
+
468
+ #Pad to match longest and make a mask tensor for the padding
469
+ longest = max([len(ids) for ids in phoneme_ids])
470
+ phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids]
471
+ phoneme_ids = torch.tensor(phoneme_ids).to(device)
472
+
473
+ # Convert to embeddings
474
+ phoneme_embeds = self.phoneme_embedder(phoneme_ids)
475
+ phoneme_embeds = self.proj_out(phoneme_embeds)
476
+
477
+ return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device)
478
+
479
+
480
+
481
+ class TokenizerLUTConditioner(Conditioner):
482
+ """
483
+ A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary
484
+
485
+ Args:
486
+ tokenizer_name: the name of the tokenizer from the Hugging Face transformers library
487
+ output_dim: the dimension of the output embeddings
488
+ max_length: the maximum length of the text to embed
489
+ project_out: whether to add another linear projection to the output embeddings
490
+ """
491
+
492
+ def __init__(
493
+ self,
494
+ tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library
495
+ output_dim: int,
496
+ max_length: int = 1024,
497
+ project_out: bool = False,
498
+ ):
499
+ super().__init__(output_dim, output_dim, project_out=project_out)
500
+
501
+ from transformers import AutoTokenizer
502
+
503
+ # Suppress logging from transformers
504
+ previous_level = logging.root.manager.disable
505
+ logging.disable(logging.ERROR)
506
+ with warnings.catch_warnings():
507
+ warnings.simplefilter("ignore")
508
+ try:
509
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
510
+ finally:
511
+ logging.disable(previous_level)
512
+
513
+ self.max_length = max_length
514
+
515
+ self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim)
516
+
517
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
518
+ self.proj_out.to(device)
519
+
520
+ encoded = self.tokenizer(
521
+ texts,
522
+ truncation=True,
523
+ max_length=self.max_length,
524
+ padding="max_length",
525
+ return_tensors="pt",
526
+ )
527
+
528
+ input_ids = encoded["input_ids"].to(device)
529
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
530
+
531
+ embeddings = self.token_embedder(input_ids)
532
+
533
+ embeddings = self.proj_out(embeddings)
534
+
535
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
536
+
537
+ return embeddings, attention_mask
538
+
539
+ class PretransformConditioner(Conditioner):
540
+ """
541
+ A conditioner that uses a pretransform's encoder for conditioning
542
+
543
+ Args:
544
+ pretransform: an instantiated pretransform to use for conditioning
545
+ output_dim: the dimension of the output embeddings
546
+ """
547
+ def __init__(self, pretransform: Pretransform, output_dim: int):
548
+ super().__init__(pretransform.encoded_channels, output_dim)
549
+
550
+ self.pretransform = pretransform
551
+
552
+ def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
553
+
554
+ self.pretransform.to(device)
555
+ self.proj_out.to(device)
556
+
557
+ if isinstance(audio, list) or isinstance(audio, tuple):
558
+ audio = torch.cat(audio, dim=0)
559
+
560
+ # Convert audio to pretransform input channels
561
+ audio = set_audio_channels(audio, self.pretransform.io_channels)
562
+
563
+ latents = self.pretransform.encode(audio)
564
+ latents = self.proj_out(latents)
565
+
566
+ return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
567
+
568
+
569
+ class AudioAutoencoderConditioner(Conditioner):
570
+ """
571
+ A conditioner that uses a pretransform's encoder for conditioning
572
+
573
+ Args:
574
+ pretransform: an instantiated pretransform to use for conditioning
575
+ output_dim: the dimension of the output embeddings
576
+ """
577
+ def __init__(self, pretransform: Pretransform, output_dim: int):
578
+ super().__init__(pretransform.encoded_channels, output_dim)
579
+
580
+ self.pretransform = pretransform
581
+ self.empty_audio_feat = nn.Parameter(torch.zeros(1, 215, self.proj_out.out_features), requires_grad=True)
582
+ nn.init.constant_(self.empty_audio_feat, 0)
583
+
584
+ def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
585
+
586
+ self.pretransform.to(device)
587
+ self.proj_out.to(device)
588
+
589
+ if isinstance(audio, list) or isinstance(audio, tuple):
590
+ original_audios = torch.cat(audio, dim=0).to(device)
591
+ is_zero = torch.all(original_audios == 0, dim=(1,2))
592
+ audio = original_audios
593
+
594
+ # Convert audio to pretransform input channels
595
+ audio = set_audio_channels(audio, self.pretransform.io_channels)
596
+
597
+ latents = self.pretransform.encode(audio)
598
+ latents = latents.permute(0, 2, 1)
599
+ latents = self.proj_out(latents)
600
+
601
+ empty_audio_feat = self.empty_audio_feat.expand(latents.shape[0], -1, -1)
602
+ is_zero_expanded = is_zero.view(latents.shape[0], 1, 1)
603
+ latents = torch.where(is_zero_expanded, empty_audio_feat, latents)
604
+ return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
605
+
606
+
607
+ class MultiConditioner(nn.Module):
608
+ """
609
+ A module that applies multiple conditioners to an input dictionary based on the keys
610
+
611
+ Args:
612
+ conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt")
613
+ default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"})
614
+ """
615
+ def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}):
616
+ super().__init__()
617
+
618
+ self.conditioners = nn.ModuleDict(conditioners)
619
+ self.default_keys = default_keys
620
+
621
+ def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]:
622
+ output = {}
623
+
624
+ for key, conditioner in self.conditioners.items():
625
+ condition_key = key
626
+
627
+ conditioner_inputs = []
628
+
629
+ for x in batch_metadata:
630
+
631
+ if condition_key not in x:
632
+ if condition_key in self.default_keys:
633
+ condition_key = self.default_keys[condition_key]
634
+ else:
635
+ raise ValueError(f"Conditioner key {condition_key} not found in batch metadata")
636
+
637
+ if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1:
638
+ conditioner_input = x[condition_key][0]
639
+
640
+ else:
641
+ conditioner_input = x[condition_key]
642
+
643
+ conditioner_inputs.append(conditioner_input)
644
+
645
+ output[key] = conditioner(conditioner_inputs, device)
646
+
647
+ return output
648
+
649
+ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner:
650
+ """
651
+ Create a MultiConditioner from a conditioning config dictionary
652
+
653
+ Args:
654
+ config: the conditioning config dictionary
655
+ device: the device to put the conditioners on
656
+ """
657
+ conditioners = {}
658
+ cond_dim = config["cond_dim"]
659
+
660
+ default_keys = config.get("default_keys", {})
661
+
662
+ for conditioner_info in config["configs"]:
663
+ id = conditioner_info["id"]
664
+
665
+ conditioner_type = conditioner_info["type"]
666
+
667
+ conditioner_config = {"output_dim": cond_dim}
668
+
669
+ conditioner_config.update(conditioner_info["config"])
670
+
671
+ if conditioner_type == "t5":
672
+ conditioners[id] = T5Conditioner(**conditioner_config)
673
+ elif conditioner_type == "clip":
674
+ conditioners[id] = CLIPConditioner(**conditioner_config)
675
+ elif conditioner_type == "clap_text":
676
+ conditioners[id] = CLAPTextConditioner(**conditioner_config)
677
+ elif conditioner_type == "clap_audio":
678
+ conditioners[id] = CLAPAudioConditioner(**conditioner_config)
679
+ elif conditioner_type == "int":
680
+ conditioners[id] = IntConditioner(**conditioner_config)
681
+ elif conditioner_type == "number":
682
+ conditioners[id] = NumberConditioner(**conditioner_config)
683
+ elif conditioner_type == "phoneme":
684
+ conditioners[id] = PhonemeConditioner(**conditioner_config)
685
+ elif conditioner_type == "lut":
686
+ conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
687
+ elif conditioner_type == "pretransform":
688
+ sample_rate = conditioner_config.pop("sample_rate", None)
689
+ assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
690
+
691
+ pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
692
+
693
+ if conditioner_config.get("pretransform_ckpt_path", None) is not None:
694
+ pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
695
+
696
+ conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
697
+
698
+ elif conditioner_type == "audio_autoencoder":
699
+ sample_rate = conditioner_config.pop("sample_rate", None)
700
+ assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
701
+
702
+ pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
703
+
704
+ if conditioner_config.get("pretransform_ckpt_path", None) is not None:
705
+ pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
706
+
707
+ conditioners[id] = AudioAutoencoderConditioner(pretransform, **conditioner_config)
708
+ else:
709
+ raise ValueError(f"Unknown conditioner type: {conditioner_type}")
710
+
711
+ return MultiConditioner(conditioners, default_keys=default_keys)
stable_audio_tools/models/diffusion.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from functools import partial
5
+ import numpy as np
6
+ import typing as tp
7
+
8
+ from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
9
+ from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
10
+ from .dit import DiffusionTransformer
11
+ from .factory import create_pretransform_from_config
12
+ from .pretransforms import Pretransform
13
+ from ..inference.generation import generate_diffusion_cond
14
+
15
+ from .adp import UNetCFG1d, UNet1d
16
+
17
+ from time import time
18
+
19
+ class Profiler:
20
+
21
+ def __init__(self):
22
+ self.ticks = [[time(), None]]
23
+
24
+ def tick(self, msg):
25
+ self.ticks.append([time(), msg])
26
+
27
+ def __repr__(self):
28
+ rep = 80 * "=" + "\n"
29
+ for i in range(1, len(self.ticks)):
30
+ msg = self.ticks[i][1]
31
+ ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
32
+ rep += msg + f": {ellapsed*1000:.2f}ms\n"
33
+ rep += 80 * "=" + "\n\n\n"
34
+ return rep
35
+
36
+ class DiffusionModel(nn.Module):
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+
40
+ def forward(self, x, t, **kwargs):
41
+ raise NotImplementedError()
42
+
43
+ class DiffusionModelWrapper(nn.Module):
44
+ def __init__(
45
+ self,
46
+ model: DiffusionModel,
47
+ io_channels,
48
+ sample_size,
49
+ sample_rate,
50
+ min_input_length,
51
+ pretransform: tp.Optional[Pretransform] = None,
52
+ ):
53
+ super().__init__()
54
+ self.io_channels = io_channels
55
+ self.sample_size = sample_size
56
+ self.sample_rate = sample_rate
57
+ self.min_input_length = min_input_length
58
+
59
+ self.model = model
60
+
61
+ if pretransform is not None:
62
+ self.pretransform = pretransform
63
+ else:
64
+ self.pretransform = None
65
+
66
+ def forward(self, x, t, **kwargs):
67
+ return self.model(x, t, **kwargs)
68
+
69
+ class ConditionedDiffusionModel(nn.Module):
70
+ def __init__(self,
71
+ *args,
72
+ supports_cross_attention: bool = False,
73
+ supports_input_concat: bool = False,
74
+ supports_global_cond: bool = False,
75
+ supports_prepend_cond: bool = False,
76
+ **kwargs):
77
+ super().__init__(*args, **kwargs)
78
+ self.supports_cross_attention = supports_cross_attention
79
+ self.supports_input_concat = supports_input_concat
80
+ self.supports_global_cond = supports_global_cond
81
+ self.supports_prepend_cond = supports_prepend_cond
82
+
83
+ def forward(self,
84
+ x: torch.Tensor,
85
+ t: torch.Tensor,
86
+ cross_attn_cond: torch.Tensor = None,
87
+ cross_attn_mask: torch.Tensor = None,
88
+ input_concat_cond: torch.Tensor = None,
89
+ global_embed: torch.Tensor = None,
90
+ prepend_cond: torch.Tensor = None,
91
+ prepend_cond_mask: torch.Tensor = None,
92
+ cfg_scale: float = 1.0,
93
+ cfg_dropout_prob: float = 0.0,
94
+ batch_cfg: bool = False,
95
+ rescale_cfg: bool = False,
96
+ **kwargs):
97
+ raise NotImplementedError()
98
+
99
+ class ConditionedDiffusionModelWrapper(nn.Module):
100
+ """
101
+ A diffusion model that takes in conditioning
102
+ """
103
+ def __init__(
104
+ self,
105
+ model: ConditionedDiffusionModel,
106
+ conditioner: MultiConditioner,
107
+ io_channels,
108
+ sample_rate,
109
+ min_input_length: int,
110
+ diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
111
+ pretransform: tp.Optional[Pretransform] = None,
112
+ cross_attn_cond_ids: tp.List[str] = [],
113
+ global_cond_ids: tp.List[str] = [],
114
+ input_concat_ids: tp.List[str] = [],
115
+ prepend_cond_ids: tp.List[str] = [],
116
+ ):
117
+ super().__init__()
118
+
119
+ self.model = model
120
+ self.conditioner = conditioner
121
+ self.io_channels = io_channels
122
+ self.sample_rate = sample_rate
123
+ self.diffusion_objective = diffusion_objective
124
+ self.pretransform = pretransform
125
+ self.cross_attn_cond_ids = cross_attn_cond_ids # ['prompt', 'seconds_start', 'seconds_total']
126
+ self.global_cond_ids = global_cond_ids # ['seconds_start', 'seconds_total']
127
+ self.input_concat_ids = input_concat_ids
128
+ self.prepend_cond_ids = prepend_cond_ids
129
+ self.min_input_length = min_input_length
130
+
131
+ def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[torch.Tensor, tp.Any], negative=False):
132
+ cross_attention_input = None
133
+ cross_attention_masks = None
134
+ global_cond = None
135
+ input_concat_cond = None
136
+ prepend_cond = None
137
+ prepend_cond_mask = None
138
+
139
+ if len(self.cross_attn_cond_ids) > 0:
140
+ # Concatenate all cross-attention inputs over the sequence dimension
141
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
142
+ cross_attention_input = []
143
+ cross_attention_masks = []
144
+
145
+ for key in self.cross_attn_cond_ids:
146
+ cross_attn_in, cross_attn_mask = conditioning_tensors[key]
147
+
148
+ # Add sequence dimension if it's not there
149
+ if len(cross_attn_in.shape) == 2:
150
+ cross_attn_in = cross_attn_in.unsqueeze(1)
151
+ cross_attn_mask = cross_attn_mask.unsqueeze(1)
152
+
153
+ cross_attention_input.append(cross_attn_in)
154
+ cross_attention_masks.append(cross_attn_mask)
155
+
156
+ cross_attention_input = torch.cat(cross_attention_input, dim=1) # [1, 130, 768] (text feature:128)
157
+ cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
158
+
159
+ if len(self.global_cond_ids) > 0:
160
+ # Concatenate all global conditioning inputs over the channel dimension
161
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
162
+ global_conds = []
163
+ for key in self.global_cond_ids:
164
+
165
+ global_cond_input = conditioning_tensors[key][0]
166
+
167
+ global_conds.append(global_cond_input)
168
+
169
+ # Concatenate over the channel dimension
170
+ global_cond = torch.cat(global_conds, dim=-1)
171
+
172
+ if len(global_cond.shape) == 3:
173
+ global_cond = global_cond.squeeze(1)
174
+
175
+ if len(self.input_concat_ids) > 0: # False
176
+ # Concatenate all input concat conditioning inputs over the channel dimension
177
+ # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
178
+ input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
179
+
180
+ if len(self.prepend_cond_ids) > 0: # False
181
+ # Concatenate all prepend conditioning inputs over the sequence dimension
182
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
183
+ prepend_conds = []
184
+ prepend_cond_masks = []
185
+
186
+ for key in self.prepend_cond_ids:
187
+ prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
188
+ prepend_conds.append(prepend_cond_input)
189
+ prepend_cond_masks.append(prepend_cond_mask)
190
+
191
+ prepend_cond = torch.cat(prepend_conds, dim=1)
192
+ prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
193
+
194
+ if negative: # False
195
+ return {
196
+ "negative_cross_attn_cond": cross_attention_input,
197
+ "negative_cross_attn_mask": cross_attention_masks,
198
+ "negative_global_cond": global_cond,
199
+ "negative_input_concat_cond": input_concat_cond
200
+ }
201
+ else:
202
+ return {
203
+ "cross_attn_cond": cross_attention_input,
204
+ "cross_attn_mask": cross_attention_masks,
205
+ "global_cond": global_cond,
206
+ "input_concat_cond": input_concat_cond,
207
+ "prepend_cond": prepend_cond,
208
+ "prepend_cond_mask": prepend_cond_mask
209
+ }
210
+
211
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
212
+ return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
213
+
214
+ def generate(self, *args, **kwargs):
215
+ return generate_diffusion_cond(self, *args, **kwargs)
216
+
217
+ class UNetCFG1DWrapper(ConditionedDiffusionModel):
218
+ def __init__(
219
+ self,
220
+ *args,
221
+ **kwargs
222
+ ):
223
+ super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
224
+
225
+ self.model = UNetCFG1d(*args, **kwargs)
226
+
227
+ with torch.no_grad():
228
+ for param in self.model.parameters():
229
+ param *= 0.5
230
+
231
+ def forward(self,
232
+ x,
233
+ t,
234
+ cross_attn_cond=None,
235
+ cross_attn_mask=None,
236
+ input_concat_cond=None,
237
+ global_cond=None,
238
+ cfg_scale=1.0,
239
+ cfg_dropout_prob: float = 0.0,
240
+ batch_cfg: bool = False,
241
+ rescale_cfg: bool = False,
242
+ negative_cross_attn_cond=None,
243
+ negative_cross_attn_mask=None,
244
+ negative_global_cond=None,
245
+ negative_input_concat_cond=None,
246
+ prepend_cond=None,
247
+ prepend_cond_mask=None,
248
+ **kwargs):
249
+ p = Profiler()
250
+
251
+ p.tick("start")
252
+
253
+ channels_list = None
254
+ if input_concat_cond is not None:
255
+ channels_list = [input_concat_cond]
256
+
257
+ outputs = self.model(
258
+ x,
259
+ t,
260
+ embedding=cross_attn_cond,
261
+ embedding_mask=cross_attn_mask,
262
+ features=global_cond,
263
+ channels_list=channels_list,
264
+ embedding_scale=cfg_scale,
265
+ embedding_mask_proba=cfg_dropout_prob,
266
+ batch_cfg=batch_cfg,
267
+ rescale_cfg=rescale_cfg,
268
+ negative_embedding=negative_cross_attn_cond,
269
+ negative_embedding_mask=negative_cross_attn_mask,
270
+ **kwargs)
271
+
272
+ p.tick("UNetCFG1D forward")
273
+
274
+ #print(f"Profiler: {p}")
275
+ return outputs
276
+
277
+ class UNet1DCondWrapper(ConditionedDiffusionModel):
278
+ def __init__(
279
+ self,
280
+ *args,
281
+ **kwargs
282
+ ):
283
+ super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
284
+
285
+ self.model = UNet1d(*args, **kwargs)
286
+
287
+ with torch.no_grad():
288
+ for param in self.model.parameters():
289
+ param *= 0.5
290
+
291
+ def forward(self,
292
+ x,
293
+ t,
294
+ input_concat_cond=None,
295
+ global_cond=None,
296
+ cross_attn_cond=None,
297
+ cross_attn_mask=None,
298
+ prepend_cond=None,
299
+ prepend_cond_mask=None,
300
+ cfg_scale=1.0,
301
+ cfg_dropout_prob: float = 0.0,
302
+ batch_cfg: bool = False,
303
+ rescale_cfg: bool = False,
304
+ negative_cross_attn_cond=None,
305
+ negative_cross_attn_mask=None,
306
+ negative_global_cond=None,
307
+ negative_input_concat_cond=None,
308
+ **kwargs):
309
+
310
+ channels_list = None
311
+ if input_concat_cond is not None:
312
+
313
+ # Interpolate input_concat_cond to the same length as x
314
+ if input_concat_cond.shape[2] != x.shape[2]:
315
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
316
+
317
+ channels_list = [input_concat_cond]
318
+
319
+ outputs = self.model(
320
+ x,
321
+ t,
322
+ features=global_cond,
323
+ channels_list=channels_list,
324
+ **kwargs)
325
+
326
+ return outputs
327
+
328
+ class UNet1DUncondWrapper(DiffusionModel):
329
+ def __init__(
330
+ self,
331
+ in_channels,
332
+ *args,
333
+ **kwargs
334
+ ):
335
+ super().__init__()
336
+
337
+ self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
338
+
339
+ self.io_channels = in_channels
340
+
341
+ with torch.no_grad():
342
+ for param in self.model.parameters():
343
+ param *= 0.5
344
+
345
+ def forward(self, x, t, **kwargs):
346
+ return self.model(x, t, **kwargs)
347
+
348
+ class DAU1DCondWrapper(ConditionedDiffusionModel):
349
+ def __init__(
350
+ self,
351
+ *args,
352
+ **kwargs
353
+ ):
354
+ super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
355
+
356
+ self.model = DiffusionAttnUnet1D(*args, **kwargs)
357
+
358
+ with torch.no_grad():
359
+ for param in self.model.parameters():
360
+ param *= 0.5
361
+
362
+ def forward(self,
363
+ x,
364
+ t,
365
+ input_concat_cond=None,
366
+ cross_attn_cond=None,
367
+ cross_attn_mask=None,
368
+ global_cond=None,
369
+ cfg_scale=1.0,
370
+ cfg_dropout_prob: float = 0.0,
371
+ batch_cfg: bool = False,
372
+ rescale_cfg: bool = False,
373
+ negative_cross_attn_cond=None,
374
+ negative_cross_attn_mask=None,
375
+ negative_global_cond=None,
376
+ negative_input_concat_cond=None,
377
+ prepend_cond=None,
378
+ **kwargs):
379
+
380
+ return self.model(x, t, cond = input_concat_cond)
381
+
382
+ class DiffusionAttnUnet1D(nn.Module):
383
+ def __init__(
384
+ self,
385
+ io_channels = 2,
386
+ depth=14,
387
+ n_attn_layers = 6,
388
+ channels = [128, 128, 256, 256] + [512] * 10,
389
+ cond_dim = 0,
390
+ cond_noise_aug = False,
391
+ kernel_size = 5,
392
+ learned_resample = False,
393
+ strides = [2] * 13,
394
+ conv_bias = True,
395
+ use_snake = False
396
+ ):
397
+ super().__init__()
398
+
399
+ self.cond_noise_aug = cond_noise_aug
400
+
401
+ self.io_channels = io_channels
402
+
403
+ if self.cond_noise_aug:
404
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
405
+
406
+ self.timestep_embed = FourierFeatures(1, 16)
407
+
408
+ attn_layer = depth - n_attn_layers
409
+
410
+ strides = [1] + strides
411
+
412
+ block = nn.Identity()
413
+
414
+ conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
415
+
416
+ for i in range(depth, 0, -1):
417
+ c = channels[i - 1]
418
+ stride = strides[i-1]
419
+ if stride > 2 and not learned_resample:
420
+ raise ValueError("Must have stride 2 without learned resampling")
421
+
422
+ if i > 1:
423
+ c_prev = channels[i - 2]
424
+ add_attn = i >= attn_layer and n_attn_layers > 0
425
+ block = SkipBlock(
426
+ Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
427
+ conv_block(c_prev, c, c),
428
+ SelfAttention1d(
429
+ c, c // 32) if add_attn else nn.Identity(),
430
+ conv_block(c, c, c),
431
+ SelfAttention1d(
432
+ c, c // 32) if add_attn else nn.Identity(),
433
+ conv_block(c, c, c),
434
+ SelfAttention1d(
435
+ c, c // 32) if add_attn else nn.Identity(),
436
+ block,
437
+ conv_block(c * 2 if i != depth else c, c, c),
438
+ SelfAttention1d(
439
+ c, c // 32) if add_attn else nn.Identity(),
440
+ conv_block(c, c, c),
441
+ SelfAttention1d(
442
+ c, c // 32) if add_attn else nn.Identity(),
443
+ conv_block(c, c, c_prev),
444
+ SelfAttention1d(c_prev, c_prev //
445
+ 32) if add_attn else nn.Identity(),
446
+ Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
447
+ )
448
+ else:
449
+ cond_embed_dim = 16 if not self.cond_noise_aug else 32
450
+ block = nn.Sequential(
451
+ conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
452
+ conv_block(c, c, c),
453
+ conv_block(c, c, c),
454
+ block,
455
+ conv_block(c * 2, c, c),
456
+ conv_block(c, c, c),
457
+ conv_block(c, c, io_channels, is_last=True),
458
+ )
459
+ self.net = block
460
+
461
+ with torch.no_grad():
462
+ for param in self.net.parameters():
463
+ param *= 0.5
464
+
465
+ def forward(self, x, t, cond=None, cond_aug_scale=None):
466
+
467
+ timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
468
+
469
+ inputs = [x, timestep_embed]
470
+
471
+ if cond is not None:
472
+ if cond.shape[2] != x.shape[2]:
473
+ cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
474
+
475
+ if self.cond_noise_aug:
476
+ # Get a random number between 0 and 1, uniformly sampled
477
+ if cond_aug_scale is None:
478
+ aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
479
+ else:
480
+ aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
481
+
482
+ # Add noise to the conditioning signal
483
+ cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
484
+
485
+ # Get embedding for noise cond level, reusing timestamp_embed
486
+ aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
487
+
488
+ inputs.append(aug_level_embed)
489
+
490
+ inputs.append(cond)
491
+
492
+ outputs = self.net(torch.cat(inputs, dim=1))
493
+
494
+ return outputs
495
+
496
+ class DiTWrapper(ConditionedDiffusionModel):
497
+ def __init__(
498
+ self,
499
+ *args,
500
+ **kwargs
501
+ ):
502
+ super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
503
+
504
+ self.model = DiffusionTransformer(*args, **kwargs)
505
+
506
+ with torch.no_grad():
507
+ for param in self.model.parameters():
508
+ param *= 0.5
509
+
510
+ def forward(self,
511
+ x,
512
+ t,
513
+ cross_attn_cond=None,
514
+ cross_attn_mask=None,
515
+ negative_cross_attn_cond=None,
516
+ negative_cross_attn_mask=None,
517
+ input_concat_cond=None,
518
+ negative_input_concat_cond=None,
519
+ global_cond=None,
520
+ negative_global_cond=None,
521
+ prepend_cond=None,
522
+ prepend_cond_mask=None,
523
+ cfg_scale=1.0,
524
+ cfg_dropout_prob: float = 0.0,
525
+ batch_cfg: bool = True,
526
+ rescale_cfg: bool = False,
527
+ scale_phi: float = 0.0,
528
+ **kwargs):
529
+
530
+ assert batch_cfg, "batch_cfg must be True for DiTWrapper"
531
+ #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
532
+
533
+ return self.model(
534
+ x,
535
+ t,
536
+ cross_attn_cond=cross_attn_cond,
537
+ cross_attn_cond_mask=cross_attn_mask,
538
+ negative_cross_attn_cond=negative_cross_attn_cond,
539
+ negative_cross_attn_mask=negative_cross_attn_mask,
540
+ input_concat_cond=input_concat_cond,
541
+ prepend_cond=prepend_cond,
542
+ prepend_cond_mask=prepend_cond_mask,
543
+ cfg_scale=cfg_scale,
544
+ cfg_dropout_prob=cfg_dropout_prob,
545
+ scale_phi=scale_phi,
546
+ global_embed=global_cond,
547
+ **kwargs)
548
+
549
+ class DiTUncondWrapper(DiffusionModel):
550
+ def __init__(
551
+ self,
552
+ in_channels,
553
+ *args,
554
+ **kwargs
555
+ ):
556
+ super().__init__()
557
+
558
+ self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs)
559
+
560
+ self.io_channels = in_channels
561
+
562
+ with torch.no_grad():
563
+ for param in self.model.parameters():
564
+ param *= 0.5
565
+
566
+ def forward(self, x, t, **kwargs):
567
+ return self.model(x, t, **kwargs)
568
+
569
+ def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
570
+ diffusion_uncond_config = config["model"]
571
+
572
+ model_type = diffusion_uncond_config.get('type', None)
573
+
574
+ diffusion_config = diffusion_uncond_config.get('config', {})
575
+
576
+ assert model_type is not None, "Must specify model type in config"
577
+
578
+ pretransform = diffusion_uncond_config.get("pretransform", None)
579
+
580
+ sample_size = config.get("sample_size", None)
581
+ assert sample_size is not None, "Must specify sample size in config"
582
+
583
+ sample_rate = config.get("sample_rate", None)
584
+ assert sample_rate is not None, "Must specify sample rate in config"
585
+
586
+ if pretransform is not None:
587
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
588
+ min_input_length = pretransform.downsampling_ratio
589
+ else:
590
+ min_input_length = 1
591
+
592
+ if model_type == 'DAU1d':
593
+
594
+ model = DiffusionAttnUnet1D(
595
+ **diffusion_config
596
+ )
597
+
598
+ elif model_type == "adp_uncond_1d":
599
+
600
+ model = UNet1DUncondWrapper(
601
+ **diffusion_config
602
+ )
603
+
604
+ elif model_type == "dit":
605
+ model = DiTUncondWrapper(
606
+ **diffusion_config
607
+ )
608
+
609
+ else:
610
+ raise NotImplementedError(f'Unknown model type: {model_type}')
611
+
612
+ return DiffusionModelWrapper(model,
613
+ io_channels=model.io_channels,
614
+ sample_size=sample_size,
615
+ sample_rate=sample_rate,
616
+ pretransform=pretransform,
617
+ min_input_length=min_input_length)
618
+
619
+ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
620
+
621
+ model_config = config["model"]
622
+
623
+ model_type = config["model_type"]
624
+
625
+ diffusion_config = model_config.get('diffusion', None)
626
+ assert diffusion_config is not None, "Must specify diffusion config"
627
+
628
+ diffusion_model_type = diffusion_config.get('type', None)
629
+ assert diffusion_model_type is not None, "Must specify diffusion model type"
630
+
631
+ diffusion_model_config = diffusion_config.get('config', None)
632
+ if diffusion_model_config.get('video_fps', None) is not None:
633
+ diffusion_model_config.pop('video_fps')
634
+ assert diffusion_model_config is not None, "Must specify diffusion model config"
635
+
636
+ if diffusion_model_type == 'adp_cfg_1d':
637
+ diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
638
+ elif diffusion_model_type == 'adp_1d':
639
+ diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
640
+ elif diffusion_model_type == 'dit':
641
+ diffusion_model = DiTWrapper(**diffusion_model_config)
642
+
643
+ io_channels = model_config.get('io_channels', None)
644
+ assert io_channels is not None, "Must specify io_channels in model config"
645
+
646
+ sample_rate = config.get('sample_rate', None)
647
+ assert sample_rate is not None, "Must specify sample_rate in config"
648
+
649
+ diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
650
+
651
+ conditioning_config = model_config.get('conditioning', None)
652
+
653
+ conditioner = None
654
+ if conditioning_config is not None:
655
+ conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
656
+
657
+ cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
658
+ global_cond_ids = diffusion_config.get('global_cond_ids', [])
659
+ input_concat_ids = diffusion_config.get('input_concat_ids', [])
660
+ prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
661
+
662
+ pretransform = model_config.get("pretransform", None)
663
+
664
+ if pretransform is not None:
665
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
666
+ min_input_length = pretransform.downsampling_ratio
667
+ else:
668
+ min_input_length = 1
669
+
670
+ if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
671
+ min_input_length *= np.prod(diffusion_model_config["factors"])
672
+ elif diffusion_model_type == "dit":
673
+ min_input_length *= diffusion_model.model.patch_size
674
+
675
+ # Get the proper wrapper class
676
+
677
+ extra_kwargs = {}
678
+
679
+ if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint":
680
+ wrapper_fn = ConditionedDiffusionModelWrapper
681
+
682
+ extra_kwargs["diffusion_objective"] = diffusion_objective
683
+
684
+ elif model_type == "diffusion_prior":
685
+ prior_type = model_config.get("prior_type", None)
686
+ assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
687
+
688
+ if prior_type == "mono_stereo":
689
+ from .diffusion_prior import MonoToStereoDiffusionPrior
690
+ wrapper_fn = MonoToStereoDiffusionPrior
691
+
692
+ return wrapper_fn(
693
+ diffusion_model,
694
+ conditioner,
695
+ min_input_length=min_input_length,
696
+ sample_rate=sample_rate,
697
+ cross_attn_cond_ids=cross_attention_ids,
698
+ global_cond_ids=global_cond_ids,
699
+ input_concat_ids=input_concat_ids,
700
+ prepend_cond_ids=prepend_cond_ids,
701
+ pretransform=pretransform,
702
+ io_channels=io_channels,
703
+ **extra_kwargs
704
+ )
stable_audio_tools/models/discriminators.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from functools import reduce
6
+ import typing as tp
7
+ from einops import rearrange
8
+ from audiotools import AudioSignal, STFTParams
9
+ from dac.model.discriminator import WNConv1d, WNConv2d
10
+
11
+ def get_hinge_losses(score_real, score_fake):
12
+ gen_loss = -score_fake.mean()
13
+ dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean()
14
+ return dis_loss, gen_loss
15
+
16
+ class EncodecDiscriminator(nn.Module):
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ super().__init__()
20
+
21
+ from encodec.msstftd import MultiScaleSTFTDiscriminator
22
+
23
+ self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs)
24
+
25
+ def forward(self, x):
26
+ logits, features = self.discriminators(x)
27
+ return logits, features
28
+
29
+ def loss(self, x, y):
30
+ feature_matching_distance = 0.
31
+ logits_true, feature_true = self.forward(x)
32
+ logits_fake, feature_fake = self.forward(y)
33
+
34
+ dis_loss = torch.tensor(0.)
35
+ adv_loss = torch.tensor(0.)
36
+
37
+ for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)):
38
+
39
+ feature_matching_distance = feature_matching_distance + sum(
40
+ map(
41
+ lambda x, y: abs(x - y).mean(),
42
+ scale_true,
43
+ scale_fake,
44
+ )) / len(scale_true)
45
+
46
+ _dis, _adv = get_hinge_losses(
47
+ logits_true[i],
48
+ logits_fake[i],
49
+ )
50
+
51
+ dis_loss = dis_loss + _dis
52
+ adv_loss = adv_loss + _adv
53
+
54
+ return dis_loss, adv_loss, feature_matching_distance
55
+
56
+ # Discriminators from oobleck
57
+
58
+ IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]]
59
+
60
+ TensorDict = tp.Dict[str, torch.Tensor]
61
+
62
+ class SharedDiscriminatorConvNet(nn.Module):
63
+
64
+ def __init__(
65
+ self,
66
+ in_size: int,
67
+ convolution: tp.Union[nn.Conv1d, nn.Conv2d],
68
+ out_size: int = 1,
69
+ capacity: int = 32,
70
+ n_layers: int = 4,
71
+ kernel_size: int = 15,
72
+ stride: int = 4,
73
+ activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(),
74
+ normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm,
75
+ ) -> None:
76
+ super().__init__()
77
+ channels = [in_size]
78
+ channels += list(capacity * 2**np.arange(n_layers))
79
+
80
+ if isinstance(stride, int):
81
+ stride = n_layers * [stride]
82
+
83
+ net = []
84
+ for i in range(n_layers):
85
+ if isinstance(kernel_size, int):
86
+ pad = kernel_size // 2
87
+ s = stride[i]
88
+ else:
89
+ pad = kernel_size[0] // 2
90
+ s = (stride[i], 1)
91
+
92
+ net.append(
93
+ normalization(
94
+ convolution(
95
+ channels[i],
96
+ channels[i + 1],
97
+ kernel_size,
98
+ stride=s,
99
+ padding=pad,
100
+ )))
101
+ net.append(activation())
102
+
103
+ net.append(convolution(channels[-1], out_size, 1))
104
+
105
+ self.net = nn.ModuleList(net)
106
+
107
+ def forward(self, x) -> IndividualDiscriminatorOut:
108
+ features = []
109
+ for layer in self.net:
110
+ x = layer(x)
111
+ if isinstance(layer, nn.modules.conv._ConvNd):
112
+ features.append(x)
113
+ score = x.reshape(x.shape[0], -1).mean(-1)
114
+ return score, features
115
+
116
+
117
+ class MultiScaleDiscriminator(nn.Module):
118
+
119
+ def __init__(self,
120
+ in_channels: int,
121
+ n_scales: int,
122
+ **conv_kwargs) -> None:
123
+ super().__init__()
124
+ layers = []
125
+ for _ in range(n_scales):
126
+ layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs))
127
+ self.layers = nn.ModuleList(layers)
128
+
129
+ def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
130
+ score = 0
131
+ features = []
132
+ for layer in self.layers:
133
+ s, f = layer(x)
134
+ score = score + s
135
+ features.extend(f)
136
+ x = nn.functional.avg_pool1d(x, 2)
137
+ return score, features
138
+
139
+ class MultiPeriodDiscriminator(nn.Module):
140
+
141
+ def __init__(self,
142
+ in_channels: int,
143
+ periods: tp.Sequence[int],
144
+ **conv_kwargs) -> None:
145
+ super().__init__()
146
+ layers = []
147
+ self.periods = periods
148
+
149
+ for _ in periods:
150
+ layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs))
151
+
152
+ self.layers = nn.ModuleList(layers)
153
+
154
+ def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
155
+ score = 0
156
+ features = []
157
+ for layer, n in zip(self.layers, self.periods):
158
+ s, f = layer(self.fold(x, n))
159
+ score = score + s
160
+ features.extend(f)
161
+ return score, features
162
+
163
+ def fold(self, x: torch.Tensor, n: int) -> torch.Tensor:
164
+ pad = (n - (x.shape[-1] % n)) % n
165
+ x = nn.functional.pad(x, (0, pad))
166
+ return x.reshape(*x.shape[:2], -1, n)
167
+
168
+
169
+ class MultiDiscriminator(nn.Module):
170
+ """
171
+ Individual discriminators should take a single tensor as input (NxB C T) and
172
+ return a tuple composed of a score tensor (NxB) and a Sequence of Features
173
+ Sequence[NxB C' T'].
174
+ """
175
+
176
+ def __init__(self, discriminator_list: tp.Sequence[nn.Module],
177
+ keys: tp.Sequence[str]) -> None:
178
+ super().__init__()
179
+ self.discriminators = nn.ModuleList(discriminator_list)
180
+ self.keys = keys
181
+
182
+ def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict:
183
+ features = features.chunk(len(self.keys), 0)
184
+ return {k: features[i] for i, k in enumerate(self.keys)}
185
+
186
+ @staticmethod
187
+ def concat_dicts(dict_a, dict_b):
188
+ out_dict = {}
189
+ keys = set(list(dict_a.keys()) + list(dict_b.keys()))
190
+ for k in keys:
191
+ out_dict[k] = []
192
+ if k in dict_a:
193
+ if isinstance(dict_a[k], list):
194
+ out_dict[k].extend(dict_a[k])
195
+ else:
196
+ out_dict[k].append(dict_a[k])
197
+ if k in dict_b:
198
+ if isinstance(dict_b[k], list):
199
+ out_dict[k].extend(dict_b[k])
200
+ else:
201
+ out_dict[k].append(dict_b[k])
202
+ return out_dict
203
+
204
+ @staticmethod
205
+ def sum_dicts(dict_a, dict_b):
206
+ out_dict = {}
207
+ keys = set(list(dict_a.keys()) + list(dict_b.keys()))
208
+ for k in keys:
209
+ out_dict[k] = 0.
210
+ if k in dict_a:
211
+ out_dict[k] = out_dict[k] + dict_a[k]
212
+ if k in dict_b:
213
+ out_dict[k] = out_dict[k] + dict_b[k]
214
+ return out_dict
215
+
216
+ def forward(self, inputs: TensorDict) -> TensorDict:
217
+ discriminator_input = torch.cat([inputs[k] for k in self.keys], 0)
218
+ all_scores = []
219
+ all_features = []
220
+
221
+ for discriminator in self.discriminators:
222
+ score, features = discriminator(discriminator_input)
223
+ scores = self.unpack_tensor_to_dict(score)
224
+ scores = {f"score_{k}": scores[k] for k in scores.keys()}
225
+ all_scores.append(scores)
226
+
227
+ features = map(self.unpack_tensor_to_dict, features)
228
+ features = reduce(self.concat_dicts, features)
229
+ features = {f"features_{k}": features[k] for k in features.keys()}
230
+ all_features.append(features)
231
+
232
+ all_scores = reduce(self.sum_dicts, all_scores)
233
+ all_features = reduce(self.concat_dicts, all_features)
234
+
235
+ inputs.update(all_scores)
236
+ inputs.update(all_features)
237
+
238
+ return inputs
239
+
240
+ class OobleckDiscriminator(nn.Module):
241
+
242
+ def __init__(
243
+ self,
244
+ in_channels=1,
245
+ ):
246
+ super().__init__()
247
+
248
+ multi_scale_discriminator = MultiScaleDiscriminator(
249
+ in_channels=in_channels,
250
+ n_scales=3,
251
+ )
252
+
253
+ multi_period_discriminator = MultiPeriodDiscriminator(
254
+ in_channels=in_channels,
255
+ periods=[2, 3, 5, 7, 11]
256
+ )
257
+
258
+ # multi_resolution_discriminator = MultiScaleSTFTDiscriminator(
259
+ # filters=32,
260
+ # in_channels = in_channels,
261
+ # out_channels = 1,
262
+ # n_ffts = [2048, 1024, 512, 256, 128],
263
+ # hop_lengths = [512, 256, 128, 64, 32],
264
+ # win_lengths = [2048, 1024, 512, 256, 128]
265
+ # )
266
+
267
+ self.multi_discriminator = MultiDiscriminator(
268
+ [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator],
269
+ ["reals", "fakes"]
270
+ )
271
+
272
+ def loss(self, reals, fakes):
273
+ inputs = {
274
+ "reals": reals,
275
+ "fakes": fakes,
276
+ }
277
+
278
+ inputs = self.multi_discriminator(inputs)
279
+
280
+ scores_real = inputs["score_reals"]
281
+ scores_fake = inputs["score_fakes"]
282
+
283
+ features_real = inputs["features_reals"]
284
+ features_fake = inputs["features_fakes"]
285
+
286
+ dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake)
287
+
288
+ feature_matching_distance = torch.tensor(0.)
289
+
290
+ for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)):
291
+
292
+ feature_matching_distance = feature_matching_distance + sum(
293
+ map(
294
+ lambda real, fake: abs(real - fake).mean(),
295
+ scale_real,
296
+ scale_fake,
297
+ )) / len(scale_real)
298
+
299
+ return dis_loss, gen_loss, feature_matching_distance
300
+
301
+
302
+ ## Discriminators from Descript Audio Codec repo
303
+ ## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt
304
+ class MPD(nn.Module):
305
+ def __init__(self, period, channels=1):
306
+ super().__init__()
307
+
308
+ self.period = period
309
+ self.convs = nn.ModuleList(
310
+ [
311
+ WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)),
312
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
313
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
314
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
315
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
316
+ ]
317
+ )
318
+ self.conv_post = WNConv2d(
319
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
320
+ )
321
+
322
+ def pad_to_period(self, x):
323
+ t = x.shape[-1]
324
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
325
+ return x
326
+
327
+ def forward(self, x):
328
+ fmap = []
329
+
330
+ x = self.pad_to_period(x)
331
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
332
+
333
+ for layer in self.convs:
334
+ x = layer(x)
335
+ fmap.append(x)
336
+
337
+ x = self.conv_post(x)
338
+ fmap.append(x)
339
+
340
+ return fmap
341
+
342
+
343
+ class MSD(nn.Module):
344
+ def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1):
345
+ super().__init__()
346
+
347
+ self.convs = nn.ModuleList(
348
+ [
349
+ WNConv1d(channels, 16, 15, 1, padding=7),
350
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
351
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
352
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
353
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
354
+ WNConv1d(1024, 1024, 5, 1, padding=2),
355
+ ]
356
+ )
357
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
358
+ self.sample_rate = sample_rate
359
+ self.rate = rate
360
+
361
+ def forward(self, x):
362
+ x = AudioSignal(x, self.sample_rate)
363
+ x.resample(self.sample_rate // self.rate)
364
+ x = x.audio_data
365
+
366
+ fmap = []
367
+
368
+ for l in self.convs:
369
+ x = l(x)
370
+ fmap.append(x)
371
+ x = self.conv_post(x)
372
+ fmap.append(x)
373
+
374
+ return fmap
375
+
376
+
377
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
378
+
379
+
380
+ class MRD(nn.Module):
381
+ def __init__(
382
+ self,
383
+ window_length: int,
384
+ hop_factor: float = 0.25,
385
+ sample_rate: int = 44100,
386
+ bands: list = BANDS,
387
+ channels: int = 1
388
+ ):
389
+ """Complex multi-band spectrogram discriminator.
390
+ Parameters
391
+ ----------
392
+ window_length : int
393
+ Window length of STFT.
394
+ hop_factor : float, optional
395
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
396
+ sample_rate : int, optional
397
+ Sampling rate of audio in Hz, by default 44100
398
+ bands : list, optional
399
+ Bands to run discriminator over.
400
+ """
401
+ super().__init__()
402
+
403
+ self.window_length = window_length
404
+ self.hop_factor = hop_factor
405
+ self.sample_rate = sample_rate
406
+ self.stft_params = STFTParams(
407
+ window_length=window_length,
408
+ hop_length=int(window_length * hop_factor),
409
+ match_stride=True,
410
+ )
411
+
412
+ self.channels = channels
413
+
414
+ n_fft = window_length // 2 + 1
415
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
416
+ self.bands = bands
417
+
418
+ ch = 32
419
+ convs = lambda: nn.ModuleList(
420
+ [
421
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
422
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
423
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
424
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
425
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
426
+ ]
427
+ )
428
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
429
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
430
+
431
+ def spectrogram(self, x):
432
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
433
+ x = torch.view_as_real(x.stft())
434
+ x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels)
435
+ # Split into bands
436
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
437
+ return x_bands
438
+
439
+ def forward(self, x):
440
+ x_bands = self.spectrogram(x)
441
+ fmap = []
442
+
443
+ x = []
444
+ for band, stack in zip(x_bands, self.band_convs):
445
+ for layer in stack:
446
+ band = layer(band)
447
+ fmap.append(band)
448
+ x.append(band)
449
+
450
+ x = torch.cat(x, dim=-1)
451
+ x = self.conv_post(x)
452
+ fmap.append(x)
453
+
454
+ return fmap
455
+
456
+
457
+ class DACDiscriminator(nn.Module):
458
+ def __init__(
459
+ self,
460
+ channels: int = 1,
461
+ rates: list = [],
462
+ periods: list = [2, 3, 5, 7, 11],
463
+ fft_sizes: list = [2048, 1024, 512],
464
+ sample_rate: int = 44100,
465
+ bands: list = BANDS,
466
+ ):
467
+ """Discriminator that combines multiple discriminators.
468
+
469
+ Parameters
470
+ ----------
471
+ rates : list, optional
472
+ sampling rates (in Hz) to run MSD at, by default []
473
+ If empty, MSD is not used.
474
+ periods : list, optional
475
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
476
+ fft_sizes : list, optional
477
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
478
+ sample_rate : int, optional
479
+ Sampling rate of audio in Hz, by default 44100
480
+ bands : list, optional
481
+ Bands to run MRD at, by default `BANDS`
482
+ """
483
+ super().__init__()
484
+ discs = []
485
+ discs += [MPD(p, channels=channels) for p in periods]
486
+ discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates]
487
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes]
488
+ self.discriminators = nn.ModuleList(discs)
489
+
490
+ def preprocess(self, y):
491
+ # Remove DC offset
492
+ y = y - y.mean(dim=-1, keepdims=True)
493
+ # Peak normalize the volume of input audio
494
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
495
+ return y
496
+
497
+ def forward(self, x):
498
+ x = self.preprocess(x)
499
+ fmaps = [d(x) for d in self.discriminators]
500
+ return fmaps
501
+
502
+ class DACGANLoss(nn.Module):
503
+ """
504
+ Computes a discriminator loss, given a discriminator on
505
+ generated waveforms/spectrograms compared to ground truth
506
+ waveforms/spectrograms. Computes the loss for both the
507
+ discriminator and the generator in separate functions.
508
+ """
509
+
510
+ def __init__(self, **discriminator_kwargs):
511
+ super().__init__()
512
+ self.discriminator = DACDiscriminator(**discriminator_kwargs)
513
+
514
+ def forward(self, fake, real):
515
+ d_fake = self.discriminator(fake)
516
+ d_real = self.discriminator(real)
517
+ return d_fake, d_real
518
+
519
+ def discriminator_loss(self, fake, real):
520
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
521
+
522
+ loss_d = 0
523
+ for x_fake, x_real in zip(d_fake, d_real):
524
+ loss_d += torch.mean(x_fake[-1] ** 2)
525
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
526
+ return loss_d
527
+
528
+ def generator_loss(self, fake, real):
529
+ d_fake, d_real = self.forward(fake, real)
530
+
531
+ loss_g = 0
532
+ for x_fake in d_fake:
533
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
534
+
535
+ loss_feature = 0
536
+
537
+ for i in range(len(d_fake)):
538
+ for j in range(len(d_fake[i]) - 1):
539
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
540
+ return loss_g, loss_feature
541
+
542
+ def loss(self, fake, real):
543
+ gen_loss, feature_distance = self.generator_loss(fake, real)
544
+ dis_loss = self.discriminator_loss(fake, real)
545
+
546
+ return dis_loss, gen_loss, feature_distance
stable_audio_tools/models/dit.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+ from einops import rearrange
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from x_transformers import ContinuousTransformerWrapper, Encoder
9
+
10
+ from .blocks import FourierFeatures
11
+ from .transformer import ContinuousTransformer
12
+
13
+ class DiffusionTransformer(nn.Module):
14
+ def __init__(self,
15
+ io_channels=32,
16
+ patch_size=1,
17
+ embed_dim=768,
18
+ cond_token_dim=0,
19
+ project_cond_tokens=True,
20
+ global_cond_dim=0,
21
+ project_global_cond=True,
22
+ input_concat_dim=0,
23
+ prepend_cond_dim=0,
24
+ depth=12,
25
+ num_heads=8,
26
+ transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers",
27
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
28
+ **kwargs):
29
+
30
+ super().__init__()
31
+
32
+ self.cond_token_dim = cond_token_dim
33
+
34
+ # Timestep embeddings
35
+ timestep_features_dim = 256
36
+
37
+ self.timestep_features = FourierFeatures(1, timestep_features_dim)
38
+
39
+ self.to_timestep_embed = nn.Sequential(
40
+ nn.Linear(timestep_features_dim, embed_dim, bias=True),
41
+ nn.SiLU(),
42
+ nn.Linear(embed_dim, embed_dim, bias=True),
43
+ )
44
+
45
+ if cond_token_dim > 0:
46
+ # Conditioning tokens
47
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
48
+ self.to_cond_embed = nn.Sequential(
49
+ nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
50
+ nn.SiLU(),
51
+ nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
52
+ )
53
+ else:
54
+ cond_embed_dim = 0
55
+
56
+ if global_cond_dim > 0:
57
+ # Global conditioning
58
+ global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
59
+ self.to_global_embed = nn.Sequential(
60
+ nn.Linear(global_cond_dim, global_embed_dim, bias=False),
61
+ nn.SiLU(),
62
+ nn.Linear(global_embed_dim, global_embed_dim, bias=False)
63
+ )
64
+
65
+ if prepend_cond_dim > 0:
66
+ # Prepend conditioning
67
+ self.to_prepend_embed = nn.Sequential(
68
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
69
+ nn.SiLU(),
70
+ nn.Linear(embed_dim, embed_dim, bias=False)
71
+ )
72
+
73
+ self.input_concat_dim = input_concat_dim
74
+
75
+ dim_in = io_channels + self.input_concat_dim
76
+
77
+ self.patch_size = patch_size
78
+
79
+ # Transformer
80
+
81
+ self.transformer_type = transformer_type
82
+
83
+ self.global_cond_type = global_cond_type
84
+
85
+ if self.transformer_type == "x-transformers":
86
+ self.transformer = ContinuousTransformerWrapper(
87
+ dim_in=dim_in * patch_size,
88
+ dim_out=io_channels * patch_size,
89
+ max_seq_len=0, #Not relevant without absolute positional embeds
90
+ attn_layers = Encoder(
91
+ dim=embed_dim,
92
+ depth=depth,
93
+ heads=num_heads,
94
+ attn_flash = True,
95
+ cross_attend = cond_token_dim > 0,
96
+ dim_context=None if cond_embed_dim == 0 else cond_embed_dim,
97
+ zero_init_branch_output=True,
98
+ use_abs_pos_emb = False,
99
+ rotary_pos_emb=True,
100
+ ff_swish = True,
101
+ ff_glu = True,
102
+ **kwargs
103
+ )
104
+ )
105
+
106
+ elif self.transformer_type == "continuous_transformer":
107
+
108
+ global_dim = None
109
+
110
+ if self.global_cond_type == "adaLN":
111
+ # The global conditioning is projected to the embed_dim already at this point
112
+ global_dim = embed_dim
113
+
114
+ self.transformer = ContinuousTransformer(
115
+ dim=embed_dim,
116
+ depth=depth,
117
+ dim_heads=embed_dim // num_heads,
118
+ dim_in=dim_in * patch_size,
119
+ dim_out=io_channels * patch_size,
120
+ cross_attend = cond_token_dim > 0,
121
+ cond_token_dim = cond_embed_dim,
122
+ global_cond_dim=global_dim,
123
+ **kwargs
124
+ )
125
+
126
+ else:
127
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
128
+
129
+ self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
130
+ nn.init.zeros_(self.preprocess_conv.weight)
131
+ self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
132
+ nn.init.zeros_(self.postprocess_conv.weight)
133
+
134
+ def _forward(
135
+ self,
136
+ x,
137
+ t,
138
+ mask=None,
139
+ cross_attn_cond=None,
140
+ cross_attn_cond_mask=None,
141
+ input_concat_cond=None,
142
+ global_embed=None,
143
+ prepend_cond=None,
144
+ prepend_cond_mask=None,
145
+ return_info=False,
146
+ **kwargs):
147
+
148
+ if cross_attn_cond is not None:
149
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond) # MLP endecoder, shape: [1, 130, 768]
150
+
151
+ if global_embed is not None:
152
+ # Project the global conditioning to the embedding dimension
153
+ global_embed = self.to_global_embed(global_embed)
154
+
155
+ prepend_inputs = None
156
+ prepend_mask = None
157
+ prepend_length = 0
158
+ if prepend_cond is not None:
159
+ # Project the prepend conditioning to the embedding dimension
160
+ prepend_cond = self.to_prepend_embed(prepend_cond)
161
+
162
+ prepend_inputs = prepend_cond
163
+ if prepend_cond_mask is not None:
164
+ prepend_mask = prepend_cond_mask
165
+
166
+ if input_concat_cond is not None:
167
+
168
+ # Interpolate input_concat_cond to the same length as x
169
+ if input_concat_cond.shape[2] != x.shape[2]:
170
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
171
+
172
+ x = torch.cat([x, input_concat_cond], dim=1)
173
+
174
+ # Get the batch of timestep embeddings
175
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
176
+
177
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
178
+ if global_embed is not None:
179
+ global_embed = global_embed + timestep_embed
180
+ else:
181
+ global_embed = timestep_embed
182
+
183
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
184
+ if self.global_cond_type == "prepend": # True
185
+ if prepend_inputs is None: # True
186
+ # Prepend inputs are just the global embed, and the mask is all ones
187
+ prepend_inputs = global_embed.unsqueeze(1) # [1, 1, 1536]
188
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
189
+ else:
190
+ # Prepend inputs are the prepend conditioning + the global embed
191
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
192
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
193
+
194
+ prepend_length = prepend_inputs.shape[1] # 1
195
+
196
+ x = self.preprocess_conv(x) + x # [1, 64, 1024]
197
+
198
+ x = rearrange(x, "b c t -> b t c") # [1, 1024, 64]
199
+
200
+ extra_args = {}
201
+
202
+ if self.global_cond_type == "adaLN": # 'prepend'
203
+ extra_args["global_cond"] = global_embed
204
+
205
+ if self.patch_size > 1: # self.patch_size==1
206
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
207
+
208
+ if self.transformer_type == "x-transformers":
209
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
210
+ elif self.transformer_type == "continuous_transformer":
211
+
212
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
213
+
214
+ if return_info:
215
+ output, info = output
216
+ elif self.transformer_type == "mm_transformer":
217
+ output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
218
+
219
+ output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
220
+
221
+ if self.patch_size > 1:
222
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
223
+
224
+ output = self.postprocess_conv(output) + output
225
+
226
+ if return_info:
227
+ return output, info
228
+
229
+ return output
230
+
231
+ def forward(
232
+ self,
233
+ x,
234
+ t,
235
+ cross_attn_cond=None,
236
+ cross_attn_cond_mask=None,
237
+ negative_cross_attn_cond=None,
238
+ negative_cross_attn_mask=None,
239
+ input_concat_cond=None,
240
+ global_embed=None,
241
+ negative_global_embed=None,
242
+ prepend_cond=None,
243
+ prepend_cond_mask=None,
244
+ cfg_scale=1.0,
245
+ cfg_dropout_prob=0.0,
246
+ causal=False,
247
+ scale_phi=0.0,
248
+ mask=None,
249
+ return_info=False,
250
+ **kwargs):
251
+
252
+ assert causal == False, "Causal mode is not supported for DiffusionTransformer"
253
+
254
+ if cross_attn_cond_mask is not None:
255
+ cross_attn_cond_mask = cross_attn_cond_mask.bool()
256
+
257
+ cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
258
+
259
+ if prepend_cond_mask is not None:
260
+ prepend_cond_mask = prepend_cond_mask.bool()
261
+
262
+ # CFG dropout
263
+ if cfg_dropout_prob > 0.0:
264
+ if cross_attn_cond is not None:
265
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
266
+ dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
267
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
268
+
269
+ if prepend_cond is not None:
270
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
271
+ dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
272
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
273
+
274
+
275
+ if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None):
276
+ # Classifier-free guidance
277
+ # Concatenate conditioned and unconditioned inputs on the batch dimension
278
+ batch_inputs = torch.cat([x, x], dim=0)
279
+ batch_timestep = torch.cat([t, t], dim=0)
280
+
281
+ if global_embed is not None:
282
+ batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
283
+ else:
284
+ batch_global_cond = None
285
+
286
+ if input_concat_cond is not None:
287
+ batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
288
+ else:
289
+ batch_input_concat_cond = None
290
+
291
+ batch_cond = None
292
+ batch_cond_masks = None
293
+
294
+ # Handle CFG for cross-attention conditioning
295
+ if cross_attn_cond is not None:
296
+
297
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
298
+
299
+ # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
300
+ if negative_cross_attn_cond is not None:
301
+
302
+ # If there's a negative cross-attention mask, set the masked tokens to the null embed
303
+ if negative_cross_attn_mask is not None:
304
+ negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
305
+
306
+ negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
307
+
308
+ batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
309
+
310
+ else:
311
+ batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
312
+
313
+ if cross_attn_cond_mask is not None:
314
+ batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
315
+
316
+ batch_prepend_cond = None
317
+ batch_prepend_cond_mask = None
318
+
319
+ if prepend_cond is not None:
320
+
321
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
322
+
323
+ batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
324
+
325
+ if prepend_cond_mask is not None:
326
+ batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
327
+
328
+
329
+ if mask is not None:
330
+ batch_masks = torch.cat([mask, mask], dim=0)
331
+ else:
332
+ batch_masks = None
333
+
334
+ batch_output = self._forward(
335
+ batch_inputs,
336
+ batch_timestep,
337
+ cross_attn_cond=batch_cond,
338
+ cross_attn_cond_mask=batch_cond_masks,
339
+ mask = batch_masks,
340
+ input_concat_cond=batch_input_concat_cond,
341
+ global_embed = batch_global_cond,
342
+ prepend_cond = batch_prepend_cond,
343
+ prepend_cond_mask = batch_prepend_cond_mask,
344
+ return_info = return_info,
345
+ **kwargs)
346
+
347
+ if return_info:
348
+ batch_output, info = batch_output
349
+
350
+ cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
351
+ cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
352
+
353
+ # CFG Rescale
354
+ if scale_phi != 0.0:
355
+ cond_out_std = cond_output.std(dim=1, keepdim=True)
356
+ out_cfg_std = cfg_output.std(dim=1, keepdim=True)
357
+ output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
358
+ else:
359
+ output = cfg_output
360
+
361
+ if return_info:
362
+ return output, info
363
+
364
+ return output
365
+
366
+ else:
367
+ return self._forward(
368
+ x,
369
+ t,
370
+ cross_attn_cond=cross_attn_cond,
371
+ cross_attn_cond_mask=cross_attn_cond_mask,
372
+ input_concat_cond=input_concat_cond,
373
+ global_embed=global_embed,
374
+ prepend_cond=prepend_cond,
375
+ prepend_cond_mask=prepend_cond_mask,
376
+ mask=mask,
377
+ return_info=return_info,
378
+ **kwargs
379
+ )
stable_audio_tools/models/factory.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def create_model_from_config(model_config):
4
+ model_type = model_config.get('model_type', None)
5
+
6
+ assert model_type is not None, 'model_type must be specified in model config'
7
+
8
+ if model_type == 'autoencoder':
9
+ from .autoencoders import create_autoencoder_from_config
10
+ return create_autoencoder_from_config(model_config)
11
+ elif model_type == 'diffusion_uncond':
12
+ from .diffusion import create_diffusion_uncond_from_config
13
+ return create_diffusion_uncond_from_config(model_config)
14
+ elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior":
15
+ from .diffusion import create_diffusion_cond_from_config
16
+ return create_diffusion_cond_from_config(model_config)
17
+ elif model_type == 'diffusion_autoencoder':
18
+ from .autoencoders import create_diffAE_from_config
19
+ return create_diffAE_from_config(model_config)
20
+ elif model_type == 'lm':
21
+ from .lm import create_audio_lm_from_config
22
+ return create_audio_lm_from_config(model_config)
23
+ else:
24
+ raise NotImplementedError(f'Unknown model type: {model_type}')
25
+
26
+ def create_model_from_config_path(model_config_path):
27
+ with open(model_config_path) as f:
28
+ model_config = json.load(f)
29
+
30
+ return create_model_from_config(model_config)
31
+
32
+ def create_pretransform_from_config(pretransform_config, sample_rate):
33
+ pretransform_type = pretransform_config.get('type', None)
34
+
35
+ assert pretransform_type is not None, 'type must be specified in pretransform config'
36
+
37
+ if pretransform_type == 'autoencoder':
38
+ from .autoencoders import create_autoencoder_from_config
39
+ from .pretransforms import AutoencoderPretransform
40
+
41
+ # Create fake top-level config to pass sample rate to autoencoder constructor
42
+ # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
43
+ autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
44
+ autoencoder = create_autoencoder_from_config(autoencoder_config)
45
+
46
+ scale = pretransform_config.get("scale", 1.0)
47
+ model_half = pretransform_config.get("model_half", False)
48
+ iterate_batch = pretransform_config.get("iterate_batch", False)
49
+ chunked = pretransform_config.get("chunked", False)
50
+
51
+ pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
52
+ elif pretransform_type == 'wavelet':
53
+ from .pretransforms import WaveletPretransform
54
+
55
+ wavelet_config = pretransform_config["config"]
56
+ channels = wavelet_config["channels"]
57
+ levels = wavelet_config["levels"]
58
+ wavelet = wavelet_config["wavelet"]
59
+
60
+ pretransform = WaveletPretransform(channels, levels, wavelet)
61
+ elif pretransform_type == 'pqmf':
62
+ from .pretransforms import PQMFPretransform
63
+ pqmf_config = pretransform_config["config"]
64
+ pretransform = PQMFPretransform(**pqmf_config)
65
+ elif pretransform_type == 'dac_pretrained':
66
+ from .pretransforms import PretrainedDACPretransform
67
+ pretrained_dac_config = pretransform_config["config"]
68
+ pretransform = PretrainedDACPretransform(**pretrained_dac_config)
69
+ elif pretransform_type == "audiocraft_pretrained":
70
+ from .pretransforms import AudiocraftCompressionPretransform
71
+
72
+ audiocraft_config = pretransform_config["config"]
73
+ pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
74
+ else:
75
+ raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
76
+
77
+ enable_grad = pretransform_config.get('enable_grad', False)
78
+ pretransform.enable_grad = enable_grad
79
+
80
+ pretransform.eval().requires_grad_(pretransform.enable_grad)
81
+
82
+ return pretransform
83
+
84
+ def create_bottleneck_from_config(bottleneck_config):
85
+ bottleneck_type = bottleneck_config.get('type', None)
86
+
87
+ assert bottleneck_type is not None, 'type must be specified in bottleneck config'
88
+
89
+ if bottleneck_type == 'tanh':
90
+ from .bottleneck import TanhBottleneck
91
+ bottleneck = TanhBottleneck()
92
+ elif bottleneck_type == 'vae':
93
+ from .bottleneck import VAEBottleneck
94
+ bottleneck = VAEBottleneck()
95
+ elif bottleneck_type == 'rvq':
96
+ from .bottleneck import RVQBottleneck
97
+
98
+ quantizer_params = {
99
+ "dim": 128,
100
+ "codebook_size": 1024,
101
+ "num_quantizers": 8,
102
+ "decay": 0.99,
103
+ "kmeans_init": True,
104
+ "kmeans_iters": 50,
105
+ "threshold_ema_dead_code": 2,
106
+ }
107
+
108
+ quantizer_params.update(bottleneck_config["config"])
109
+
110
+ bottleneck = RVQBottleneck(**quantizer_params)
111
+ elif bottleneck_type == "dac_rvq":
112
+ from .bottleneck import DACRVQBottleneck
113
+
114
+ bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
115
+
116
+ elif bottleneck_type == 'rvq_vae':
117
+ from .bottleneck import RVQVAEBottleneck
118
+
119
+ quantizer_params = {
120
+ "dim": 128,
121
+ "codebook_size": 1024,
122
+ "num_quantizers": 8,
123
+ "decay": 0.99,
124
+ "kmeans_init": True,
125
+ "kmeans_iters": 50,
126
+ "threshold_ema_dead_code": 2,
127
+ }
128
+
129
+ quantizer_params.update(bottleneck_config["config"])
130
+
131
+ bottleneck = RVQVAEBottleneck(**quantizer_params)
132
+
133
+ elif bottleneck_type == 'dac_rvq_vae':
134
+ from .bottleneck import DACRVQVAEBottleneck
135
+ bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
136
+ elif bottleneck_type == 'l2_norm':
137
+ from .bottleneck import L2Bottleneck
138
+ bottleneck = L2Bottleneck()
139
+ elif bottleneck_type == "wasserstein":
140
+ from .bottleneck import WassersteinBottleneck
141
+ bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
142
+ elif bottleneck_type == "fsq":
143
+ from .bottleneck import FSQBottleneck
144
+ bottleneck = FSQBottleneck(**bottleneck_config["config"])
145
+ else:
146
+ raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
147
+
148
+ requires_grad = bottleneck_config.get('requires_grad', True)
149
+ if not requires_grad:
150
+ for param in bottleneck.parameters():
151
+ param.requires_grad = False
152
+
153
+ return bottleneck
stable_audio_tools/models/lm.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+ from tqdm.auto import trange
4
+ import typing as tp
5
+ from einops import rearrange
6
+ from torch import nn
7
+
8
+ from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
9
+ from .factory import create_pretransform_from_config
10
+ from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone
11
+ from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform
12
+ from .utils import multinomial, sample_top_k, sample_top_p
13
+
14
+ from .codebook_patterns import (
15
+ CodebooksPatternProvider,
16
+ DelayedPatternProvider,
17
+ MusicLMPattern,
18
+ ParallelPatternProvider,
19
+ UnrolledPatternProvider
20
+ )
21
+
22
+ # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license
23
+ # License can be found in LICENSES/LICENSE_META.txt
24
+
25
+ @dataclass
26
+ class LMOutput:
27
+ # The logits are already re-aligned with the input codes
28
+ # hence no extra shift is required, e.g. when computing CE
29
+ logits: torch.Tensor # [B, K, T, card]
30
+ mask: torch.Tensor # [B, K, T]
31
+
32
+ # Wrapper for a multi-codebook language model
33
+ # Handles patterns and quantizer heads
34
+ class AudioLanguageModel(nn.Module):
35
+ def __init__(
36
+ self,
37
+ pattern_provider: CodebooksPatternProvider,
38
+ backbone: AudioLMBackbone,
39
+ num_quantizers: int,
40
+ codebook_size: int
41
+ ):
42
+ super().__init__()
43
+
44
+ self.pattern_provider = pattern_provider
45
+ self.backbone = backbone
46
+ self.num_quantizers = num_quantizers
47
+ self.codebook_size = codebook_size
48
+
49
+ self.masked_token_id = codebook_size
50
+
51
+ # Per-quantizer embedders
52
+ # Add one for the mask embed
53
+ self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)])
54
+
55
+ # Per-quantizer output heads
56
+ self.quantizer_heads = nn.ModuleList([
57
+ nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers)
58
+ ])
59
+
60
+ def forward(self,
61
+ sequence: torch.Tensor, #[batch, seq_len,
62
+ prepend_cond=None, #[batch, seq, channels]
63
+ prepend_cond_mask=None,
64
+ cross_attn_cond=None, #[batch, seq, channels],
65
+ **kwargs
66
+ ):
67
+
68
+
69
+ batch, num_quantizers, seq_len = sequence.shape
70
+
71
+ assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model"
72
+
73
+ backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim]
74
+
75
+ dtype = next(self.parameters()).dtype
76
+
77
+ if cross_attn_cond is not None:
78
+ cross_attn_cond = cross_attn_cond.to(dtype)
79
+
80
+ if prepend_cond is not None:
81
+ prepend_cond = prepend_cond.to(dtype)
82
+
83
+ if prepend_cond_mask is not None:
84
+ prepend_cond_mask = prepend_cond_mask.to(dtype)
85
+
86
+ backbone_input = backbone_input.to(dtype)
87
+
88
+ output = self.backbone(
89
+ backbone_input,
90
+ cross_attn_cond=cross_attn_cond,
91
+ prepend_cond=prepend_cond,
92
+ prepend_cond_mask=prepend_cond_mask,
93
+ **kwargs
94
+ ) # [batch, seq_len, embed_dim]
95
+
96
+ # Run output through quantizer heads
97
+ logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size]
98
+
99
+ return logits
100
+
101
+ def compute_logits(
102
+ self,
103
+ codes, #[batch, num_quantizers, seq_len]
104
+ **kwargs):
105
+ """
106
+ Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning
107
+ Handles translation between input sequence and pattern-shifted sequence
108
+ Only used during training
109
+ """
110
+
111
+ batch, _, seq_len = codes.shape
112
+
113
+ pattern = self.pattern_provider.get_pattern(seq_len)
114
+
115
+ # Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps
116
+ shifted_codes, _, _ = pattern.build_pattern_sequence(
117
+ codes,
118
+ self.masked_token_id,
119
+ keep_only_valid_steps=True
120
+ )
121
+
122
+ # Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size]
123
+ logits = self(shifted_codes, **kwargs)
124
+
125
+ # Rearrange logits to prepare to revert pattern
126
+ logits = rearrange(logits, "b n s c -> b c n s")
127
+
128
+ # Revert sequence logits back to original sequence length, removing masked steps
129
+ logits, _, logits_mask = pattern.revert_pattern_logits(
130
+ logits, float('nan'), keep_only_valid_steps=True
131
+ )
132
+
133
+ logits = rearrange(logits, "b c n t -> b n t c")
134
+
135
+ logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len]
136
+
137
+ return LMOutput(logits=logits, mask=logits_mask)
138
+
139
+ # Conditioning and generation wrapper for a multi-codebook language model
140
+ # Handles conditioning, CFG, generation, and encoding/decoding
141
+ class AudioLanguageModelWrapper(nn.Module):
142
+ def __init__(
143
+ self,
144
+ pretransform: Pretransform,
145
+ lm: AudioLanguageModel,
146
+ sample_rate: int,
147
+ min_input_length: int,
148
+ conditioner: MultiConditioner = None,
149
+ cross_attn_cond_ids: tp.List[str] = [],
150
+ prepend_cond_ids: tp.List[str] = [],
151
+ global_cond_ids: tp.List[str] = []
152
+ ):
153
+ super().__init__()
154
+
155
+ assert pretransform.is_discrete, "Pretransform must be discrete"
156
+ self.pretransform = pretransform
157
+
158
+ self.pretransform.requires_grad_(False)
159
+ self.pretransform.eval()
160
+
161
+ if isinstance(self.pretransform, AutoencoderPretransform):
162
+ self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers
163
+ self.codebook_size = self.pretransform.model.bottleneck.codebook_size
164
+ elif isinstance(self.pretransform, PretrainedDACPretransform):
165
+ self.num_quantizers = self.pretransform.model.num_quantizers
166
+ self.codebook_size = self.pretransform.model.codebook_size
167
+ elif isinstance(self.pretransform, AudiocraftCompressionPretransform):
168
+ self.num_quantizers = self.pretransform.num_quantizers
169
+ self.codebook_size = self.pretransform.codebook_size
170
+ else:
171
+ raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}")
172
+
173
+ self.conditioner = conditioner
174
+
175
+ self.lm = lm
176
+
177
+ self.sample_rate = sample_rate
178
+ self.min_input_length = min_input_length
179
+
180
+ self.cross_attn_cond_ids = cross_attn_cond_ids
181
+ self.prepend_cond_ids = prepend_cond_ids
182
+ self.global_cond_ids = global_cond_ids
183
+
184
+ def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False):
185
+ cross_attention_input = None
186
+ prepend_cond = None
187
+ prepend_cond_mask = None
188
+ global_cond = None
189
+
190
+ if len(self.cross_attn_cond_ids) > 0:
191
+ # Concatenate all cross-attention inputs over the sequence dimension
192
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
193
+ cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1)
194
+
195
+ if len(self.prepend_cond_ids) > 0:
196
+ # Concatenate all prepend conditioning inputs over the sequence dimension
197
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
198
+ prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1)
199
+ prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1)
200
+
201
+ if len(self.global_cond_ids) > 0:
202
+ # Concatenate all global conditioning inputs over the channel dimension
203
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
204
+ global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1)
205
+ if len(global_cond.shape) == 3:
206
+ global_cond = global_cond.squeeze(1)
207
+
208
+ if negative:
209
+ return {
210
+ "negative_cross_attn_cond": cross_attention_input,
211
+ "negative_prepend_cond": prepend_cond,
212
+ "negative_prepend_cond_mask": prepend_cond_mask,
213
+ "negative_global_cond": global_cond
214
+ }
215
+ else:
216
+ return {
217
+ "cross_attn_cond": cross_attention_input,
218
+ "prepend_cond": prepend_cond,
219
+ "prepend_cond_mask": prepend_cond_mask,
220
+ "global_cond": global_cond
221
+ }
222
+
223
+ def compute_logits(
224
+ self,
225
+ codes,
226
+ condition_tensors=None,
227
+ cfg_dropout_prob=0.0,
228
+ **kwargs
229
+ ):
230
+ """
231
+ Compute logits for a batch of codes, and translates from conditioning inputs to model inputs
232
+ Handles CFG dropout
233
+ """
234
+
235
+ if condition_tensors is None:
236
+ condition_tensors = {}
237
+
238
+ conditioning_inputs = self.get_conditioning_inputs(condition_tensors)
239
+
240
+ cross_attn_cond = conditioning_inputs["cross_attn_cond"]
241
+ prepend_cond = conditioning_inputs["prepend_cond"]
242
+ prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
243
+ global_cond = conditioning_inputs["global_cond"]
244
+
245
+ if cfg_dropout_prob > 0.0:
246
+ if cross_attn_cond is not None:
247
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
248
+ dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
249
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
250
+
251
+ if prepend_cond is not None:
252
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
253
+ dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
254
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
255
+
256
+ if global_cond is not None:
257
+ null_embed = torch.zeros_like(global_cond, device=global_cond.device)
258
+ dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool)
259
+ global_cond = torch.where(dropout_mask, null_embed, global_cond)
260
+
261
+ return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
262
+
263
+ def _sample_next_token(
264
+ self,
265
+ sequence, #[batch, num_quantizers, seq_len]
266
+ conditioning_tensors=None,
267
+ cross_attn_use_cfg=True,
268
+ prepend_use_cfg=True,
269
+ global_use_cfg=True,
270
+ cfg_scale=1.0,
271
+ top_k=250,
272
+ top_p=0.0,
273
+ temp=1.0,
274
+ **kwargs
275
+ ):
276
+ """
277
+ Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs
278
+ Handles CFG inference
279
+ """
280
+
281
+ if conditioning_tensors is None:
282
+ conditioning_tensors = {}
283
+
284
+ conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors)
285
+
286
+ cross_attn_cond = conditioning_inputs["cross_attn_cond"]
287
+ prepend_cond = conditioning_inputs["prepend_cond"]
288
+ prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
289
+ global_cond = conditioning_inputs["global_cond"]
290
+
291
+ if cfg_scale != 1.0:
292
+
293
+ # Batch size is doubled to account for negative samples
294
+ sequence = torch.cat([sequence, sequence], dim=0)
295
+
296
+ if cross_attn_cond is not None and cross_attn_use_cfg:
297
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
298
+
299
+ cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
300
+
301
+ if prepend_cond is not None and prepend_use_cfg:
302
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
303
+
304
+ prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
305
+
306
+ if prepend_cond_mask is not None:
307
+ prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
308
+
309
+ if global_cond is not None and global_use_cfg:
310
+ null_embed = torch.zeros_like(global_cond, device=global_cond.device)
311
+
312
+ global_cond = torch.cat([global_cond, null_embed], dim=0)
313
+
314
+ logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
315
+
316
+ if cfg_scale != 1.0:
317
+ cond_logits, uncond_logits = logits.chunk(2, dim=0)
318
+
319
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
320
+
321
+ logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len]
322
+
323
+ # Grab the logits for the last step
324
+ logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size]
325
+
326
+ # Apply top-k or top-p sampling
327
+
328
+ if temp > 0:
329
+ probs = torch.softmax(logits / temp, dim=-1)
330
+
331
+ if top_p > 0.0:
332
+ next_token = sample_top_p(probs, p=top_p)
333
+ elif top_k > 0:
334
+ next_token = sample_top_k(probs, k=top_k)
335
+ else:
336
+ next_token = multinomial(probs, num_samples=1)
337
+
338
+ else:
339
+ next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1]
340
+
341
+ return next_token
342
+
343
+ @torch.no_grad()
344
+ def generate(
345
+ self,
346
+ max_gen_len: int = 256,
347
+ batch_size: tp.Optional[int] = None,
348
+ init_data: tp.Optional[torch.Tensor] = None,
349
+ conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
350
+ conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None,
351
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
352
+ use_cache: bool = True,
353
+ cfg_scale: float = 1.0,
354
+ **kwargs
355
+ ):
356
+ device = next(self.parameters()).device
357
+
358
+ if conditioning_tensors is None and conditioning is not None:
359
+ # Convert conditioning inputs to conditioning tensors
360
+ conditioning_tensors = self.conditioner(conditioning, device)
361
+
362
+ # Check that batch size is consistent across inputs
363
+ possible_batch_sizes = []
364
+
365
+ if batch_size is not None:
366
+ possible_batch_sizes.append(batch_size)
367
+ elif init_data is not None:
368
+ possible_batch_sizes.append(init_data.shape[0])
369
+ elif conditioning_tensors is not None:
370
+ # Assume that the first conditioning tensor has the batch dimension
371
+ possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0])
372
+ else:
373
+ possible_batch_sizes.append(1)
374
+
375
+ assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs"
376
+
377
+ batch_size = possible_batch_sizes[0]
378
+
379
+ if init_data is None:
380
+ # Initialize with zeros
381
+ assert batch_size > 0
382
+ init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long)
383
+
384
+ batch_size, num_quantizers, seq_len = init_data.shape
385
+
386
+ start_offset = seq_len
387
+ assert start_offset < max_gen_len, "init data longer than max gen length"
388
+
389
+ pattern = self.lm.pattern_provider.get_pattern(max_gen_len)
390
+
391
+ unknown_token = -1
392
+
393
+ # Initialize the generated codes with the init data, padded with unknown tokens
394
+ gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long)
395
+ gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len]
396
+
397
+ gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len]
398
+
399
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
400
+ assert start_offset_sequence is not None
401
+
402
+ # Generation
403
+ prev_offset = 0
404
+ gen_sequence_len = gen_sequence.shape[-1]
405
+
406
+ # Reset generation cache
407
+ if use_cache and self.lm.backbone.use_generation_cache:
408
+ self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2)
409
+
410
+ for offset in trange(start_offset_sequence, gen_sequence_len):
411
+
412
+ # Get the full sequence up to the current offset
413
+ curr_sequence = gen_sequence[..., prev_offset:offset]
414
+
415
+ next_token = self._sample_next_token(
416
+ curr_sequence,
417
+ conditioning_tensors=conditioning_tensors,
418
+ use_cache=use_cache,
419
+ cfg_scale=cfg_scale,
420
+ **kwargs
421
+ )
422
+
423
+ valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1)
424
+ next_token[~valid_mask] = self.lm.masked_token_id
425
+
426
+ # Update the generated sequence with the next token
427
+ gen_sequence[..., offset:offset+1] = torch.where(
428
+ gen_sequence[..., offset:offset+1] == unknown_token,
429
+ next_token,
430
+ gen_sequence[..., offset:offset+1]
431
+ )
432
+
433
+ if use_cache and self.lm.backbone.use_generation_cache:
434
+ # Only update the offset if caching is being used
435
+ prev_offset = offset
436
+
437
+ self.lm.backbone.update_generation_cache(offset)
438
+
439
+ if callback is not None:
440
+ # Callback to report progress
441
+ # Pass in the offset relative to the start of the sequence, and the length of the current sequence
442
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
443
+
444
+ assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence"
445
+
446
+ out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
447
+
448
+ # sanity checks over the returned codes and corresponding masks
449
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
450
+ assert (out_mask[..., :max_gen_len] == 1).all()
451
+
452
+ #out_codes = out_codes[..., 0:max_gen_len]
453
+
454
+ return out_codes
455
+
456
+
457
+ def generate_audio(
458
+ self,
459
+ **kwargs
460
+ ):
461
+ """
462
+ Generate audio from a batch of codes
463
+ """
464
+
465
+ codes = self.generate(**kwargs)
466
+
467
+ audio = self.pretransform.decode_tokens(codes)
468
+
469
+ return audio
470
+
471
+
472
+ def create_audio_lm_from_config(config):
473
+ model_config = config.get('model', None)
474
+ assert model_config is not None, 'model config must be specified in config'
475
+
476
+ sample_rate = config.get('sample_rate', None)
477
+ assert sample_rate is not None, "Must specify sample_rate in config"
478
+
479
+ lm_config = model_config.get('lm', None)
480
+ assert lm_config is not None, 'lm config must be specified in model config'
481
+
482
+ codebook_pattern = lm_config.get("codebook_pattern", "delay")
483
+
484
+ pattern_providers = {
485
+ 'parallel': ParallelPatternProvider,
486
+ 'delay': DelayedPatternProvider,
487
+ 'unroll': UnrolledPatternProvider,
488
+ 'musiclm': MusicLMPattern,
489
+ }
490
+
491
+ pretransform_config = model_config.get("pretransform", None)
492
+
493
+ pretransform = create_pretransform_from_config(pretransform_config, sample_rate)
494
+
495
+ assert pretransform.is_discrete, "Pretransform must be discrete"
496
+
497
+ min_input_length = pretransform.downsampling_ratio
498
+
499
+ pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers)
500
+
501
+ conditioning_config = model_config.get('conditioning', None)
502
+
503
+ conditioner = None
504
+ if conditioning_config is not None:
505
+ conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
506
+
507
+ cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', [])
508
+ prepend_cond_ids = lm_config.get('prepend_cond_ids', [])
509
+ global_cond_ids = lm_config.get('global_cond_ids', [])
510
+
511
+ lm_type = lm_config.get("type", None)
512
+ lm_model_config = lm_config.get("config", None)
513
+
514
+ assert lm_type is not None, "Must specify lm type in lm config"
515
+ assert lm_model_config is not None, "Must specify lm model config in lm config"
516
+
517
+ if lm_type == "x-transformers":
518
+ backbone = XTransformersAudioLMBackbone(**lm_model_config)
519
+ elif lm_type == "continuous_transformer":
520
+ backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config)
521
+ else:
522
+ raise NotImplementedError(f"Unrecognized lm type {lm_type}")
523
+
524
+ lm = AudioLanguageModel(
525
+ pattern_provider=pattern_provider,
526
+ backbone=backbone,
527
+ num_quantizers=pretransform.num_quantizers,
528
+ codebook_size=pretransform.codebook_size
529
+ )
530
+
531
+ model = AudioLanguageModelWrapper(
532
+ pretransform=pretransform,
533
+ lm=lm,
534
+ conditioner=conditioner,
535
+ sample_rate=sample_rate,
536
+ min_input_length=min_input_length,
537
+ cross_attn_cond_ids=cross_attn_cond_ids,
538
+ prepend_cond_ids=prepend_cond_ids,
539
+ global_cond_ids=global_cond_ids
540
+ )
541
+
542
+ return model
stable_audio_tools/models/local_attention.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from einops import rearrange
4
+ from torch import nn
5
+
6
+ from .blocks import AdaRMSNorm
7
+ from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
8
+
9
+ def checkpoint(function, *args, **kwargs):
10
+ kwargs.setdefault("use_reentrant", False)
11
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
12
+
13
+ # Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
14
+ class ContinuousLocalTransformer(nn.Module):
15
+ def __init__(
16
+ self,
17
+ *,
18
+ dim,
19
+ depth,
20
+ dim_in = None,
21
+ dim_out = None,
22
+ causal = False,
23
+ local_attn_window_size = 64,
24
+ heads = 8,
25
+ ff_mult = 2,
26
+ cond_dim = 0,
27
+ cross_attn_cond_dim = 0,
28
+ **kwargs
29
+ ):
30
+ super().__init__()
31
+
32
+ dim_head = dim//heads
33
+
34
+ self.layers = nn.ModuleList([])
35
+
36
+ self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
37
+
38
+ self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
39
+
40
+ self.local_attn_window_size = local_attn_window_size
41
+
42
+ self.cond_dim = cond_dim
43
+
44
+ self.cross_attn_cond_dim = cross_attn_cond_dim
45
+
46
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
47
+
48
+ for _ in range(depth):
49
+
50
+ self.layers.append(nn.ModuleList([
51
+ AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
52
+ Attention(
53
+ dim=dim,
54
+ dim_heads=dim_head,
55
+ causal=causal,
56
+ zero_init_output=True,
57
+ natten_kernel_size=local_attn_window_size,
58
+ ),
59
+ Attention(
60
+ dim=dim,
61
+ dim_heads=dim_head,
62
+ dim_context = cross_attn_cond_dim,
63
+ zero_init_output=True
64
+ ) if self.cross_attn_cond_dim > 0 else nn.Identity(),
65
+ AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
66
+ FeedForward(dim = dim, mult = ff_mult, no_bias=True)
67
+ ]))
68
+
69
+ def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
70
+
71
+ x = checkpoint(self.project_in, x)
72
+
73
+ if prepend_cond is not None:
74
+ x = torch.cat([prepend_cond, x], dim=1)
75
+
76
+ pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
77
+
78
+ for attn_norm, attn, xattn, ff_norm, ff in self.layers:
79
+
80
+ residual = x
81
+ if cond is not None:
82
+ x = checkpoint(attn_norm, x, cond)
83
+ else:
84
+ x = checkpoint(attn_norm, x)
85
+
86
+ x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
87
+
88
+ if cross_attn_cond is not None:
89
+ x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
90
+
91
+ residual = x
92
+
93
+ if cond is not None:
94
+ x = checkpoint(ff_norm, x, cond)
95
+ else:
96
+ x = checkpoint(ff_norm, x)
97
+
98
+ x = checkpoint(ff, x) + residual
99
+
100
+ return checkpoint(self.project_out, x)
101
+
102
+ class TransformerDownsampleBlock1D(nn.Module):
103
+ def __init__(
104
+ self,
105
+ in_channels,
106
+ embed_dim = 768,
107
+ depth = 3,
108
+ heads = 12,
109
+ downsample_ratio = 2,
110
+ local_attn_window_size = 64,
111
+ **kwargs
112
+ ):
113
+ super().__init__()
114
+
115
+ self.downsample_ratio = downsample_ratio
116
+
117
+ self.transformer = ContinuousLocalTransformer(
118
+ dim=embed_dim,
119
+ depth=depth,
120
+ heads=heads,
121
+ local_attn_window_size=local_attn_window_size,
122
+ **kwargs
123
+ )
124
+
125
+ self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
126
+
127
+ self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
128
+
129
+
130
+ def forward(self, x):
131
+
132
+ x = checkpoint(self.project_in, x)
133
+
134
+ # Compute
135
+ x = self.transformer(x)
136
+
137
+ # Trade sequence length for channels
138
+ x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
139
+
140
+ # Project back to embed dim
141
+ x = checkpoint(self.project_down, x)
142
+
143
+ return x
144
+
145
+ class TransformerUpsampleBlock1D(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_channels,
149
+ embed_dim,
150
+ depth = 3,
151
+ heads = 12,
152
+ upsample_ratio = 2,
153
+ local_attn_window_size = 64,
154
+ **kwargs
155
+ ):
156
+ super().__init__()
157
+
158
+ self.upsample_ratio = upsample_ratio
159
+
160
+ self.transformer = ContinuousLocalTransformer(
161
+ dim=embed_dim,
162
+ depth=depth,
163
+ heads=heads,
164
+ local_attn_window_size = local_attn_window_size,
165
+ **kwargs
166
+ )
167
+
168
+ self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
169
+
170
+ self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
171
+
172
+ def forward(self, x):
173
+
174
+ # Project to embed dim
175
+ x = checkpoint(self.project_in, x)
176
+
177
+ # Project to increase channel dim
178
+ x = checkpoint(self.project_up, x)
179
+
180
+ # Trade channels for sequence length
181
+ x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
182
+
183
+ # Compute
184
+ x = self.transformer(x)
185
+
186
+ return x
187
+
188
+
189
+ class TransformerEncoder1D(nn.Module):
190
+ def __init__(
191
+ self,
192
+ in_channels,
193
+ out_channels,
194
+ embed_dims = [96, 192, 384, 768],
195
+ heads = [12, 12, 12, 12],
196
+ depths = [3, 3, 3, 3],
197
+ ratios = [2, 2, 2, 2],
198
+ local_attn_window_size = 64,
199
+ **kwargs
200
+ ):
201
+ super().__init__()
202
+
203
+ layers = []
204
+
205
+ for layer in range(len(depths)):
206
+ prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
207
+
208
+ layers.append(
209
+ TransformerDownsampleBlock1D(
210
+ in_channels = prev_dim,
211
+ embed_dim = embed_dims[layer],
212
+ heads = heads[layer],
213
+ depth = depths[layer],
214
+ downsample_ratio = ratios[layer],
215
+ local_attn_window_size = local_attn_window_size,
216
+ **kwargs
217
+ )
218
+ )
219
+
220
+ self.layers = nn.Sequential(*layers)
221
+
222
+ self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
223
+ self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
224
+
225
+ def forward(self, x):
226
+ x = rearrange(x, "b c n -> b n c")
227
+ x = checkpoint(self.project_in, x)
228
+ x = self.layers(x)
229
+ x = checkpoint(self.project_out, x)
230
+ x = rearrange(x, "b n c -> b c n")
231
+
232
+ return x
233
+
234
+
235
+ class TransformerDecoder1D(nn.Module):
236
+ def __init__(
237
+ self,
238
+ in_channels,
239
+ out_channels,
240
+ embed_dims = [768, 384, 192, 96],
241
+ heads = [12, 12, 12, 12],
242
+ depths = [3, 3, 3, 3],
243
+ ratios = [2, 2, 2, 2],
244
+ local_attn_window_size = 64,
245
+ **kwargs
246
+ ):
247
+
248
+ super().__init__()
249
+
250
+ layers = []
251
+
252
+ for layer in range(len(depths)):
253
+ prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
254
+
255
+ layers.append(
256
+ TransformerUpsampleBlock1D(
257
+ in_channels = prev_dim,
258
+ embed_dim = embed_dims[layer],
259
+ heads = heads[layer],
260
+ depth = depths[layer],
261
+ upsample_ratio = ratios[layer],
262
+ local_attn_window_size = local_attn_window_size,
263
+ **kwargs
264
+ )
265
+ )
266
+
267
+ self.layers = nn.Sequential(*layers)
268
+
269
+ self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
270
+ self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
271
+
272
+ def forward(self, x):
273
+ x = rearrange(x, "b c n -> b n c")
274
+ x = checkpoint(self.project_in, x)
275
+ x = self.layers(x)
276
+ x = checkpoint(self.project_out, x)
277
+ x = rearrange(x, "b n c -> b c n")
278
+ return x
stable_audio_tools/models/pqmf.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from scipy.optimize import fmin
7
+ from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
8
+
9
+ class PQMF(nn.Module):
10
+ """
11
+ Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
12
+ Uses polyphase representation which is computationally more efficient for real-time.
13
+
14
+ Parameters:
15
+ - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
16
+ - num_bands (int): Number of desired frequency bands. It must be a power of 2.
17
+ """
18
+
19
+ def __init__(self, attenuation, num_bands):
20
+ super(PQMF, self).__init__()
21
+
22
+ # Ensure num_bands is a power of 2
23
+ is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
24
+ assert is_power_of_2, "'num_bands' must be a power of 2."
25
+
26
+ # Create the prototype filter
27
+ prototype_filter = design_prototype_filter(attenuation, num_bands)
28
+ filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
29
+ padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
30
+
31
+ # Register filters and settings
32
+ self.register_buffer("filter_bank", padded_filter_bank)
33
+ self.register_buffer("prototype", prototype_filter)
34
+ self.num_bands = num_bands
35
+
36
+ def forward(self, signal):
37
+ """Decompose the signal into multiple frequency bands."""
38
+ # If signal is not a pytorch tensor of Batch x Channels x Length, convert it
39
+ signal = prepare_signal_dimensions(signal)
40
+ # The signal length must be a multiple of num_bands. Pad it with zeros.
41
+ signal = pad_signal(signal, self.num_bands)
42
+ # run it
43
+ signal = polyphase_analysis(signal, self.filter_bank)
44
+ return apply_alias_cancellation(signal)
45
+
46
+ def inverse(self, bands):
47
+ """Reconstruct the original signal from the frequency bands."""
48
+ bands = apply_alias_cancellation(bands)
49
+ return polyphase_synthesis(bands, self.filter_bank)
50
+
51
+
52
+ def prepare_signal_dimensions(signal):
53
+ """
54
+ Rearrange signal into Batch x Channels x Length.
55
+
56
+ Parameters
57
+ ----------
58
+ signal : torch.Tensor or numpy.ndarray
59
+ The input signal.
60
+
61
+ Returns
62
+ -------
63
+ torch.Tensor
64
+ Preprocessed signal tensor.
65
+ """
66
+ # Convert numpy to torch tensor
67
+ if isinstance(signal, np.ndarray):
68
+ signal = torch.from_numpy(signal)
69
+
70
+ # Ensure tensor
71
+ if not isinstance(signal, torch.Tensor):
72
+ raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
73
+
74
+ # Modify dimension of signal to Batch x Channels x Length
75
+ if signal.dim() == 1:
76
+ # This is just a mono signal. Unsqueeze to 1 x 1 x Length
77
+ signal = signal.unsqueeze(0).unsqueeze(0)
78
+ elif signal.dim() == 2:
79
+ # This is a multi-channel signal (e.g. stereo)
80
+ # Rearrange so that larger dimension (Length) is last
81
+ if signal.shape[0] > signal.shape[1]:
82
+ signal = signal.T
83
+ # Unsqueeze to 1 x Channels x Length
84
+ signal = signal.unsqueeze(0)
85
+ return signal
86
+
87
+ def pad_signal(signal, num_bands):
88
+ """
89
+ Pads the signal to make its length divisible by the given number of bands.
90
+
91
+ Parameters
92
+ ----------
93
+ signal : torch.Tensor
94
+ The input signal tensor, where the last dimension represents the signal length.
95
+
96
+ num_bands : int
97
+ The number of bands by which the signal length should be divisible.
98
+
99
+ Returns
100
+ -------
101
+ torch.Tensor
102
+ The padded signal tensor. If the original signal length was already divisible
103
+ by num_bands, returns the original signal unchanged.
104
+ """
105
+ remainder = signal.shape[-1] % num_bands
106
+ if remainder > 0:
107
+ padding_size = num_bands - remainder
108
+ signal = nn.functional.pad(signal, (0, padding_size))
109
+ return signal
110
+
111
+ def generate_modulated_filter_bank(prototype_filter, num_bands):
112
+ """
113
+ Generate a QMF bank of cosine modulated filters based on a given prototype filter.
114
+
115
+ Parameters
116
+ ----------
117
+ prototype_filter : torch.Tensor
118
+ The prototype filter used as the basis for modulation.
119
+ num_bands : int
120
+ The number of desired subbands or filters.
121
+
122
+ Returns
123
+ -------
124
+ torch.Tensor
125
+ A bank of cosine modulated filters.
126
+ """
127
+
128
+ # Initialize indices for modulation.
129
+ subband_indices = torch.arange(num_bands).reshape(-1, 1)
130
+
131
+ # Calculate the length of the prototype filter.
132
+ filter_length = prototype_filter.shape[-1]
133
+
134
+ # Generate symmetric time indices centered around zero.
135
+ time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
136
+
137
+ # Calculate phase offsets to ensure orthogonality between subbands.
138
+ phase_offsets = (-1)**subband_indices * np.pi / 4
139
+
140
+ # Compute the cosine modulation function.
141
+ modulation = torch.cos(
142
+ (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
143
+ )
144
+
145
+ # Apply modulation to the prototype filter.
146
+ modulated_filters = 2 * prototype_filter * modulation
147
+
148
+ return modulated_filters
149
+
150
+
151
+ def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
152
+ """
153
+ Design a lowpass filter using the Kaiser window.
154
+
155
+ Parameters
156
+ ----------
157
+ angular_cutoff : float
158
+ The angular frequency cutoff of the filter.
159
+ attenuation : float
160
+ The desired stopband attenuation in decibels (dB).
161
+ filter_length : int, optional
162
+ Desired length of the filter. If not provided, it's computed based on the given specs.
163
+
164
+ Returns
165
+ -------
166
+ ndarray
167
+ The designed lowpass filter coefficients.
168
+ """
169
+
170
+ estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
171
+
172
+ # Ensure the estimated length is odd.
173
+ estimated_length = 2 * (estimated_length // 2) + 1
174
+
175
+ if filter_length is None:
176
+ filter_length = estimated_length
177
+
178
+ return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
179
+
180
+
181
+ def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
182
+ """
183
+ Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
184
+
185
+ Parameters
186
+ ----------
187
+ angular_cutoff : float
188
+ Angular frequency cutoff of the filter.
189
+ attenuation : float
190
+ Desired stopband attenuation in dB.
191
+ num_bands : int
192
+ Number of bands for the multiband filter system.
193
+ filter_length : int, optional
194
+ Desired length of the filter.
195
+
196
+ Returns
197
+ -------
198
+ float
199
+ The computed objective (loss) value for the given filter specs.
200
+ """
201
+
202
+ filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
203
+ convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
204
+
205
+ return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
206
+
207
+
208
+ def design_prototype_filter(attenuation, num_bands, filter_length=None):
209
+ """
210
+ Design the optimal prototype filter for a multiband system given the desired specs.
211
+
212
+ Parameters
213
+ ----------
214
+ attenuation : float
215
+ The desired stopband attenuation in dB.
216
+ num_bands : int
217
+ Number of bands for the multiband filter system.
218
+ filter_length : int, optional
219
+ Desired length of the filter. If not provided, it's computed based on the given specs.
220
+
221
+ Returns
222
+ -------
223
+ ndarray
224
+ The optimal prototype filter coefficients.
225
+ """
226
+
227
+ optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
228
+ 1 / num_bands, disp=0)[0]
229
+
230
+ prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
231
+ return torch.tensor(prototype_filter, dtype=torch.float32)
232
+
233
+ def pad_to_nearest_power_of_two(x):
234
+ """
235
+ Pads the input tensor 'x' on both sides such that its last dimension
236
+ becomes the nearest larger power of two.
237
+
238
+ Parameters:
239
+ -----------
240
+ x : torch.Tensor
241
+ The input tensor to be padded.
242
+
243
+ Returns:
244
+ --------
245
+ torch.Tensor
246
+ The padded tensor.
247
+ """
248
+ current_length = x.shape[-1]
249
+ target_length = 2**math.ceil(math.log2(current_length))
250
+
251
+ total_padding = target_length - current_length
252
+ left_padding = total_padding // 2
253
+ right_padding = total_padding - left_padding
254
+
255
+ return nn.functional.pad(x, (left_padding, right_padding))
256
+
257
+ def apply_alias_cancellation(x):
258
+ """
259
+ Applies alias cancellation by inverting the sign of every
260
+ second element of every second row, starting from the second
261
+ row's first element in a tensor.
262
+
263
+ This operation helps ensure that the aliasing introduced in
264
+ each band during the decomposition will be counteracted during
265
+ the reconstruction.
266
+
267
+ Parameters:
268
+ -----------
269
+ x : torch.Tensor
270
+ The input tensor.
271
+
272
+ Returns:
273
+ --------
274
+ torch.Tensor
275
+ Tensor with specific elements' sign inverted for alias cancellation.
276
+ """
277
+
278
+ # Create a mask of the same shape as 'x', initialized with all ones
279
+ mask = torch.ones_like(x)
280
+
281
+ # Update specific elements in the mask to -1 to perform inversion
282
+ mask[..., 1::2, ::2] = -1
283
+
284
+ # Apply the mask to the input tensor 'x'
285
+ return x * mask
286
+
287
+ def ensure_odd_length(tensor):
288
+ """
289
+ Pads the last dimension of a tensor to ensure its size is odd.
290
+
291
+ Parameters:
292
+ -----------
293
+ tensor : torch.Tensor
294
+ Input tensor whose last dimension might need padding.
295
+
296
+ Returns:
297
+ --------
298
+ torch.Tensor
299
+ The original tensor if its last dimension was already odd,
300
+ or the padded tensor with an odd-sized last dimension.
301
+ """
302
+
303
+ last_dim_size = tensor.shape[-1]
304
+
305
+ if last_dim_size % 2 == 0:
306
+ tensor = nn.functional.pad(tensor, (0, 1))
307
+
308
+ return tensor
309
+
310
+ def polyphase_analysis(signal, filter_bank):
311
+ """
312
+ Applies the polyphase method to efficiently analyze the signal using a filter bank.
313
+
314
+ Parameters:
315
+ -----------
316
+ signal : torch.Tensor
317
+ Input signal tensor with shape (Batch x Channels x Length).
318
+
319
+ filter_bank : torch.Tensor
320
+ Filter bank tensor with shape (Bands x Length).
321
+
322
+ Returns:
323
+ --------
324
+ torch.Tensor
325
+ Signal split into sub-bands. (Batch x Channels x Bands x Length)
326
+ """
327
+
328
+ num_bands = filter_bank.shape[0]
329
+ num_channels = signal.shape[1]
330
+
331
+ # Rearrange signal for polyphase processing.
332
+ # Also combine Batch x Channel into one dimension for now.
333
+ #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
334
+ signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
335
+
336
+ # Rearrange the filter bank for matching signal shape
337
+ filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
338
+
339
+ # Apply convolution with appropriate padding to maintain spatial dimensions
340
+ padding = filter_bank.shape[-1] // 2
341
+ filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
342
+
343
+ # Truncate the last dimension post-convolution to adjust the output shape
344
+ filtered_signal = filtered_signal[..., :-1]
345
+ # Rearrange the first dimension back into Batch x Channels
346
+ filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
347
+
348
+ return filtered_signal
349
+
350
+ def polyphase_synthesis(signal, filter_bank):
351
+ """
352
+ Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
353
+
354
+ Parameters
355
+ ----------
356
+ signal : torch.Tensor
357
+ Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
358
+
359
+ filter_bank : torch.Tensor
360
+ Analysis filter bank (shape: Bands x Length).
361
+
362
+ should_rearrange : bool, optional
363
+ Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
364
+
365
+ Returns
366
+ -------
367
+ torch.Tensor
368
+ Reconstructed signal (shape: Batch x Channels X Length)
369
+ """
370
+
371
+ num_bands = filter_bank.shape[0]
372
+ num_channels = signal.shape[1]
373
+
374
+ # Rearrange the filter bank
375
+ filter_bank = filter_bank.flip(-1)
376
+ filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
377
+
378
+ # Combine Batch x Channels into one dimension for now.
379
+ signal = rearrange(signal, "b c n t -> (b c) n t")
380
+
381
+ # Apply convolution with appropriate padding
382
+ padding_amount = filter_bank.shape[-1] // 2 + 1
383
+ reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
384
+
385
+ # Scale the result
386
+ reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
387
+
388
+ # Reorganize the output and truncate
389
+ reconstructed_signal = reconstructed_signal.flip(1)
390
+ reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
391
+ reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
392
+
393
+ return reconstructed_signal
stable_audio_tools/models/pretrained.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from .factory import create_model_from_config
4
+ from .utils import load_ckpt_state_dict
5
+
6
+ from huggingface_hub import hf_hub_download
7
+ import torch
8
+
9
+ def get_pretrained_model(name: str):
10
+
11
+ model_config_path = hf_hub_download(name, filename="config.json", repo_type='model')
12
+
13
+ with open(model_config_path) as f:
14
+ model_config = json.load(f)
15
+
16
+ model = create_model_from_config(model_config)
17
+
18
+ # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file
19
+ try:
20
+ model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model')
21
+ except Exception as e:
22
+ model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model')
23
+
24
+ # Load state dict with strict=False to ignore missing keys
25
+ state_dict = load_ckpt_state_dict(model_ckpt_path)
26
+ model.load_state_dict(state_dict, strict=False)
27
+
28
+ # Initialize missing position_ids if needed
29
+ if hasattr(model.conditioner.conditioners.video_prompt.visual_encoder_model.vision_model.embeddings, 'num_positions'):
30
+ num_positions = model.conditioner.conditioners.video_prompt.visual_encoder_model.vision_model.embeddings.num_positions
31
+ model.conditioner.conditioners.video_prompt.visual_encoder_model.vision_model.embeddings.position_ids = torch.arange(0, num_positions, dtype=torch.long)
32
+
33
+ return model, model_config
stable_audio_tools/models/pretransforms.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+
5
+ class Pretransform(nn.Module):
6
+ def __init__(self, enable_grad, io_channels, is_discrete):
7
+ super().__init__()
8
+
9
+ self.is_discrete = is_discrete
10
+ self.io_channels = io_channels
11
+ self.encoded_channels = None
12
+ self.downsampling_ratio = None
13
+
14
+ self.enable_grad = enable_grad
15
+
16
+ def encode(self, x):
17
+ raise NotImplementedError
18
+
19
+ def decode(self, z):
20
+ raise NotImplementedError
21
+
22
+ def tokenize(self, x):
23
+ raise NotImplementedError
24
+
25
+ def decode_tokens(self, tokens):
26
+ raise NotImplementedError
27
+
28
+ class AutoencoderPretransform(Pretransform):
29
+ def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
30
+ super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
31
+ self.model = model
32
+ self.model.requires_grad_(False).eval()
33
+ self.scale=scale
34
+ self.downsampling_ratio = model.downsampling_ratio
35
+ self.io_channels = model.io_channels
36
+ self.sample_rate = model.sample_rate
37
+
38
+ self.model_half = model_half
39
+ self.iterate_batch = iterate_batch
40
+
41
+ self.encoded_channels = model.latent_dim
42
+
43
+ self.chunked = chunked
44
+ self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
45
+ self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
46
+
47
+ if self.model_half:
48
+ self.model.half()
49
+
50
+ def encode(self, x, **kwargs):
51
+
52
+ if self.model_half:
53
+ x = x.half()
54
+ self.model.to(torch.float16)
55
+
56
+ encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
57
+
58
+ if self.model_half:
59
+ encoded = encoded.float()
60
+
61
+ return encoded / self.scale
62
+
63
+ def decode(self, z, **kwargs):
64
+ z = z * self.scale
65
+
66
+ if self.model_half:
67
+ z = z.half()
68
+ self.model.to(torch.float16)
69
+
70
+ decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
71
+
72
+ if self.model_half:
73
+ decoded = decoded.float()
74
+
75
+ return decoded
76
+
77
+ def tokenize(self, x, **kwargs):
78
+ assert self.model.is_discrete, "Cannot tokenize with a continuous model"
79
+
80
+ _, info = self.model.encode(x, return_info = True, **kwargs)
81
+
82
+ return info[self.model.bottleneck.tokens_id]
83
+
84
+ def decode_tokens(self, tokens, **kwargs):
85
+ assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
86
+
87
+ return self.model.decode_tokens(tokens, **kwargs)
88
+
89
+ def load_state_dict(self, state_dict, strict=True):
90
+ self.model.load_state_dict(state_dict, strict=strict)
91
+
92
+ class WaveletPretransform(Pretransform):
93
+ def __init__(self, channels, levels, wavelet):
94
+ super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
95
+
96
+ from .wavelets import WaveletEncode1d, WaveletDecode1d
97
+
98
+ self.encoder = WaveletEncode1d(channels, levels, wavelet)
99
+ self.decoder = WaveletDecode1d(channels, levels, wavelet)
100
+
101
+ self.downsampling_ratio = 2 ** levels
102
+ self.io_channels = channels
103
+ self.encoded_channels = channels * self.downsampling_ratio
104
+
105
+ def encode(self, x):
106
+ return self.encoder(x)
107
+
108
+ def decode(self, z):
109
+ return self.decoder(z)
110
+
111
+ class PQMFPretransform(Pretransform):
112
+ def __init__(self, attenuation=100, num_bands=16):
113
+ # TODO: Fix PQMF to take in in-channels
114
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
115
+ from .pqmf import PQMF
116
+ self.pqmf = PQMF(attenuation, num_bands)
117
+
118
+
119
+ def encode(self, x):
120
+ # x is (Batch x Channels x Time)
121
+ x = self.pqmf.forward(x)
122
+ # pqmf.forward returns (Batch x Channels x Bands x Time)
123
+ # but Pretransform needs Batch x Channels x Time
124
+ # so concatenate channels and bands into one axis
125
+ return rearrange(x, "b c n t -> b (c n) t")
126
+
127
+ def decode(self, x):
128
+ # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
129
+ x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
130
+ # returns (Batch x Channels x Time)
131
+ return self.pqmf.inverse(x)
132
+
133
+ class PretrainedDACPretransform(Pretransform):
134
+ def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
135
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
136
+
137
+ import dac
138
+
139
+ model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
140
+
141
+ self.model = dac.DAC.load(model_path)
142
+
143
+ self.quantize_on_decode = quantize_on_decode
144
+
145
+ if model_type == "44khz":
146
+ self.downsampling_ratio = 512
147
+ else:
148
+ self.downsampling_ratio = 320
149
+
150
+ self.io_channels = 1
151
+
152
+ self.scale = scale
153
+
154
+ self.chunked = chunked
155
+
156
+ self.encoded_channels = self.model.latent_dim
157
+
158
+ self.num_quantizers = self.model.n_codebooks
159
+
160
+ self.codebook_size = self.model.codebook_size
161
+
162
+ def encode(self, x):
163
+
164
+ latents = self.model.encoder(x)
165
+
166
+ if self.quantize_on_decode:
167
+ output = latents
168
+ else:
169
+ z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
170
+ output = z
171
+
172
+ if self.scale != 1.0:
173
+ output = output / self.scale
174
+
175
+ return output
176
+
177
+ def decode(self, z):
178
+
179
+ if self.scale != 1.0:
180
+ z = z * self.scale
181
+
182
+ if self.quantize_on_decode:
183
+ z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
184
+
185
+ return self.model.decode(z)
186
+
187
+ def tokenize(self, x):
188
+ return self.model.encode(x)[1]
189
+
190
+ def decode_tokens(self, tokens):
191
+ latents = self.model.quantizer.from_codes(tokens)
192
+ return self.model.decode(latents)
193
+
194
+ class AudiocraftCompressionPretransform(Pretransform):
195
+ def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
196
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
197
+
198
+ try:
199
+ from audiocraft.models import CompressionModel
200
+ except ImportError:
201
+ raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
202
+
203
+ self.model = CompressionModel.get_pretrained(model_type)
204
+
205
+ self.quantize_on_decode = quantize_on_decode
206
+
207
+ self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
208
+
209
+ self.sample_rate = self.model.sample_rate
210
+
211
+ self.io_channels = self.model.channels
212
+
213
+ self.scale = scale
214
+
215
+ #self.encoded_channels = self.model.latent_dim
216
+
217
+ self.num_quantizers = self.model.num_codebooks
218
+
219
+ self.codebook_size = self.model.cardinality
220
+
221
+ self.model.to(torch.float16).eval().requires_grad_(False)
222
+
223
+ def encode(self, x):
224
+
225
+ assert False, "Audiocraft compression models do not support continuous encoding"
226
+
227
+ # latents = self.model.encoder(x)
228
+
229
+ # if self.quantize_on_decode:
230
+ # output = latents
231
+ # else:
232
+ # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
233
+ # output = z
234
+
235
+ # if self.scale != 1.0:
236
+ # output = output / self.scale
237
+
238
+ # return output
239
+
240
+ def decode(self, z):
241
+
242
+ assert False, "Audiocraft compression models do not support continuous decoding"
243
+
244
+ # if self.scale != 1.0:
245
+ # z = z * self.scale
246
+
247
+ # if self.quantize_on_decode:
248
+ # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
249
+
250
+ # return self.model.decode(z)
251
+
252
+ def tokenize(self, x):
253
+ with torch.cuda.amp.autocast(enabled=False):
254
+ return self.model.encode(x.to(torch.float16))[0]
255
+
256
+ def decode_tokens(self, tokens):
257
+ with torch.cuda.amp.autocast(enabled=False):
258
+ return self.model.decode(tokens)
stable_audio_tools/models/temptransformer.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+
5
+ from einops import rearrange, repeat
6
+ from einops.layers.torch import Rearrange
7
+
8
+ class Residual(nn.Module):
9
+ def __init__(self, fn):
10
+ super().__init__()
11
+ self.fn = fn
12
+ def forward(self, x, **kwargs):
13
+ return self.fn(x, **kwargs) + x
14
+
15
+ class SA_PreNorm(nn.Module):
16
+ def __init__(self, dim, fn):
17
+ super().__init__()
18
+ self.norm = nn.LayerNorm(dim)
19
+ self.fn = fn
20
+ def forward(self, x, **kwargs):
21
+ return self.fn(self.norm(x), **kwargs)
22
+
23
+ class SA_FeedForward(nn.Module):
24
+ def __init__(self, dim, hidden_dim, dropout = 0.):
25
+ super().__init__()
26
+ self.net = nn.Sequential(
27
+ nn.Linear(dim, hidden_dim),
28
+ nn.GELU(),
29
+ nn.Dropout(dropout),
30
+ nn.Linear(hidden_dim, dim),
31
+ nn.Dropout(dropout)
32
+ )
33
+ def forward(self, x):
34
+ return self.net(x)
35
+
36
+ class SA_Attention(nn.Module):
37
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
38
+ super().__init__()
39
+ inner_dim = dim_head * heads
40
+ project_out = not (heads == 1 and dim_head == dim)
41
+
42
+ self.heads = heads
43
+ self.scale = dim_head ** -0.5
44
+
45
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
46
+
47
+ self.to_out = nn.Sequential(
48
+ nn.Linear(inner_dim, dim),
49
+ nn.Dropout(dropout)
50
+ ) if project_out else nn.Identity()
51
+
52
+ def forward(self, x):
53
+ b, n, _, h = *x.shape, self.heads
54
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
55
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
56
+
57
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
58
+
59
+ attn = dots.softmax(dim=-1)
60
+
61
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
62
+ out = rearrange(out, 'b h n d -> b n (h d)')
63
+ out = self.to_out(out)
64
+ return out
65
+
66
+
67
+ class ReAttention(nn.Module):
68
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
69
+ super().__init__()
70
+ inner_dim = dim_head * heads
71
+ self.heads = heads
72
+ self.scale = dim_head ** -0.5
73
+
74
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
75
+
76
+ self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
77
+
78
+ self.reattn_norm = nn.Sequential(
79
+ Rearrange('b h i j -> b i j h'),
80
+ nn.LayerNorm(heads),
81
+ Rearrange('b i j h -> b h i j')
82
+ )
83
+
84
+ self.to_out = nn.Sequential(
85
+ nn.Linear(inner_dim, dim),
86
+ nn.Dropout(dropout)
87
+ )
88
+
89
+ def forward(self, x):
90
+ b, n, _, h = *x.shape, self.heads
91
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
92
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
93
+
94
+ # attention
95
+
96
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
97
+ attn = dots.softmax(dim=-1)
98
+
99
+ # re-attention
100
+
101
+ attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
102
+ attn = self.reattn_norm(attn)
103
+
104
+ # aggregate and out
105
+
106
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
107
+ out = rearrange(out, 'b h n d -> b n (h d)')
108
+ out = self.to_out(out)
109
+ return out
110
+
111
+ class LeFF(nn.Module):
112
+
113
+ def __init__(self, dim = 192, scale = 4, depth_kernel = 3):
114
+ super().__init__()
115
+
116
+ scale_dim = dim*scale
117
+ self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim),
118
+ Rearrange('b n c -> b c n'),
119
+ nn.BatchNorm1d(scale_dim),
120
+ nn.GELU(),
121
+ Rearrange('b c (h w) -> b c h w', h=14, w=14)
122
+ )
123
+
124
+ self.depth_conv = nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False),
125
+ nn.BatchNorm2d(scale_dim),
126
+ nn.GELU(),
127
+ Rearrange('b c h w -> b (h w) c', h=14, w=14)
128
+ )
129
+
130
+ self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim),
131
+ Rearrange('b n c -> b c n'),
132
+ nn.BatchNorm1d(dim),
133
+ nn.GELU(),
134
+ Rearrange('b c n -> b n c')
135
+ )
136
+
137
+ def forward(self, x):
138
+ x = self.up_proj(x)
139
+ x = self.depth_conv(x)
140
+ x = self.down_proj(x)
141
+ return x
142
+
143
+
144
+ class LCAttention(nn.Module):
145
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
146
+ super().__init__()
147
+ inner_dim = dim_head * heads
148
+ project_out = not (heads == 1 and dim_head == dim)
149
+
150
+ self.heads = heads
151
+ self.scale = dim_head ** -0.5
152
+
153
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
154
+
155
+ self.to_out = nn.Sequential(
156
+ nn.Linear(inner_dim, dim),
157
+ nn.Dropout(dropout)
158
+ ) if project_out else nn.Identity()
159
+
160
+ def forward(self, x):
161
+ b, n, _, h = *x.shape, self.heads
162
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
163
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
164
+ q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query
165
+
166
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
167
+
168
+ attn = dots.softmax(dim=-1)
169
+
170
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
171
+ out = rearrange(out, 'b h n d -> b n (h d)')
172
+ out = self.to_out(out)
173
+ return out
174
+
175
+ class SA_Transformer(nn.Module):
176
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
177
+ super().__init__()
178
+ self.layers = nn.ModuleList([])
179
+ self.norm = nn.LayerNorm(dim)
180
+ for _ in range(depth):
181
+ self.layers.append(nn.ModuleList([
182
+ SA_PreNorm(dim, SA_Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
183
+ SA_PreNorm(dim, SA_FeedForward(dim, mlp_dim, dropout = dropout))
184
+ ]))
185
+
186
+ def forward(self, x):
187
+ for attn, ff in self.layers:
188
+ x = attn(x) + x
189
+ x = ff(x) + x
190
+ return self.norm(x)
stable_audio_tools/models/transformer.py ADDED
@@ -0,0 +1,812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce, partial
2
+ from packaging import version
3
+
4
+ from einops import rearrange, repeat
5
+ from einops.layers.torch import Rearrange
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn, einsum
9
+ from torch.cuda.amp import autocast
10
+ from typing import Callable, Literal
11
+ import warnings
12
+ warnings.simplefilter(action='ignore', category=FutureWarning)
13
+
14
+ try:
15
+ from flash_attn import flash_attn_func, flash_attn_kvpacked_func
16
+ except ImportError as e:
17
+ print(e)
18
+ print('flash_attn not installed, disabling Flash Attention')
19
+ flash_attn_kvpacked_func = None
20
+ flash_attn_func = None
21
+
22
+ try:
23
+ import natten
24
+ except ImportError:
25
+ natten = None
26
+
27
+ def checkpoint(function, *args, **kwargs):
28
+ kwargs.setdefault("use_reentrant", False)
29
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
30
+
31
+
32
+ # Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
33
+ # License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
34
+
35
+ def create_causal_mask(i, j, device):
36
+ return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
37
+
38
+ def or_reduce(masks):
39
+ head, *body = masks
40
+ for rest in body:
41
+ head = head | rest
42
+ return head
43
+
44
+ # positional embeddings
45
+
46
+ class AbsolutePositionalEmbedding(nn.Module):
47
+ def __init__(self, dim, max_seq_len):
48
+ super().__init__()
49
+ self.scale = dim ** -0.5
50
+ self.max_seq_len = max_seq_len
51
+ self.emb = nn.Embedding(max_seq_len, dim)
52
+
53
+ def forward(self, x, pos = None, seq_start_pos = None):
54
+ seq_len, device = x.shape[1], x.device
55
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
56
+
57
+ if pos is None:
58
+ pos = torch.arange(seq_len, device = device)
59
+
60
+ if seq_start_pos is not None:
61
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
62
+
63
+ pos_emb = self.emb(pos)
64
+ pos_emb = pos_emb * self.scale
65
+ return pos_emb
66
+
67
+ class ScaledSinusoidalEmbedding(nn.Module):
68
+ def __init__(self, dim, theta = 10000):
69
+ super().__init__()
70
+ assert (dim % 2) == 0, 'dimension must be divisible by 2'
71
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
72
+
73
+ half_dim = dim // 2
74
+ freq_seq = torch.arange(half_dim).float() / half_dim
75
+ inv_freq = theta ** -freq_seq
76
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
77
+
78
+ def forward(self, x, pos = None, seq_start_pos = None):
79
+ seq_len, device = x.shape[1], x.device
80
+
81
+ if pos is None:
82
+ pos = torch.arange(seq_len, device = device)
83
+
84
+ if seq_start_pos is not None:
85
+ pos = pos - seq_start_pos[..., None]
86
+
87
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
88
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
89
+ return emb * self.scale
90
+
91
+ class RotaryEmbedding(nn.Module):
92
+ def __init__(
93
+ self,
94
+ dim,
95
+ use_xpos = False,
96
+ scale_base = 512,
97
+ interpolation_factor = 1.,
98
+ base = 10000,
99
+ base_rescale_factor = 1.
100
+ ):
101
+ super().__init__()
102
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
103
+ # has some connection to NTK literature
104
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
105
+ base *= base_rescale_factor ** (dim / (dim - 2))
106
+
107
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
108
+ self.register_buffer('inv_freq', inv_freq)
109
+
110
+ assert interpolation_factor >= 1.
111
+ self.interpolation_factor = interpolation_factor
112
+
113
+ if not use_xpos:
114
+ self.register_buffer('scale', None)
115
+ return
116
+
117
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
118
+
119
+ self.scale_base = scale_base
120
+ self.register_buffer('scale', scale)
121
+
122
+ def forward_from_seq_len(self, seq_len):
123
+ device = self.inv_freq.device
124
+
125
+ t = torch.arange(seq_len, device = device)
126
+ return self.forward(t)
127
+
128
+ @autocast(enabled = False)
129
+ def forward(self, t):
130
+ device = self.inv_freq.device
131
+
132
+ t = t.to(torch.float32)
133
+
134
+ t = t / self.interpolation_factor
135
+
136
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
137
+ freqs = torch.cat((freqs, freqs), dim = -1)
138
+
139
+ if self.scale is None:
140
+ return freqs, 1.
141
+
142
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
143
+ scale = self.scale ** rearrange(power, 'n -> n 1')
144
+ scale = torch.cat((scale, scale), dim = -1)
145
+
146
+ return freqs, scale
147
+
148
+ def rotate_half(x):
149
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
150
+ x1, x2 = x.unbind(dim = -2)
151
+ return torch.cat((-x2, x1), dim = -1)
152
+
153
+ @autocast(enabled = False)
154
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
155
+ out_dtype = t.dtype
156
+
157
+ # cast to float32 if necessary for numerical stability
158
+ dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
159
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
160
+ freqs, t = freqs.to(dtype), t.to(dtype)
161
+ freqs = freqs[-seq_len:, :]
162
+
163
+ if t.ndim == 4 and freqs.ndim == 3:
164
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
165
+
166
+ # partial rotary embeddings, Wang et al. GPT-J
167
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
168
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
169
+
170
+ t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
171
+
172
+ return torch.cat((t, t_unrotated), dim = -1)
173
+
174
+ # norms
175
+ class LayerNorm(nn.Module):
176
+ def __init__(self, dim, bias=False, fix_scale=False):
177
+ """
178
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
179
+ """
180
+ super().__init__()
181
+
182
+ if fix_scale:
183
+ self.register_buffer("gamma", torch.ones(dim))
184
+ else:
185
+ self.gamma = nn.Parameter(torch.ones(dim))
186
+
187
+ if bias:
188
+ self.beta = nn.Parameter(torch.zeros(dim))
189
+ else:
190
+ self.register_buffer("beta", torch.zeros(dim))
191
+
192
+
193
+ def forward(self, x):
194
+ return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
195
+
196
+ # feedforward
197
+
198
+ class GLU(nn.Module):
199
+ def __init__(
200
+ self,
201
+ dim_in,
202
+ dim_out,
203
+ activation: Callable,
204
+ use_conv = False,
205
+ conv_kernel_size = 3,
206
+ ):
207
+ super().__init__()
208
+ self.act = activation
209
+ self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
210
+ self.use_conv = use_conv
211
+
212
+ def forward(self, x):
213
+ if self.use_conv:
214
+ x = rearrange(x, 'b n d -> b d n')
215
+ x = self.proj(x)
216
+ x = rearrange(x, 'b d n -> b n d')
217
+ else:
218
+ x = self.proj(x)
219
+
220
+ x, gate = x.chunk(2, dim = -1)
221
+ return x * self.act(gate)
222
+
223
+ class FeedForward(nn.Module):
224
+ def __init__(
225
+ self,
226
+ dim,
227
+ dim_out = None,
228
+ mult = 4,
229
+ no_bias = False,
230
+ glu = True,
231
+ use_conv = False,
232
+ conv_kernel_size = 3,
233
+ zero_init_output = True,
234
+ ):
235
+ super().__init__()
236
+ inner_dim = int(dim * mult)
237
+
238
+ # Default to SwiGLU
239
+
240
+ activation = nn.SiLU()
241
+
242
+ dim_out = dim if dim_out is None else dim_out
243
+
244
+ if glu:
245
+ linear_in = GLU(dim, inner_dim, activation)
246
+ else:
247
+ linear_in = nn.Sequential(
248
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
249
+ nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
250
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
251
+ activation
252
+ )
253
+
254
+ linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
255
+
256
+ # init last linear layer to 0
257
+ if zero_init_output:
258
+ nn.init.zeros_(linear_out.weight)
259
+ if not no_bias:
260
+ nn.init.zeros_(linear_out.bias)
261
+
262
+
263
+ self.ff = nn.Sequential(
264
+ linear_in,
265
+ Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
266
+ linear_out,
267
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
268
+ )
269
+
270
+ def forward(self, x):
271
+ return self.ff(x)
272
+
273
+ class Attention(nn.Module):
274
+ def __init__(
275
+ self,
276
+ dim,
277
+ dim_heads = 64,
278
+ dim_context = None,
279
+ causal = False,
280
+ zero_init_output=True,
281
+ qk_norm: Literal['l2', 'ln', 'none'] = 'none',
282
+ natten_kernel_size = None
283
+ ):
284
+ super().__init__()
285
+ self.dim = dim
286
+ self.dim_heads = dim_heads
287
+ self.causal = causal
288
+
289
+ dim_kv = dim_context if dim_context is not None else dim
290
+
291
+ self.num_heads = dim // dim_heads
292
+ self.kv_heads = dim_kv // dim_heads
293
+
294
+ if dim_context is not None:
295
+ self.to_q = nn.Linear(dim, dim, bias=False)
296
+ self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
297
+ else:
298
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
299
+
300
+ self.to_out = nn.Linear(dim, dim, bias=False)
301
+
302
+ if zero_init_output:
303
+ nn.init.zeros_(self.to_out.weight)
304
+
305
+ self.qk_norm = qk_norm
306
+
307
+ if self.qk_norm == "ln":
308
+ self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
309
+ self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
310
+
311
+ # Using 1d neighborhood attention
312
+ self.natten_kernel_size = natten_kernel_size
313
+ if natten_kernel_size is not None:
314
+ return
315
+
316
+ self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
317
+
318
+ self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
319
+
320
+ self.sdp_kwargs = dict(
321
+ enable_flash = True,
322
+ enable_math = True,
323
+ enable_mem_efficient = True
324
+ )
325
+
326
+ def flash_attn(
327
+ self,
328
+ q,
329
+ k,
330
+ v,
331
+ mask = None,
332
+ causal = None
333
+ ):
334
+ batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
335
+ kv_heads = k.shape[1]
336
+ # Recommended for multi-query single-key-value attention by Tri Dao
337
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
338
+
339
+ if heads != kv_heads:
340
+ # Repeat interleave kv_heads to match q_heads
341
+ heads_per_kv_head = heads // kv_heads
342
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
343
+
344
+ if k.ndim == 3:
345
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
346
+
347
+ if v.ndim == 3:
348
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
349
+
350
+ causal = self.causal if causal is None else causal
351
+
352
+ if q_len == 1 and causal:
353
+ causal = False
354
+
355
+ if mask is not None:
356
+ assert mask.ndim == 4
357
+ mask = mask.expand(batch, heads, q_len, k_len)
358
+
359
+ # handle kv cache - this should be bypassable in updated flash attention 2
360
+
361
+ if k_len > q_len and causal:
362
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
363
+ if mask is None:
364
+ mask = ~causal_mask
365
+ else:
366
+ mask = mask & ~causal_mask
367
+ causal = False
368
+
369
+ # manually handle causal mask, if another mask was given
370
+
371
+ row_is_entirely_masked = None
372
+
373
+ if mask is not None and causal:
374
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
375
+ mask = mask & ~causal_mask
376
+
377
+ # protect against an entire row being masked out
378
+
379
+ row_is_entirely_masked = ~mask.any(dim = -1)
380
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
381
+
382
+ causal = False
383
+
384
+ with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
385
+ out = F.scaled_dot_product_attention(
386
+ q, k, v,
387
+ attn_mask = mask,
388
+ is_causal = causal
389
+ )
390
+
391
+ # for a row that is entirely masked out, should zero out the output of that row token
392
+
393
+ if row_is_entirely_masked is not None:
394
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
395
+
396
+ return out
397
+
398
+ def forward(
399
+ self,
400
+ x,
401
+ context = None,
402
+ mask = None,
403
+ context_mask = None,
404
+ rotary_pos_emb = None,
405
+ causal = None
406
+ ):
407
+ h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
408
+
409
+ kv_input = context if has_context else x
410
+
411
+ if hasattr(self, 'to_q'):
412
+ # Use separate linear projections for q and k/v
413
+ q = self.to_q(x)
414
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h) # [B, 24, 1025, 64]
415
+
416
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
417
+
418
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
419
+ else:
420
+ # Use fused linear projection
421
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
422
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
423
+
424
+ # Normalize q and k for cosine sim attention
425
+ if self.qk_norm == "l2":
426
+ q = F.normalize(q, dim=-1)
427
+ k = F.normalize(k, dim=-1)
428
+ elif self.qk_norm == "ln":
429
+ q = self.q_norm(q)
430
+ k = self.k_norm(k)
431
+
432
+ if rotary_pos_emb is not None and not has_context:
433
+ freqs, _ = rotary_pos_emb
434
+
435
+ q_dtype = q.dtype
436
+ k_dtype = k.dtype
437
+
438
+ q = q.to(torch.float32)
439
+ k = k.to(torch.float32)
440
+ freqs = freqs.to(torch.float32)
441
+
442
+ q = apply_rotary_pos_emb(q, freqs)
443
+ k = apply_rotary_pos_emb(k, freqs)
444
+
445
+ q = q.to(q_dtype)
446
+ k = k.to(k_dtype)
447
+
448
+ input_mask = context_mask
449
+
450
+ if input_mask is None and not has_context:
451
+ input_mask = mask
452
+
453
+ # determine masking
454
+ masks = []
455
+ final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
456
+
457
+ if input_mask is not None:
458
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
459
+ masks.append(~input_mask)
460
+
461
+ # Other masks will be added here later
462
+
463
+ if len(masks) > 0:
464
+ final_attn_mask = ~or_reduce(masks)
465
+
466
+ n, device = q.shape[-2], q.device
467
+
468
+ causal = self.causal if causal is None else causal
469
+
470
+ if n == 1 and causal:
471
+ causal = False
472
+
473
+ if self.natten_kernel_size is not None:
474
+ if natten is None:
475
+ raise ImportError('natten not installed, please install natten to use neighborhood attention')
476
+
477
+ dtype_in = q.dtype
478
+ q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
479
+
480
+ attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1)
481
+
482
+ if final_attn_mask is not None:
483
+ attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
484
+
485
+ attn = F.softmax(attn, dim=-1, dtype=torch.float32)
486
+
487
+ out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in)
488
+
489
+ # Prioritize Flash Attention 2
490
+ elif self.use_fa_flash:
491
+ assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2'
492
+ # Flash Attention 2 requires FP16 inputs
493
+ fa_dtype_in = q.dtype
494
+ q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
495
+
496
+ out = flash_attn_func(q, k, v, causal = causal)
497
+
498
+ out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
499
+
500
+ # Fall back to PyTorch implementation
501
+ elif self.use_pt_flash:
502
+ out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask)
503
+
504
+ else:
505
+ # Fall back to custom implementation
506
+
507
+ if h != kv_h:
508
+ # Repeat interleave kv_heads to match q_heads
509
+ heads_per_kv_head = h // kv_h
510
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
511
+
512
+ scale = 1. / (q.shape[-1] ** 0.5)
513
+
514
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
515
+
516
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
517
+
518
+ i, j, dtype = *dots.shape[-2:], dots.dtype
519
+
520
+ mask_value = -torch.finfo(dots.dtype).max
521
+
522
+ if final_attn_mask is not None:
523
+ dots = dots.masked_fill(~final_attn_mask, mask_value)
524
+
525
+ if causal:
526
+ causal_mask = self.create_causal_mask(i, j, device = device)
527
+ dots = dots.masked_fill(causal_mask, mask_value)
528
+
529
+ attn = F.softmax(dots, dim=-1, dtype=torch.float32)
530
+ attn = attn.type(dtype)
531
+
532
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
533
+
534
+ # merge heads
535
+ out = rearrange(out, ' b h n d -> b n (h d)')
536
+
537
+ # Communicate between heads
538
+ out = self.to_out(out)
539
+
540
+ if mask is not None:
541
+ mask = rearrange(mask, 'b n -> b n 1')
542
+ out = out.masked_fill(~mask, 0.)
543
+
544
+ return out
545
+
546
+
547
+ class ConformerModule(nn.Module):
548
+ def __init__(
549
+ self,
550
+ dim,
551
+ norm_kwargs = {},
552
+ ):
553
+
554
+ super().__init__()
555
+
556
+ self.dim = dim
557
+
558
+ self.in_norm = LayerNorm(dim, **norm_kwargs)
559
+ self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
560
+ self.glu = GLU(dim, dim, nn.SiLU())
561
+ self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
562
+ self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
563
+ self.swish = nn.SiLU()
564
+ self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
565
+
566
+ def forward(self, x):
567
+ x = self.in_norm(x)
568
+ x = rearrange(x, 'b n d -> b d n')
569
+ x = self.pointwise_conv(x)
570
+ x = rearrange(x, 'b d n -> b n d')
571
+ x = self.glu(x)
572
+ x = rearrange(x, 'b n d -> b d n')
573
+ x = self.depthwise_conv(x)
574
+ x = rearrange(x, 'b d n -> b n d')
575
+ x = self.mid_norm(x)
576
+ x = self.swish(x)
577
+ x = rearrange(x, 'b n d -> b d n')
578
+ x = self.pointwise_conv_2(x)
579
+ x = rearrange(x, 'b d n -> b n d')
580
+
581
+ return x
582
+
583
+ class TransformerBlock(nn.Module):
584
+ def __init__(
585
+ self,
586
+ dim,
587
+ dim_heads = 64,
588
+ cross_attend = False,
589
+ dim_context = None,
590
+ global_cond_dim = None,
591
+ causal = False,
592
+ zero_init_branch_outputs = True,
593
+ conformer = False,
594
+ layer_ix = -1,
595
+ remove_norms = False,
596
+ attn_kwargs = {},
597
+ ff_kwargs = {},
598
+ norm_kwargs = {}
599
+ ):
600
+
601
+ super().__init__()
602
+ self.dim = dim
603
+ self.dim_heads = dim_heads
604
+ self.cross_attend = cross_attend
605
+ self.dim_context = dim_context
606
+ self.causal = causal
607
+
608
+ self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
609
+
610
+ self.self_attn = Attention(
611
+ dim,
612
+ dim_heads = dim_heads,
613
+ causal = causal,
614
+ zero_init_output=zero_init_branch_outputs,
615
+ **attn_kwargs
616
+ )
617
+
618
+ if cross_attend:
619
+ self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
620
+ self.cross_attn = Attention(
621
+ dim,
622
+ dim_heads = dim_heads,
623
+ dim_context=dim_context,
624
+ causal = causal,
625
+ zero_init_output=zero_init_branch_outputs,
626
+ **attn_kwargs
627
+ )
628
+
629
+ self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
630
+ self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
631
+
632
+ self.layer_ix = layer_ix
633
+
634
+ self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
635
+
636
+ self.global_cond_dim = global_cond_dim
637
+
638
+ if global_cond_dim is not None:
639
+ self.to_scale_shift_gate = nn.Sequential(
640
+ nn.SiLU(),
641
+ nn.Linear(global_cond_dim, dim * 6, bias=False)
642
+ )
643
+
644
+ nn.init.zeros_(self.to_scale_shift_gate[1].weight)
645
+
646
+ def forward(
647
+ self,
648
+ x,
649
+ context = None,
650
+ global_cond=None,
651
+ mask = None,
652
+ context_mask = None,
653
+ rotary_pos_emb = None,
654
+ adapter=None
655
+ ):
656
+
657
+ if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: # False
658
+
659
+ scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
660
+
661
+ # self-attention with adaLN
662
+ residual = x
663
+ x = self.pre_norm(x)
664
+ x = x * (1 + scale_self) + shift_self
665
+
666
+ x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
667
+ x = x * torch.sigmoid(1 - gate_self)
668
+ x = x + residual
669
+
670
+ if context is not None:
671
+
672
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
673
+
674
+ if self.conformer is not None:
675
+ x = x + self.conformer(x)
676
+
677
+ # feedforward with adaLN
678
+ residual = x
679
+ x = self.ff_norm(x)
680
+ x = x * (1 + scale_ff) + shift_ff
681
+ x = self.ff(x)
682
+ x = x * torch.sigmoid(1 - gate_ff)
683
+ x = x + residual
684
+
685
+ else:
686
+ x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
687
+
688
+ if context is not None:
689
+
690
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
691
+
692
+ if self.conformer is not None:
693
+ x = x + self.conformer(x)
694
+
695
+ x = x + self.ff(self.ff_norm(x))
696
+
697
+ return x
698
+
699
+ class ContinuousTransformer(nn.Module):
700
+ def __init__(
701
+ self,
702
+ dim,
703
+ depth,
704
+ *,
705
+ dim_in = None,
706
+ dim_out = None,
707
+ dim_heads = 64,
708
+ cross_attend=False,
709
+ cond_token_dim=None,
710
+ global_cond_dim=None,
711
+ causal=False,
712
+ rotary_pos_emb=True,
713
+ zero_init_branch_outputs=True,
714
+ conformer=False,
715
+ use_sinusoidal_emb=False,
716
+ use_abs_pos_emb=False,
717
+ abs_pos_emb_max_length=10000,
718
+ **kwargs
719
+ ):
720
+
721
+ super().__init__()
722
+
723
+ self.dim = dim
724
+ self.depth = depth
725
+ self.causal = causal
726
+ self.layers = nn.ModuleList([])
727
+
728
+ self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
729
+ self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
730
+
731
+ if rotary_pos_emb:
732
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
733
+ else:
734
+ self.rotary_pos_emb = None
735
+
736
+ self.use_sinusoidal_emb = use_sinusoidal_emb
737
+ if use_sinusoidal_emb:
738
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
739
+
740
+ self.use_abs_pos_emb = use_abs_pos_emb
741
+ if use_abs_pos_emb:
742
+ self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
743
+
744
+ for i in range(depth):
745
+ self.layers.append(
746
+ TransformerBlock(
747
+ dim,
748
+ dim_heads = dim_heads,
749
+ cross_attend = cross_attend,
750
+ dim_context = cond_token_dim,
751
+ global_cond_dim = global_cond_dim,
752
+ causal = causal,
753
+ zero_init_branch_outputs = zero_init_branch_outputs,
754
+ conformer=conformer,
755
+ layer_ix=i,
756
+ **kwargs
757
+ )
758
+ )
759
+
760
+ def forward(
761
+ self,
762
+ x,
763
+ mask = None,
764
+ prepend_embeds = None,
765
+ prepend_mask = None,
766
+ global_cond = None,
767
+ return_info = False,
768
+ **kwargs
769
+ ):
770
+ batch, seq, device = *x.shape[:2], x.device
771
+
772
+ info = {
773
+ "hidden_states": [],
774
+ }
775
+
776
+ x = self.project_in(x)
777
+
778
+ if prepend_embeds is not None:
779
+ prepend_length, prepend_dim = prepend_embeds.shape[1:]
780
+
781
+ assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
782
+
783
+ x = torch.cat((prepend_embeds, x), dim = -2)
784
+
785
+ if prepend_mask is not None or mask is not None:
786
+ mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
787
+ prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
788
+
789
+ mask = torch.cat((prepend_mask, mask), dim = -1)
790
+
791
+ # Attention layers
792
+ if self.rotary_pos_emb is not None:
793
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
794
+ else:
795
+ rotary_pos_emb = None
796
+
797
+ if self.use_sinusoidal_emb or self.use_abs_pos_emb:
798
+ x = x + self.pos_emb(x)
799
+
800
+ # Iterate over the transformer layers
801
+ for index, layer in enumerate(self.layers):
802
+ x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
803
+
804
+ if return_info:
805
+ info["hidden_states"].append(x)
806
+
807
+ x = self.project_out(x)
808
+
809
+ if return_info:
810
+ return x, info
811
+
812
+ return x
stable_audio_tools/models/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import load_file
3
+
4
+ from torch.nn.utils import remove_weight_norm
5
+ import warnings
6
+ warnings.simplefilter(action='ignore', category=FutureWarning)
7
+
8
+
9
+ def load_ckpt_state_dict(ckpt_path):
10
+ if ckpt_path.endswith(".safetensors"):
11
+ state_dict = load_file(ckpt_path)
12
+ else:
13
+ state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
14
+
15
+ return state_dict
16
+
17
+ def remove_weight_norm_from_model(model):
18
+ for module in model.modules():
19
+ if hasattr(module, "weight"):
20
+ print(f"Removing weight norm from {module}")
21
+ remove_weight_norm(module)
22
+
23
+ return model
24
+
25
+ # Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
26
+ # License can be found in LICENSES/LICENSE_META.txt
27
+
28
+ def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
29
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
30
+
31
+ Args:
32
+ input (torch.Tensor): The input tensor containing probabilities.
33
+ num_samples (int): Number of samples to draw.
34
+ replacement (bool): Whether to draw with replacement or not.
35
+ Keywords args:
36
+ generator (torch.Generator): A pseudorandom number generator for sampling.
37
+ Returns:
38
+ torch.Tensor: Last dimension contains num_samples indices
39
+ sampled from the multinomial probability distribution
40
+ located in the last dimension of tensor input.
41
+ """
42
+
43
+ if num_samples == 1:
44
+ q = torch.empty_like(input).exponential_(1, generator=generator)
45
+ return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
46
+
47
+ input_ = input.reshape(-1, input.shape[-1])
48
+ output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
49
+ output = output_.reshape(*list(input.shape[:-1]), -1)
50
+ return output
51
+
52
+
53
+ def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
54
+ """Sample next token from top K values along the last dimension of the input probs tensor.
55
+
56
+ Args:
57
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
58
+ k (int): The k in “top-k”.
59
+ Returns:
60
+ torch.Tensor: Sampled tokens.
61
+ """
62
+ top_k_value, _ = torch.topk(probs, k, dim=-1)
63
+ min_value_top_k = top_k_value[..., [-1]]
64
+ probs *= (probs >= min_value_top_k).float()
65
+ probs.div_(probs.sum(dim=-1, keepdim=True))
66
+ next_token = multinomial(probs, num_samples=1)
67
+ return next_token
68
+
69
+
70
+ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
71
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
72
+
73
+ Args:
74
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
75
+ p (int): The p in “top-p”.
76
+ Returns:
77
+ torch.Tensor: Sampled tokens.
78
+ """
79
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
80
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
81
+ mask = probs_sum - probs_sort > p
82
+ probs_sort *= (~mask).float()
83
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
84
+ next_token = multinomial(probs_sort, num_samples=1)
85
+ next_token = torch.gather(probs_idx, -1, next_token)
86
+ return next_token
87
+
88
+ def next_power_of_two(n):
89
+ return 2 ** (n - 1).bit_length()
90
+
91
+ def next_multiple_of_64(n):
92
+ return ((n + 63) // 64) * 64
stable_audio_tools/models/wavelets.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The 1D discrete wavelet transform for PyTorch."""
2
+
3
+ from einops import rearrange
4
+ import pywt
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from typing import Literal
9
+
10
+
11
+ def get_filter_bank(wavelet):
12
+ filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank)
13
+ if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0):
14
+ filt = filt[:, 1:]
15
+ return filt
16
+
17
+ class WaveletEncode1d(nn.Module):
18
+ def __init__(self,
19
+ channels,
20
+ levels,
21
+ wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
22
+ super().__init__()
23
+ self.wavelet = wavelet
24
+ self.channels = channels
25
+ self.levels = levels
26
+ filt = get_filter_bank(wavelet)
27
+ assert filt.shape[-1] % 2 == 1
28
+ kernel = filt[:2, None]
29
+ kernel = torch.flip(kernel, dims=(-1,))
30
+ index_i = torch.repeat_interleave(torch.arange(2), channels)
31
+ index_j = torch.tile(torch.arange(channels), (2,))
32
+ kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
33
+ kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
34
+ self.register_buffer("kernel", kernel_final)
35
+
36
+ def forward(self, x):
37
+ for i in range(self.levels):
38
+ low, rest = x[:, : self.channels], x[:, self.channels :]
39
+ pad = self.kernel.shape[-1] // 2
40
+ low = F.pad(low, (pad, pad), "reflect")
41
+ low = F.conv1d(low, self.kernel, stride=2)
42
+ rest = rearrange(
43
+ rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels
44
+ )
45
+ x = torch.cat([low, rest], dim=1)
46
+ return x
47
+
48
+
49
+ class WaveletDecode1d(nn.Module):
50
+ def __init__(self,
51
+ channels,
52
+ levels,
53
+ wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
54
+ super().__init__()
55
+ self.wavelet = wavelet
56
+ self.channels = channels
57
+ self.levels = levels
58
+ filt = get_filter_bank(wavelet)
59
+ assert filt.shape[-1] % 2 == 1
60
+ kernel = filt[2:, None]
61
+ index_i = torch.repeat_interleave(torch.arange(2), channels)
62
+ index_j = torch.tile(torch.arange(channels), (2,))
63
+ kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
64
+ kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
65
+ self.register_buffer("kernel", kernel_final)
66
+
67
+ def forward(self, x):
68
+ for i in range(self.levels):
69
+ low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :]
70
+ pad = self.kernel.shape[-1] // 2 + 2
71
+ low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2)
72
+ low = F.pad(low, (pad, pad), "reflect")
73
+ low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2)
74
+ low = F.conv_transpose1d(
75
+ low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2
76
+ )
77
+ low = low[..., pad - 1 : -pad]
78
+ rest = rearrange(
79
+ rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels
80
+ )
81
+ x = torch.cat([low, rest], dim=1)
82
+ return x
stable_audio_tools/training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .factory import create_training_wrapper_from_config, create_demo_callback_from_config
stable_audio_tools/training/autoencoders.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import wandb
4
+ from einops import rearrange
5
+ from safetensors.torch import save_file, save_model
6
+ from ema_pytorch import EMA
7
+ from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss
8
+ import pytorch_lightning as pl
9
+ from ..models.autoencoders import AudioAutoencoder
10
+ from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss
11
+ from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
12
+ from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
13
+ from .utils import create_optimizer_from_config, create_scheduler_from_config
14
+
15
+
16
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
17
+ from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
18
+
19
+ class AutoencoderTrainingWrapper(pl.LightningModule):
20
+ def __init__(
21
+ self,
22
+ autoencoder: AudioAutoencoder,
23
+ lr: float = 1e-4,
24
+ warmup_steps: int = 0,
25
+ encoder_freeze_on_warmup: bool = False,
26
+ sample_rate=48000,
27
+ loss_config: dict = None,
28
+ optimizer_configs: dict = None,
29
+ use_ema: bool = True,
30
+ ema_copy = None,
31
+ force_input_mono = False,
32
+ latent_mask_ratio = 0.0,
33
+ teacher_model: AudioAutoencoder = None
34
+ ):
35
+ super().__init__()
36
+
37
+ self.automatic_optimization = False
38
+
39
+ self.autoencoder = autoencoder
40
+
41
+ self.warmed_up = False
42
+ self.warmup_steps = warmup_steps
43
+ self.encoder_freeze_on_warmup = encoder_freeze_on_warmup
44
+ self.lr = lr
45
+
46
+ self.force_input_mono = force_input_mono
47
+
48
+ self.teacher_model = teacher_model
49
+
50
+ if optimizer_configs is None:
51
+ optimizer_configs ={
52
+ "autoencoder": {
53
+ "optimizer": {
54
+ "type": "AdamW",
55
+ "config": {
56
+ "lr": lr,
57
+ "betas": (.8, .99)
58
+ }
59
+ }
60
+ },
61
+ "discriminator": {
62
+ "optimizer": {
63
+ "type": "AdamW",
64
+ "config": {
65
+ "lr": lr,
66
+ "betas": (.8, .99)
67
+ }
68
+ }
69
+ }
70
+
71
+ }
72
+
73
+ self.optimizer_configs = optimizer_configs
74
+
75
+ if loss_config is None:
76
+ scales = [2048, 1024, 512, 256, 128, 64, 32]
77
+ hop_sizes = []
78
+ win_lengths = []
79
+ overlap = 0.75
80
+ for s in scales:
81
+ hop_sizes.append(int(s * (1 - overlap)))
82
+ win_lengths.append(s)
83
+
84
+ loss_config = {
85
+ "discriminator": {
86
+ "type": "encodec",
87
+ "config": {
88
+ "n_ffts": scales,
89
+ "hop_lengths": hop_sizes,
90
+ "win_lengths": win_lengths,
91
+ "filters": 32
92
+ },
93
+ "weights": {
94
+ "adversarial": 0.1,
95
+ "feature_matching": 5.0,
96
+ }
97
+ },
98
+ "spectral": {
99
+ "type": "mrstft",
100
+ "config": {
101
+ "fft_sizes": scales,
102
+ "hop_sizes": hop_sizes,
103
+ "win_lengths": win_lengths,
104
+ "perceptual_weighting": True
105
+ },
106
+ "weights": {
107
+ "mrstft": 1.0,
108
+ }
109
+ },
110
+ "time": {
111
+ "type": "l1",
112
+ "config": {},
113
+ "weights": {
114
+ "l1": 0.0,
115
+ }
116
+ }
117
+ }
118
+
119
+ self.loss_config = loss_config
120
+
121
+ # Spectral reconstruction loss
122
+
123
+ stft_loss_args = loss_config['spectral']['config']
124
+
125
+ if self.autoencoder.out_channels == 2:
126
+ self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
127
+ self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
128
+ else:
129
+ self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
130
+
131
+ # Discriminator
132
+
133
+ if loss_config['discriminator']['type'] == 'oobleck':
134
+ self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config'])
135
+ elif loss_config['discriminator']['type'] == 'encodec':
136
+ self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config'])
137
+ elif loss_config['discriminator']['type'] == 'dac':
138
+ self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config'])
139
+
140
+ self.gen_loss_modules = []
141
+
142
+ # Adversarial and feature matching losses
143
+ self.gen_loss_modules += [
144
+ ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'),
145
+ ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'),
146
+ ]
147
+
148
+ if self.teacher_model is not None:
149
+ # Distillation losses
150
+
151
+ stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25
152
+ self.gen_loss_modules += [
153
+ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss
154
+ AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder
155
+ AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder
156
+ AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder
157
+ ]
158
+
159
+ else:
160
+
161
+ # Reconstruction loss
162
+ self.gen_loss_modules += [
163
+ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
164
+ ]
165
+
166
+ if self.autoencoder.out_channels == 2:
167
+
168
+ # Add left and right channel reconstruction losses in addition to the sum and difference
169
+ self.gen_loss_modules += [
170
+ AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2),
171
+ AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2),
172
+ ]
173
+
174
+ self.gen_loss_modules += [
175
+ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
176
+ ]
177
+
178
+ if self.loss_config['time']['weights']['l1'] > 0.0:
179
+ self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss'))
180
+
181
+ if self.autoencoder.bottleneck is not None:
182
+ self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config)
183
+
184
+ self.losses_gen = MultiLoss(self.gen_loss_modules)
185
+
186
+ self.disc_loss_modules = [
187
+ ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'),
188
+ ]
189
+
190
+ self.losses_disc = MultiLoss(self.disc_loss_modules)
191
+
192
+ # Set up EMA for model weights
193
+ self.autoencoder_ema = None
194
+
195
+ self.use_ema = use_ema
196
+
197
+ if self.use_ema:
198
+ self.autoencoder_ema = EMA(
199
+ self.autoencoder,
200
+ ema_model=ema_copy,
201
+ beta=0.9999,
202
+ power=3/4,
203
+ update_every=1,
204
+ update_after_step=1
205
+ )
206
+
207
+ self.latent_mask_ratio = latent_mask_ratio
208
+
209
+ def configure_optimizers(self):
210
+
211
+ opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters())
212
+ opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters())
213
+
214
+ if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']:
215
+ sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen)
216
+ sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc)
217
+ return [opt_gen, opt_disc], [sched_gen, sched_disc]
218
+
219
+ return [opt_gen, opt_disc]
220
+
221
+ def training_step(self, batch, batch_idx):
222
+ reals, _ = batch
223
+
224
+ # Remove extra dimension added by WebDataset
225
+ if reals.ndim == 4 and reals.shape[0] == 1:
226
+ reals = reals[0]
227
+
228
+ if self.global_step >= self.warmup_steps:
229
+ self.warmed_up = True
230
+
231
+ loss_info = {}
232
+
233
+ loss_info["reals"] = reals
234
+
235
+ encoder_input = reals
236
+
237
+ if self.force_input_mono and encoder_input.shape[1] > 1:
238
+ encoder_input = encoder_input.mean(dim=1, keepdim=True)
239
+
240
+ loss_info["encoder_input"] = encoder_input
241
+
242
+ data_std = encoder_input.std()
243
+
244
+ if self.warmed_up and self.encoder_freeze_on_warmup:
245
+ with torch.no_grad():
246
+ latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
247
+ else:
248
+ latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
249
+
250
+ loss_info["latents"] = latents
251
+
252
+ loss_info.update(encoder_info)
253
+
254
+ # Encode with teacher model for distillation
255
+ if self.teacher_model is not None:
256
+ with torch.no_grad():
257
+ teacher_latents = self.teacher_model.encode(encoder_input, return_info=False)
258
+ loss_info['teacher_latents'] = teacher_latents
259
+
260
+ if self.latent_mask_ratio > 0.0:
261
+ mask = torch.rand_like(latents) < self.latent_mask_ratio
262
+ latents = torch.where(mask, torch.zeros_like(latents), latents)
263
+
264
+ decoded = self.autoencoder.decode(latents)
265
+
266
+ loss_info["decoded"] = decoded
267
+
268
+ if self.autoencoder.out_channels == 2:
269
+ loss_info["decoded_left"] = decoded[:, 0:1, :]
270
+ loss_info["decoded_right"] = decoded[:, 1:2, :]
271
+ loss_info["reals_left"] = reals[:, 0:1, :]
272
+ loss_info["reals_right"] = reals[:, 1:2, :]
273
+
274
+ # Distillation
275
+ if self.teacher_model is not None:
276
+ with torch.no_grad():
277
+ teacher_decoded = self.teacher_model.decode(teacher_latents)
278
+ own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher
279
+ teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model
280
+
281
+ loss_info['teacher_decoded'] = teacher_decoded
282
+ loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded
283
+ loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded
284
+
285
+
286
+ if self.warmed_up:
287
+ loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded)
288
+ else:
289
+ loss_dis = torch.tensor(0.).to(reals)
290
+ loss_adv = torch.tensor(0.).to(reals)
291
+ feature_matching_distance = torch.tensor(0.).to(reals)
292
+
293
+ loss_info["loss_dis"] = loss_dis
294
+ loss_info["loss_adv"] = loss_adv
295
+ loss_info["feature_matching_distance"] = feature_matching_distance
296
+
297
+ opt_gen, opt_disc = self.optimizers()
298
+
299
+ lr_schedulers = self.lr_schedulers()
300
+
301
+ sched_gen = None
302
+ sched_disc = None
303
+
304
+ if lr_schedulers is not None:
305
+ sched_gen, sched_disc = lr_schedulers
306
+
307
+ # Train the discriminator
308
+ if self.global_step % 2 and self.warmed_up:
309
+ loss, losses = self.losses_disc(loss_info)
310
+
311
+ log_dict = {
312
+ 'train/disc_lr': opt_disc.param_groups[0]['lr']
313
+ }
314
+
315
+ opt_disc.zero_grad()
316
+ self.manual_backward(loss)
317
+ opt_disc.step()
318
+
319
+ if sched_disc is not None:
320
+ # sched step every step
321
+ sched_disc.step()
322
+
323
+ # Train the generator
324
+ else:
325
+
326
+ loss, losses = self.losses_gen(loss_info)
327
+
328
+ if self.use_ema:
329
+ self.autoencoder_ema.update()
330
+
331
+ opt_gen.zero_grad()
332
+ self.manual_backward(loss)
333
+ opt_gen.step()
334
+
335
+ if sched_gen is not None:
336
+ # scheduler step every step
337
+ sched_gen.step()
338
+
339
+ log_dict = {
340
+ 'train/loss': loss.detach(),
341
+ 'train/latent_std': latents.std().detach(),
342
+ 'train/data_std': data_std.detach(),
343
+ 'train/gen_lr': opt_gen.param_groups[0]['lr']
344
+ }
345
+
346
+ for loss_name, loss_value in losses.items():
347
+ log_dict[f'train/{loss_name}'] = loss_value.detach()
348
+
349
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
350
+
351
+ return loss
352
+
353
+ def export_model(self, path, use_safetensors=False):
354
+ if self.autoencoder_ema is not None:
355
+ model = self.autoencoder_ema.ema_model
356
+ else:
357
+ model = self.autoencoder
358
+
359
+ if use_safetensors:
360
+ save_model(model, path)
361
+ else:
362
+ torch.save({"state_dict": model.state_dict()}, path)
363
+
364
+
365
+ class AutoencoderDemoCallback(pl.Callback):
366
+ def __init__(
367
+ self,
368
+ demo_dl,
369
+ demo_every=2000,
370
+ sample_size=65536,
371
+ sample_rate=48000
372
+ ):
373
+ super().__init__()
374
+ self.demo_every = demo_every
375
+ self.demo_samples = sample_size
376
+ self.demo_dl = iter(demo_dl)
377
+ self.sample_rate = sample_rate
378
+ self.last_demo_step = -1
379
+
380
+ @rank_zero_only
381
+ @torch.no_grad()
382
+ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
383
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
384
+ return
385
+
386
+ self.last_demo_step = trainer.global_step
387
+
388
+ module.eval()
389
+
390
+ try:
391
+ demo_reals, _ = next(self.demo_dl)
392
+
393
+ # Remove extra dimension added by WebDataset
394
+ if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
395
+ demo_reals = demo_reals[0]
396
+
397
+ encoder_input = demo_reals
398
+
399
+ encoder_input = encoder_input.to(module.device)
400
+
401
+ if module.force_input_mono:
402
+ encoder_input = encoder_input.mean(dim=1, keepdim=True)
403
+
404
+ demo_reals = demo_reals.to(module.device)
405
+
406
+ with torch.no_grad():
407
+ if module.use_ema:
408
+
409
+ latents = module.autoencoder_ema.ema_model.encode(encoder_input)
410
+
411
+ fakes = module.autoencoder_ema.ema_model.decode(latents)
412
+ else:
413
+ latents = module.autoencoder.encode(encoder_input)
414
+
415
+ fakes = module.autoencoder.decode(latents)
416
+
417
+ #Interleave reals and fakes
418
+ reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
419
+
420
+ # Put the demos together
421
+ reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
422
+
423
+ log_dict = {}
424
+
425
+ filename = f'recon_{trainer.global_step:08}.wav'
426
+ reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
427
+ torchaudio.save(filename, reals_fakes, self.sample_rate)
428
+
429
+ log_dict[f'recon'] = wandb.Audio(filename,
430
+ sample_rate=self.sample_rate,
431
+ caption=f'Reconstructed')
432
+
433
+ log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
434
+ log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
435
+
436
+ log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
437
+
438
+ trainer.logger.experiment.log(log_dict)
439
+ except Exception as e:
440
+ print(f'{type(e).__name__}: {e}')
441
+ raise e
442
+ finally:
443
+ module.train()
444
+
445
+ def create_loss_modules_from_bottleneck(bottleneck, loss_config):
446
+ losses = []
447
+
448
+ if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
449
+ try:
450
+ kl_weight = loss_config['bottleneck']['weights']['kl']
451
+ except:
452
+ kl_weight = 1e-6
453
+
454
+ kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss')
455
+ losses.append(kl_loss)
456
+
457
+ if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
458
+ quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss')
459
+ losses.append(quantizer_loss)
460
+
461
+ if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck):
462
+ codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss')
463
+ commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss')
464
+ losses.append(codebook_loss)
465
+ losses.append(commitment_loss)
466
+
467
+ if isinstance(bottleneck, WassersteinBottleneck):
468
+ try:
469
+ mmd_weight = loss_config['bottleneck']['weights']['mmd']
470
+ except:
471
+ mmd_weight = 100
472
+
473
+ mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss')
474
+ losses.append(mmd_loss)
475
+
476
+ return losses
stable_audio_tools/training/diffusion.py ADDED
@@ -0,0 +1,1656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import sys, gc
3
+ import random
4
+ import torch
5
+ import torchaudio
6
+ import typing as tp
7
+ import wandb
8
+
9
+ from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
10
+ import auraloss
11
+ from ema_pytorch import EMA
12
+ from einops import rearrange
13
+ from safetensors.torch import save_file
14
+ from torch import optim
15
+ from torch.nn import functional as F
16
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
17
+
18
+ from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
19
+ from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper
20
+ from ..models.autoencoders import DiffusionAutoencoder
21
+ from ..models.diffusion_prior import PriorType
22
+ from .autoencoders import create_loss_modules_from_bottleneck
23
+ from .losses import AuralossLoss, MSELoss, MultiLoss
24
+ from .utils import create_optimizer_from_config, create_scheduler_from_config
25
+
26
+ from time import time
27
+
28
+
29
+ class Profiler:
30
+
31
+ def __init__(self):
32
+ self.ticks = [[time(), None]]
33
+
34
+ def tick(self, msg):
35
+ self.ticks.append([time(), msg])
36
+
37
+ def __repr__(self):
38
+ rep = 80 * "=" + "\n"
39
+ for i in range(1, len(self.ticks)):
40
+ msg = self.ticks[i][1]
41
+ ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
42
+ rep += msg + f": {ellapsed*1000:.2f}ms\n"
43
+ rep += 80 * "=" + "\n\n\n"
44
+ return rep
45
+
46
+ class DiffusionUncondTrainingWrapper(pl.LightningModule):
47
+ '''
48
+ Wrapper for training an unconditional audio diffusion model (like Dance Diffusion).
49
+ '''
50
+ def __init__(
51
+ self,
52
+ model: DiffusionModelWrapper,
53
+ lr: float = 1e-4,
54
+ pre_encoded: bool = False
55
+ ):
56
+ super().__init__()
57
+
58
+ self.diffusion = model
59
+
60
+ self.diffusion_ema = EMA(
61
+ self.diffusion.model,
62
+ beta=0.9999,
63
+ power=3/4,
64
+ update_every=1,
65
+ update_after_step=1
66
+ )
67
+
68
+ self.lr = lr
69
+
70
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
71
+
72
+ loss_modules = [
73
+ MSELoss("v",
74
+ "targets",
75
+ weight=1.0,
76
+ name="mse_loss"
77
+ )
78
+ ]
79
+
80
+ self.losses = MultiLoss(loss_modules)
81
+
82
+ self.pre_encoded = pre_encoded
83
+
84
+ def configure_optimizers(self):
85
+ return optim.Adam([*self.diffusion.parameters()], lr=self.lr)
86
+
87
+ def training_step(self, batch, batch_idx):
88
+ reals = batch[0]
89
+
90
+ if reals.ndim == 4 and reals.shape[0] == 1:
91
+ reals = reals[0]
92
+
93
+ diffusion_input = reals
94
+
95
+ loss_info = {}
96
+
97
+ if not self.pre_encoded:
98
+ loss_info["audio_reals"] = diffusion_input
99
+
100
+ if self.diffusion.pretransform is not None:
101
+ if not self.pre_encoded:
102
+ with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
103
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
104
+ else:
105
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
106
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
107
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
108
+
109
+ loss_info["reals"] = diffusion_input
110
+
111
+ # Draw uniformly distributed continuous timesteps
112
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
113
+
114
+ # Calculate the noise schedule parameters for those timesteps
115
+ alphas, sigmas = get_alphas_sigmas(t)
116
+
117
+ # Combine the ground truth data and the noise
118
+ alphas = alphas[:, None, None]
119
+ sigmas = sigmas[:, None, None]
120
+ noise = torch.randn_like(diffusion_input)
121
+ noised_inputs = diffusion_input * alphas + noise * sigmas
122
+ targets = noise * alphas - diffusion_input * sigmas
123
+
124
+ with torch.cuda.amp.autocast():
125
+ v = self.diffusion(noised_inputs, t)
126
+
127
+ loss_info.update({
128
+ "v": v,
129
+ "targets": targets
130
+ })
131
+
132
+ loss, losses = self.losses(loss_info)
133
+
134
+ log_dict = {
135
+ 'train/loss': loss.detach(),
136
+ 'train/std_data': diffusion_input.std(),
137
+ }
138
+
139
+ for loss_name, loss_value in losses.items():
140
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
141
+
142
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
143
+ return loss
144
+
145
+ def on_before_zero_grad(self, *args, **kwargs):
146
+ self.diffusion_ema.update()
147
+
148
+ def export_model(self, path, use_safetensors=False):
149
+
150
+ self.diffusion.model = self.diffusion_ema.ema_model
151
+
152
+ if use_safetensors:
153
+ save_file(self.diffusion.state_dict(), path)
154
+ else:
155
+ torch.save({"state_dict": self.diffusion.state_dict()}, path)
156
+
157
+ class DiffusionUncondDemoCallback(pl.Callback):
158
+ def __init__(self,
159
+ demo_every=2000,
160
+ num_demos=8,
161
+ demo_steps=250,
162
+ sample_rate=48000
163
+ ):
164
+ super().__init__()
165
+
166
+ self.demo_every = demo_every
167
+ self.num_demos = num_demos
168
+ self.demo_steps = demo_steps
169
+ self.sample_rate = sample_rate
170
+ self.last_demo_step = -1
171
+
172
+ @rank_zero_only
173
+ @torch.no_grad()
174
+ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
175
+
176
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
177
+ return
178
+
179
+ self.last_demo_step = trainer.global_step
180
+
181
+ demo_samples = module.diffusion.sample_size
182
+
183
+ if module.diffusion.pretransform is not None:
184
+ demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio
185
+
186
+ noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device)
187
+
188
+ try:
189
+ with torch.cuda.amp.autocast():
190
+ fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0)
191
+
192
+ if module.diffusion.pretransform is not None:
193
+ fakes = module.diffusion.pretransform.decode(fakes)
194
+
195
+ # Put the demos together
196
+ fakes = rearrange(fakes, 'b d n -> d (b n)')
197
+
198
+ log_dict = {}
199
+
200
+ filename = f'demo_{trainer.global_step:08}.wav'
201
+ fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
202
+ torchaudio.save(filename, fakes, self.sample_rate)
203
+
204
+ log_dict[f'demo'] = wandb.Audio(filename,
205
+ sample_rate=self.sample_rate,
206
+ caption=f'Reconstructed')
207
+
208
+ log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes))
209
+
210
+ trainer.logger.experiment.log(log_dict)
211
+
212
+ del fakes
213
+
214
+ except Exception as e:
215
+ print(f'{type(e).__name__}: {e}')
216
+ finally:
217
+ gc.collect()
218
+ torch.cuda.empty_cache()
219
+
220
+ class DiffusionCondTrainingWrapper(pl.LightningModule):
221
+ '''
222
+ Wrapper for training a conditional audio diffusion model.
223
+ '''
224
+ def __init__(
225
+ self,
226
+ model: ConditionedDiffusionModelWrapper,
227
+ lr: float = None,
228
+ mask_padding: bool = False,
229
+ mask_padding_dropout: float = 0.0,
230
+ use_ema: bool = True,
231
+ log_loss_info: bool = True,
232
+ optimizer_configs: dict = None,
233
+ pre_encoded: bool = False,
234
+ cfg_dropout_prob = 0.1,
235
+ timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
236
+ ):
237
+ super().__init__()
238
+
239
+ self.diffusion = model
240
+
241
+ if use_ema:
242
+ self.diffusion_ema = EMA(
243
+ self.diffusion.model,
244
+ beta=0.9999,
245
+ power=3/4,
246
+ update_every=1,
247
+ update_after_step=1,
248
+ include_online_model=False
249
+ )
250
+ else:
251
+ self.diffusion_ema = None
252
+
253
+ self.mask_padding = mask_padding
254
+ self.mask_padding_dropout = mask_padding_dropout
255
+
256
+ self.cfg_dropout_prob = cfg_dropout_prob
257
+
258
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
259
+
260
+ self.timestep_sampler = timestep_sampler
261
+
262
+ self.diffusion_objective = model.diffusion_objective
263
+
264
+ if 'av_loss' in optimizer_configs and optimizer_configs['av_loss']['if_add_av_loss']:
265
+ av_align_weight = optimizer_configs['av_loss']['config']['weight']
266
+ self.loss_modules = [
267
+ MSELoss("output",
268
+ "targets",
269
+ weight=1.0 - av_align_weight,
270
+ mask_key="padding_mask" if self.mask_padding else None,
271
+ name="mse_loss"
272
+ )
273
+ ]
274
+ else:
275
+ self.loss_modules = [
276
+ MSELoss("output",
277
+ "targets",
278
+ weight=1.0,
279
+ mask_key="padding_mask" if self.mask_padding else None,
280
+ name="mse_loss"
281
+ )
282
+ ]
283
+
284
+
285
+ self.losses = MultiLoss(self.loss_modules)
286
+
287
+ self.log_loss_info = log_loss_info
288
+
289
+ assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
290
+
291
+ if optimizer_configs is None:
292
+ optimizer_configs = {
293
+ "diffusion": {
294
+ "optimizer": {
295
+ "type": "Adam",
296
+ "config": {
297
+ "lr": lr
298
+ }
299
+ }
300
+ }
301
+ }
302
+ else:
303
+ if lr is not None:
304
+ print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
305
+
306
+ self.optimizer_configs = optimizer_configs
307
+
308
+ self.pre_encoded = pre_encoded
309
+
310
+ def configure_optimizers(self):
311
+ diffusion_opt_config = self.optimizer_configs['diffusion']
312
+ opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters())
313
+
314
+ if "scheduler" in diffusion_opt_config:
315
+ sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
316
+ sched_diff_config = {
317
+ "scheduler": sched_diff,
318
+ "interval": "step"
319
+ }
320
+ return [opt_diff], [sched_diff_config]
321
+
322
+ return [opt_diff]
323
+
324
+ def training_step(self, batch, batch_idx):
325
+
326
+
327
+ reals, metadata = batch
328
+
329
+ p = Profiler()
330
+
331
+ if reals.ndim == 4 and reals.shape[0] == 1:
332
+ reals = reals[0]
333
+
334
+ loss_info = {}
335
+
336
+ diffusion_input = reals
337
+ if not self.pre_encoded:
338
+ loss_info["audio_reals"] = diffusion_input
339
+
340
+ p.tick("setup")
341
+
342
+ with torch.cuda.amp.autocast():
343
+ conditioning = self.diffusion.conditioner(metadata, self.device)
344
+
345
+ use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout
346
+
347
+ # Create batch tensor of attention masks from the "mask" field of the metadata array
348
+ if use_padding_mask:
349
+ padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device)
350
+
351
+ p.tick("conditioning")
352
+
353
+ if self.diffusion.pretransform is not None:
354
+ self.diffusion.pretransform.to(self.device)
355
+
356
+ if not self.pre_encoded:
357
+ with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
358
+ self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad)
359
+
360
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
361
+ p.tick("pretransform")
362
+
363
+ # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
364
+ if use_padding_mask:
365
+ padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
366
+ else:
367
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
368
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
369
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
370
+
371
+ if self.timestep_sampler == "uniform":
372
+ # Draw uniformly distributed continuous timesteps
373
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) # [0.1360, 0.5232]
374
+ elif self.timestep_sampler == "logit_normal":
375
+ t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
376
+
377
+ # Calculate the noise schedule parameters for those timesteps
378
+ if self.diffusion_objective == "v":
379
+ alphas, sigmas = get_alphas_sigmas(t)
380
+ elif self.diffusion_objective == "rectified_flow":
381
+ alphas, sigmas = 1-t, t
382
+
383
+ # Combine the ground truth data and the noise
384
+ alphas = alphas[:, None, None]
385
+ sigmas = sigmas[:, None, None]
386
+ noise = torch.randn_like(diffusion_input)
387
+ noised_inputs = diffusion_input * alphas + noise * sigmas
388
+
389
+ if self.diffusion_objective == "v":
390
+ targets = noise * alphas - diffusion_input * sigmas
391
+ elif self.diffusion_objective == "rectified_flow":
392
+ targets = noise - diffusion_input
393
+
394
+ p.tick("noise")
395
+
396
+ extra_args = {}
397
+
398
+ if use_padding_mask:
399
+ extra_args["mask"] = padding_masks
400
+
401
+ with torch.cuda.amp.autocast():
402
+ p.tick("amp")
403
+ output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
404
+ p.tick("diffusion")
405
+
406
+ loss_info.update({
407
+ "output": output,
408
+ "targets": targets,
409
+ "padding_mask": padding_masks if use_padding_mask else None,
410
+ })
411
+
412
+ loss, losses = self.losses(loss_info)
413
+
414
+ p.tick("loss")
415
+
416
+ if self.log_loss_info:
417
+ # Loss debugging logs
418
+ num_loss_buckets = 10
419
+ bucket_size = 1 / num_loss_buckets
420
+ loss_all = F.mse_loss(output, targets, reduction="none")
421
+
422
+ sigmas = rearrange(self.all_gather(sigmas), "b c n -> (b) c n").squeeze()
423
+
424
+ # gather loss_all across all GPUs
425
+ loss_all = rearrange(self.all_gather(loss_all), "b c n -> (b) c n")
426
+
427
+ # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
428
+ loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
429
+
430
+ # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
431
+ debug_log_dict = {
432
+ f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
433
+ }
434
+
435
+ self.log_dict(debug_log_dict)
436
+
437
+
438
+ log_dict = {
439
+ 'train/loss': loss.detach(),
440
+ 'train/std_data': diffusion_input.std(),
441
+ 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
442
+ }
443
+
444
+ for loss_name, loss_value in losses.items():
445
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
446
+
447
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
448
+ p.tick("log")
449
+ #print(f"Profiler: {p}")
450
+ return loss
451
+
452
+ def validation_step(self, batch, batch_idx):
453
+ reals, metadata = batch
454
+
455
+ p = Profiler()
456
+
457
+ if reals.ndim == 4 and reals.shape[0] == 1:
458
+ reals = reals[0]
459
+
460
+ loss_info = {}
461
+
462
+ diffusion_input = reals
463
+
464
+ if not self.pre_encoded:
465
+ loss_info["audio_reals"] = diffusion_input
466
+
467
+ p.tick("setup")
468
+ with torch.cuda.amp.autocast():
469
+ conditioning = self.diffusion.conditioner(metadata, self.device)
470
+
471
+ # If mask_padding is on, randomly drop the padding masks to allow for learning silence padding
472
+ use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout
473
+
474
+ # Create batch tensor of attention masks from the "mask" field of the metadata array
475
+ if use_padding_mask:
476
+ padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length)
477
+
478
+ p.tick("conditioning")
479
+
480
+ if self.diffusion.pretransform is not None:
481
+ self.diffusion.pretransform.to(self.device)
482
+
483
+ if not self.pre_encoded:
484
+ with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
485
+ self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad)
486
+
487
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
488
+ p.tick("pretransform")
489
+
490
+ # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
491
+ if use_padding_mask:
492
+ padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
493
+ else:
494
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
495
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
496
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
497
+
498
+ if self.timestep_sampler == "uniform":
499
+ # Draw uniformly distributed continuous timesteps
500
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
501
+ elif self.timestep_sampler == "logit_normal":
502
+ t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
503
+
504
+ # Calculate the noise schedule parameters for those timesteps
505
+ if self.diffusion_objective == "v":
506
+ alphas, sigmas = get_alphas_sigmas(t)
507
+ elif self.diffusion_objective == "rectified_flow":
508
+ alphas, sigmas = 1-t, t
509
+
510
+ # Combine the ground truth data and the noise
511
+ alphas = alphas[:, None, None]
512
+ sigmas = sigmas[:, None, None]
513
+ noise = torch.randn_like(diffusion_input)
514
+ noised_inputs = diffusion_input * alphas + noise * sigmas
515
+
516
+ if self.diffusion_objective == "v":
517
+ targets = noise * alphas - diffusion_input * sigmas
518
+ elif self.diffusion_objective == "rectified_flow":
519
+ targets = noise - diffusion_input
520
+
521
+ p.tick("noise")
522
+
523
+ extra_args = {}
524
+
525
+ if use_padding_mask:
526
+ extra_args["mask"] = padding_masks
527
+
528
+ with torch.cuda.amp.autocast():
529
+ p.tick("amp")
530
+
531
+ output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
532
+ p.tick("diffusion")
533
+
534
+ loss_info.update({
535
+ "output": output,
536
+ "targets": targets,
537
+ "padding_mask": padding_masks if use_padding_mask else None,
538
+ })
539
+
540
+ loss, losses = self.losses(loss_info)
541
+
542
+ p.tick("loss")
543
+
544
+ if self.log_loss_info:
545
+ # Loss debugging logs
546
+ num_loss_buckets = 10
547
+ bucket_size = 1 / num_loss_buckets
548
+ loss_all = F.mse_loss(output, targets, reduction="none")
549
+ # loss_all = F.binary_cross_entropy_with_logits(output, targets, reduction="none")
550
+
551
+
552
+ sigmas = rearrange(self.all_gather(sigmas), "b c n -> (b) c n").squeeze()
553
+ # sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
554
+
555
+ # gather loss_all across all GPUs
556
+ loss_all = rearrange(self.all_gather(loss_all), "b c n -> (b) c n")
557
+ # loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
558
+
559
+ # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
560
+ loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
561
+
562
+ # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
563
+ debug_log_dict = {
564
+ f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
565
+ }
566
+
567
+ self.log_dict(debug_log_dict)
568
+
569
+
570
+ log_dict = {
571
+ 'valid/loss': loss.detach(),
572
+ 'valid/std_data': diffusion_input.std(),
573
+ 'valid/lr': self.trainer.optimizers[0].param_groups[0]['lr']
574
+ }
575
+
576
+
577
+ for loss_name, loss_value in losses.items():
578
+ log_dict[f"valid/{loss_name}"] = loss_value.detach()
579
+
580
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
581
+ # self.log('val_loss', val_loss, on_epoch=True, on_step=True)
582
+
583
+ p.tick("log")
584
+ #print(f"Profiler: {p}")
585
+ return loss
586
+
587
+ def on_before_zero_grad(self, *args, **kwargs):
588
+ if self.diffusion_ema is not None:
589
+ self.diffusion_ema.update()
590
+
591
+ def export_model(self, path, use_safetensors=False):
592
+ if self.diffusion_ema is not None:
593
+ self.diffusion.model = self.diffusion_ema.ema_model
594
+
595
+ if use_safetensors:
596
+ save_file(self.diffusion.state_dict(), path)
597
+ else:
598
+ torch.save({"state_dict": self.diffusion.state_dict()}, path)
599
+
600
+ class DiffusionCondDemoCallback(pl.Callback):
601
+ def __init__(self,
602
+ demo_every=2000,
603
+ num_demos=8,
604
+ sample_size=65536,
605
+ demo_steps=250,
606
+ sample_rate=48000,
607
+ demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = {},
608
+ demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7],
609
+ demo_cond_from_batch: bool = False,
610
+ display_audio_cond: bool = False
611
+ ):
612
+ super().__init__()
613
+
614
+ self.demo_every = demo_every
615
+ self.num_demos = num_demos
616
+ self.demo_samples = sample_size
617
+ self.demo_steps = demo_steps
618
+ self.sample_rate = sample_rate
619
+ self.last_demo_step = -1
620
+ self.demo_conditioning = demo_conditioning
621
+ self.demo_cfg_scales = demo_cfg_scales
622
+
623
+ # If true, the callback will use the metadata from the batch to generate the demo conditioning
624
+ self.demo_cond_from_batch = demo_cond_from_batch
625
+
626
+ # If true, the callback will display the audio conditioning
627
+ self.display_audio_cond = display_audio_cond
628
+
629
+ @rank_zero_only
630
+ @torch.no_grad()
631
+ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx):
632
+
633
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
634
+ return
635
+
636
+ module.eval()
637
+
638
+ print(f"Generating demo")
639
+ self.last_demo_step = trainer.global_step
640
+
641
+ demo_samples = self.demo_samples
642
+
643
+ demo_cond = self.demo_conditioning
644
+
645
+ if self.demo_cond_from_batch:
646
+ # Get metadata from the batch
647
+ demo_cond = batch[1][:self.num_demos]
648
+
649
+ if module.diffusion.pretransform is not None:
650
+ demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio
651
+
652
+ noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device)
653
+
654
+ try:
655
+ print("Getting conditioning")
656
+ with torch.cuda.amp.autocast():
657
+ conditioning = module.diffusion.conditioner(demo_cond, module.device)
658
+
659
+
660
+ cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
661
+
662
+ log_dict = {}
663
+
664
+ if self.display_audio_cond:
665
+ audio_inputs = torch.cat([cond["audio"] for cond in demo_cond], dim=0)
666
+ audio_inputs = rearrange(audio_inputs, 'b d n -> d (b n)')
667
+
668
+ filename = f'demo_audio_cond_{trainer.global_step:08}.wav'
669
+ audio_inputs = audio_inputs.to(torch.float32).mul(32767).to(torch.int16).cpu()
670
+ torchaudio.save(filename, audio_inputs, self.sample_rate)
671
+ log_dict[f'demo_audio_cond'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption="Audio conditioning")
672
+ log_dict[f"demo_audio_cond_melspec_left"] = wandb.Image(audio_spectrogram_image(audio_inputs))
673
+ trainer.logger.experiment.log(log_dict)
674
+
675
+ for cfg_scale in self.demo_cfg_scales:
676
+
677
+ print(f"Generating demo for cfg scale {cfg_scale}")
678
+
679
+ with torch.cuda.amp.autocast():
680
+ model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model
681
+
682
+ if module.diffusion_objective == "v":
683
+ fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
684
+ elif module.diffusion_objective == "rectified_flow":
685
+ fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
686
+
687
+ if module.diffusion.pretransform is not None:
688
+ fakes = module.diffusion.pretransform.decode(fakes)
689
+
690
+ # Put the demos together
691
+ fakes = rearrange(fakes, 'b d n -> d (b n)')
692
+
693
+ log_dict = {}
694
+
695
+ filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
696
+ fakes = fakes.div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
697
+ torchaudio.save(filename, fakes, self.sample_rate)
698
+
699
+ log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
700
+ sample_rate=self.sample_rate,
701
+ caption=f'Reconstructed')
702
+
703
+ log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
704
+
705
+ trainer.logger.experiment.log(log_dict)
706
+
707
+ del fakes
708
+
709
+ except Exception as e:
710
+ raise e
711
+ finally:
712
+ gc.collect()
713
+ torch.cuda.empty_cache()
714
+ module.train()
715
+
716
+ class DiffusionCondInpaintTrainingWrapper(pl.LightningModule):
717
+ '''
718
+ Wrapper for training a conditional audio diffusion model.
719
+ '''
720
+ def __init__(
721
+ self,
722
+ model: ConditionedDiffusionModelWrapper,
723
+ lr: float = 1e-4,
724
+ max_mask_segments = 10,
725
+ log_loss_info: bool = False,
726
+ optimizer_configs: dict = None,
727
+ use_ema: bool = True,
728
+ pre_encoded: bool = False,
729
+ cfg_dropout_prob = 0.1,
730
+ timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
731
+ ):
732
+ super().__init__()
733
+
734
+ self.diffusion = model
735
+
736
+ self.use_ema = use_ema
737
+
738
+ if self.use_ema:
739
+ self.diffusion_ema = EMA(
740
+ self.diffusion.model,
741
+ beta=0.9999,
742
+ power=3/4,
743
+ update_every=1,
744
+ update_after_step=1,
745
+ include_online_model=False
746
+ )
747
+ else:
748
+ self.diffusion_ema = None
749
+
750
+ self.cfg_dropout_prob = cfg_dropout_prob
751
+
752
+ self.lr = lr
753
+ self.max_mask_segments = max_mask_segments
754
+
755
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
756
+
757
+ self.timestep_sampler = timestep_sampler
758
+
759
+ self.diffusion_objective = model.diffusion_objective
760
+
761
+ self.loss_modules = [
762
+ MSELoss("output",
763
+ "targets",
764
+ weight=1.0,
765
+ name="mse_loss"
766
+ )
767
+ ]
768
+
769
+ self.losses = MultiLoss(self.loss_modules)
770
+
771
+ self.log_loss_info = log_loss_info
772
+
773
+ assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
774
+
775
+ if optimizer_configs is None:
776
+ optimizer_configs = {
777
+ "diffusion": {
778
+ "optimizer": {
779
+ "type": "Adam",
780
+ "config": {
781
+ "lr": lr
782
+ }
783
+ }
784
+ }
785
+ }
786
+ else:
787
+ if lr is not None:
788
+ print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
789
+
790
+ self.optimizer_configs = optimizer_configs
791
+
792
+ self.pre_encoded = pre_encoded
793
+
794
+ def configure_optimizers(self):
795
+ diffusion_opt_config = self.optimizer_configs['diffusion']
796
+ opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters())
797
+
798
+ if "scheduler" in diffusion_opt_config:
799
+ sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
800
+ sched_diff_config = {
801
+ "scheduler": sched_diff,
802
+ "interval": "step"
803
+ }
804
+ return [opt_diff], [sched_diff_config]
805
+
806
+ return [opt_diff]
807
+
808
+ def random_mask(self, sequence, max_mask_length):
809
+ b, _, sequence_length = sequence.size()
810
+
811
+ # Create a mask tensor for each batch element
812
+ masks = []
813
+
814
+ for i in range(b):
815
+ mask_type = random.randint(0, 2)
816
+
817
+ if mask_type == 0: # Random mask with multiple segments
818
+ num_segments = random.randint(1, self.max_mask_segments)
819
+ max_segment_length = max_mask_length // num_segments
820
+
821
+ segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments)
822
+
823
+ mask = torch.ones((1, 1, sequence_length))
824
+ for length in segment_lengths:
825
+ mask_start = random.randint(0, sequence_length - length)
826
+ mask[:, :, mask_start:mask_start + length] = 0
827
+
828
+ elif mask_type == 1: # Full mask
829
+ mask = torch.zeros((1, 1, sequence_length))
830
+
831
+ elif mask_type == 2: # Causal mask
832
+ mask = torch.ones((1, 1, sequence_length))
833
+ mask_length = random.randint(1, max_mask_length)
834
+ mask[:, :, -mask_length:] = 0
835
+
836
+ mask = mask.to(sequence.device)
837
+ masks.append(mask)
838
+
839
+ # Concatenate the mask tensors into a single tensor
840
+ mask = torch.cat(masks, dim=0).to(sequence.device)
841
+
842
+ # Apply the mask to the sequence tensor for each batch element
843
+ masked_sequence = sequence * mask
844
+
845
+ return masked_sequence, mask
846
+
847
+ def training_step(self, batch, batch_idx):
848
+ reals, metadata = batch
849
+
850
+ p = Profiler()
851
+
852
+ if reals.ndim == 4 and reals.shape[0] == 1:
853
+ reals = reals[0]
854
+
855
+ loss_info = {}
856
+
857
+ diffusion_input = reals
858
+
859
+ if not self.pre_encoded:
860
+ loss_info["audio_reals"] = diffusion_input
861
+
862
+ p.tick("setup")
863
+
864
+ with torch.cuda.amp.autocast():
865
+ conditioning = self.diffusion.conditioner(metadata, self.device)
866
+
867
+ p.tick("conditioning")
868
+
869
+ if self.diffusion.pretransform is not None:
870
+ self.diffusion.pretransform.to(self.device)
871
+
872
+ if not self.pre_encoded:
873
+ with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
874
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
875
+ p.tick("pretransform")
876
+
877
+ # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
878
+ # if use_padding_mask:
879
+ # padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
880
+ else:
881
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
882
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
883
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
884
+
885
+ # Max mask size is the full sequence length
886
+ max_mask_length = diffusion_input.shape[2]
887
+
888
+ # Create a mask of random length for a random slice of the input
889
+ masked_input, mask = self.random_mask(diffusion_input, max_mask_length)
890
+
891
+ conditioning['inpaint_mask'] = [mask]
892
+ conditioning['inpaint_masked_input'] = [masked_input]
893
+
894
+ if self.timestep_sampler == "uniform":
895
+ # Draw uniformly distributed continuous timesteps
896
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
897
+ elif self.timestep_sampler == "logit_normal":
898
+ t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
899
+
900
+ # Calculate the noise schedule parameters for those timesteps
901
+ if self.diffusion_objective == "v":
902
+ alphas, sigmas = get_alphas_sigmas(t)
903
+ elif self.diffusion_objective == "rectified_flow":
904
+ alphas, sigmas = 1-t, t
905
+
906
+ # Combine the ground truth data and the noise
907
+ alphas = alphas[:, None, None]
908
+ sigmas = sigmas[:, None, None]
909
+ noise = torch.randn_like(diffusion_input)
910
+ noised_inputs = diffusion_input * alphas + noise * sigmas
911
+
912
+ if self.diffusion_objective == "v":
913
+ targets = noise * alphas - diffusion_input * sigmas
914
+ elif self.diffusion_objective == "rectified_flow":
915
+ targets = noise - diffusion_input
916
+
917
+ p.tick("noise")
918
+
919
+ extra_args = {}
920
+
921
+ with torch.cuda.amp.autocast():
922
+ p.tick("amp")
923
+ output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
924
+ p.tick("diffusion")
925
+
926
+ loss_info.update({
927
+ "output": output,
928
+ "targets": targets,
929
+ })
930
+
931
+ loss, losses = self.losses(loss_info)
932
+
933
+ if self.log_loss_info:
934
+ # Loss debugging logs
935
+ num_loss_buckets = 10
936
+ bucket_size = 1 / num_loss_buckets
937
+ loss_all = F.mse_loss(output, targets, reduction="none")
938
+
939
+ sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
940
+
941
+ # gather loss_all across all GPUs
942
+ loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
943
+
944
+ # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
945
+ loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
946
+
947
+ # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
948
+ debug_log_dict = {
949
+ f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
950
+ }
951
+
952
+ self.log_dict(debug_log_dict)
953
+
954
+ log_dict = {
955
+ 'train/loss': loss.detach(),
956
+ 'train/std_data': diffusion_input.std(),
957
+ 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
958
+ }
959
+
960
+ for loss_name, loss_value in losses.items():
961
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
962
+
963
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
964
+ p.tick("log")
965
+ #print(f"Profiler: {p}")
966
+ return loss
967
+
968
+ def on_before_zero_grad(self, *args, **kwargs):
969
+ if self.diffusion_ema is not None:
970
+ self.diffusion_ema.update()
971
+
972
+ def export_model(self, path, use_safetensors=False):
973
+ if self.diffusion_ema is not None:
974
+ self.diffusion.model = self.diffusion_ema.ema_model
975
+
976
+ if use_safetensors:
977
+ save_file(self.diffusion.state_dict(), path)
978
+ else:
979
+ torch.save({"state_dict": self.diffusion.state_dict()}, path)
980
+
981
+ class DiffusionCondInpaintDemoCallback(pl.Callback):
982
+ def __init__(
983
+ self,
984
+ demo_dl,
985
+ demo_every=2000,
986
+ demo_steps=250,
987
+ sample_size=65536,
988
+ sample_rate=48000,
989
+ demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7]
990
+ ):
991
+ super().__init__()
992
+ self.demo_every = demo_every
993
+ self.demo_steps = demo_steps
994
+ self.demo_samples = sample_size
995
+ self.demo_dl = iter(demo_dl)
996
+ self.sample_rate = sample_rate
997
+ self.demo_cfg_scales = demo_cfg_scales
998
+ self.last_demo_step = -1
999
+
1000
+ @rank_zero_only
1001
+ @torch.no_grad()
1002
+ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx):
1003
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
1004
+ return
1005
+
1006
+ self.last_demo_step = trainer.global_step
1007
+
1008
+ try:
1009
+ log_dict = {}
1010
+
1011
+ demo_reals, metadata = next(self.demo_dl)
1012
+
1013
+ # Remove extra dimension added by WebDataset
1014
+ if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
1015
+ demo_reals = demo_reals[0]
1016
+
1017
+ demo_reals = demo_reals.to(module.device)
1018
+
1019
+ if not module.pre_encoded:
1020
+ # Log the real audio
1021
+ log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu()))
1022
+ # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals")
1023
+
1024
+ if module.diffusion.pretransform is not None:
1025
+ module.diffusion.pretransform.to(module.device)
1026
+ with torch.cuda.amp.autocast():
1027
+ demo_reals = module.diffusion.pretransform.encode(demo_reals)
1028
+
1029
+ demo_samples = demo_reals.shape[2]
1030
+
1031
+ # Get conditioning
1032
+ conditioning = module.diffusion.conditioner(metadata, module.device)
1033
+
1034
+ masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2])
1035
+
1036
+ conditioning['inpaint_mask'] = [mask]
1037
+ conditioning['inpaint_masked_input'] = [masked_input]
1038
+
1039
+ if module.diffusion.pretransform is not None:
1040
+ log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu()))
1041
+ else:
1042
+ log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu()))
1043
+
1044
+ cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
1045
+
1046
+ noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device)
1047
+
1048
+ trainer.logger.experiment.log(log_dict)
1049
+
1050
+ for cfg_scale in self.demo_cfg_scales:
1051
+ model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model
1052
+ print(f"Generating demo for cfg scale {cfg_scale}")
1053
+
1054
+ if module.diffusion_objective == "v":
1055
+ fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
1056
+ elif module.diffusion_objective == "rectified_flow":
1057
+ fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
1058
+
1059
+ if module.diffusion.pretransform is not None:
1060
+ with torch.cuda.amp.autocast():
1061
+ fakes = module.diffusion.pretransform.decode(fakes)
1062
+
1063
+ # Put the demos together
1064
+ fakes = rearrange(fakes, 'b d n -> d (b n)')
1065
+
1066
+ log_dict = {}
1067
+
1068
+ filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
1069
+ fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
1070
+ torchaudio.save(filename, fakes, self.sample_rate)
1071
+
1072
+ log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
1073
+ sample_rate=self.sample_rate,
1074
+ caption=f'Reconstructed')
1075
+
1076
+ log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
1077
+
1078
+ trainer.logger.experiment.log(log_dict)
1079
+ except Exception as e:
1080
+ print(f'{type(e).__name__}: {e}')
1081
+ raise e
1082
+
1083
+ class DiffusionAutoencoderTrainingWrapper(pl.LightningModule):
1084
+ '''
1085
+ Wrapper for training a diffusion autoencoder
1086
+ '''
1087
+ def __init__(
1088
+ self,
1089
+ model: DiffusionAutoencoder,
1090
+ lr: float = 1e-4,
1091
+ ema_copy = None,
1092
+ use_reconstruction_loss: bool = False
1093
+ ):
1094
+ super().__init__()
1095
+
1096
+ self.diffae = model
1097
+
1098
+ self.diffae_ema = EMA(
1099
+ self.diffae,
1100
+ ema_model=ema_copy,
1101
+ beta=0.9999,
1102
+ power=3/4,
1103
+ update_every=1,
1104
+ update_after_step=1,
1105
+ include_online_model=False
1106
+ )
1107
+
1108
+ self.lr = lr
1109
+
1110
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
1111
+
1112
+ loss_modules = [
1113
+ MSELoss("v",
1114
+ "targets",
1115
+ weight=1.0,
1116
+ name="mse_loss"
1117
+ )
1118
+ ]
1119
+
1120
+ if model.bottleneck is not None:
1121
+ # TODO: Use loss config for configurable bottleneck weights and reconstruction losses
1122
+ loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {})
1123
+
1124
+ self.use_reconstruction_loss = use_reconstruction_loss
1125
+
1126
+ if use_reconstruction_loss:
1127
+ scales = [2048, 1024, 512, 256, 128, 64, 32]
1128
+ hop_sizes = []
1129
+ win_lengths = []
1130
+ overlap = 0.75
1131
+ for s in scales:
1132
+ hop_sizes.append(int(s * (1 - overlap)))
1133
+ win_lengths.append(s)
1134
+
1135
+ sample_rate = model.sample_rate
1136
+
1137
+ stft_loss_args = {
1138
+ "fft_sizes": scales,
1139
+ "hop_sizes": hop_sizes,
1140
+ "win_lengths": win_lengths,
1141
+ "perceptual_weighting": True
1142
+ }
1143
+
1144
+ out_channels = model.out_channels
1145
+
1146
+ if model.pretransform is not None:
1147
+ out_channels = model.pretransform.io_channels
1148
+
1149
+ if out_channels == 2:
1150
+ self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1151
+ else:
1152
+ self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1153
+
1154
+ loss_modules.append(
1155
+ AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
1156
+ )
1157
+
1158
+ self.losses = MultiLoss(loss_modules)
1159
+
1160
+ def configure_optimizers(self):
1161
+ return optim.Adam([*self.diffae.parameters()], lr=self.lr)
1162
+
1163
+ def training_step(self, batch, batch_idx):
1164
+ reals = batch[0]
1165
+
1166
+ if reals.ndim == 4 and reals.shape[0] == 1:
1167
+ reals = reals[0]
1168
+
1169
+ loss_info = {}
1170
+
1171
+ loss_info["audio_reals"] = reals
1172
+
1173
+ if self.diffae.pretransform is not None:
1174
+ with torch.no_grad():
1175
+ reals = self.diffae.pretransform.encode(reals)
1176
+
1177
+ loss_info["reals"] = reals
1178
+
1179
+ #Encode reals, skipping the pretransform since it was already applied
1180
+ latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True)
1181
+
1182
+ loss_info["latents"] = latents
1183
+ loss_info.update(encoder_info)
1184
+
1185
+ if self.diffae.decoder is not None:
1186
+ latents = self.diffae.decoder(latents)
1187
+
1188
+ # Upsample latents to match diffusion length
1189
+ if latents.shape[2] != reals.shape[2]:
1190
+ latents = F.interpolate(latents, size=reals.shape[2], mode='nearest')
1191
+
1192
+ loss_info["latents_upsampled"] = latents
1193
+
1194
+ # Draw uniformly distributed continuous timesteps
1195
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
1196
+
1197
+ # Calculate the noise schedule parameters for those timesteps
1198
+ alphas, sigmas = get_alphas_sigmas(t)
1199
+
1200
+ # Combine the ground truth data and the noise
1201
+ alphas = alphas[:, None, None]
1202
+ sigmas = sigmas[:, None, None]
1203
+ noise = torch.randn_like(reals)
1204
+ noised_reals = reals * alphas + noise * sigmas
1205
+ targets = noise * alphas - reals * sigmas
1206
+
1207
+ with torch.cuda.amp.autocast():
1208
+ v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents)
1209
+
1210
+ loss_info.update({
1211
+ "v": v,
1212
+ "targets": targets
1213
+ })
1214
+
1215
+ if self.use_reconstruction_loss:
1216
+ pred = noised_reals * alphas - v * sigmas
1217
+
1218
+ loss_info["pred"] = pred
1219
+
1220
+ if self.diffae.pretransform is not None:
1221
+ pred = self.diffae.pretransform.decode(pred)
1222
+ loss_info["audio_pred"] = pred
1223
+
1224
+ loss, losses = self.losses(loss_info)
1225
+
1226
+ log_dict = {
1227
+ 'train/loss': loss.detach(),
1228
+ 'train/std_data': reals.std(),
1229
+ 'train/latent_std': latents.std(),
1230
+ }
1231
+
1232
+ for loss_name, loss_value in losses.items():
1233
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
1234
+
1235
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
1236
+ return loss
1237
+
1238
+ def on_before_zero_grad(self, *args, **kwargs):
1239
+ self.diffae_ema.update()
1240
+
1241
+ def export_model(self, path, use_safetensors=False):
1242
+
1243
+ model = self.diffae_ema.ema_model
1244
+
1245
+ if use_safetensors:
1246
+ save_file(model.state_dict(), path)
1247
+ else:
1248
+ torch.save({"state_dict": model.state_dict()}, path)
1249
+
1250
+ class DiffusionAutoencoderDemoCallback(pl.Callback):
1251
+ def __init__(
1252
+ self,
1253
+ demo_dl,
1254
+ demo_every=2000,
1255
+ demo_steps=250,
1256
+ sample_size=65536,
1257
+ sample_rate=48000
1258
+ ):
1259
+ super().__init__()
1260
+ self.demo_every = demo_every
1261
+ self.demo_steps = demo_steps
1262
+ self.demo_samples = sample_size
1263
+ self.demo_dl = iter(demo_dl)
1264
+ self.sample_rate = sample_rate
1265
+ self.last_demo_step = -1
1266
+
1267
+ @rank_zero_only
1268
+ @torch.no_grad()
1269
+ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx):
1270
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
1271
+ return
1272
+
1273
+ self.last_demo_step = trainer.global_step
1274
+
1275
+ demo_reals, _ = next(self.demo_dl)
1276
+
1277
+ # Remove extra dimension added by WebDataset
1278
+ if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
1279
+ demo_reals = demo_reals[0]
1280
+
1281
+ encoder_input = demo_reals
1282
+
1283
+ encoder_input = encoder_input.to(module.device)
1284
+
1285
+ demo_reals = demo_reals.to(module.device)
1286
+
1287
+ with torch.no_grad() and torch.cuda.amp.autocast():
1288
+ latents = module.diffae_ema.ema_model.encode(encoder_input).float()
1289
+ fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps)
1290
+
1291
+ #Interleave reals and fakes
1292
+ reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
1293
+
1294
+ # Put the demos together
1295
+ reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
1296
+
1297
+ log_dict = {}
1298
+
1299
+ filename = f'recon_{trainer.global_step:08}.wav'
1300
+ reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
1301
+ torchaudio.save(filename, reals_fakes, self.sample_rate)
1302
+
1303
+ log_dict[f'recon'] = wandb.Audio(filename,
1304
+ sample_rate=self.sample_rate,
1305
+ caption=f'Reconstructed')
1306
+
1307
+ log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
1308
+ log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
1309
+
1310
+ log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
1311
+
1312
+ if module.diffae_ema.ema_model.pretransform is not None:
1313
+ with torch.no_grad() and torch.cuda.amp.autocast():
1314
+ initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input)
1315
+ first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents)
1316
+ first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)')
1317
+ first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu()
1318
+ first_stage_filename = f'first_stage_{trainer.global_step:08}.wav'
1319
+ torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate)
1320
+
1321
+ log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents))
1322
+
1323
+ log_dict[f'first_stage'] = wandb.Audio(first_stage_filename,
1324
+ sample_rate=self.sample_rate,
1325
+ caption=f'First Stage Reconstructed')
1326
+
1327
+ log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes))
1328
+
1329
+
1330
+ trainer.logger.experiment.log(log_dict)
1331
+
1332
+ def create_source_mixture(reals, num_sources=2):
1333
+ # Create a fake mixture source by mixing elements from the training batch together with random offsets
1334
+ source = torch.zeros_like(reals)
1335
+ for i in range(reals.shape[0]):
1336
+ sources_added = 0
1337
+
1338
+ js = list(range(reals.shape[0]))
1339
+ random.shuffle(js)
1340
+ for j in js:
1341
+ if i == j or (i != j and sources_added < num_sources):
1342
+ # Randomly offset the mixed element between 0 and the length of the source
1343
+ seq_len = reals.shape[2]
1344
+ offset = random.randint(0, seq_len-1)
1345
+ source[i, :, offset:] += reals[j, :, :-offset]
1346
+ if i == j:
1347
+ # If this is the real one, shift the reals as well to ensure alignment
1348
+ new_reals = torch.zeros_like(reals[i])
1349
+ new_reals[:, offset:] = reals[i, :, :-offset]
1350
+ reals[i] = new_reals
1351
+ sources_added += 1
1352
+
1353
+ return source
1354
+
1355
+ class DiffusionPriorTrainingWrapper(pl.LightningModule):
1356
+ '''
1357
+ Wrapper for training a diffusion prior for inverse problems
1358
+ Prior types:
1359
+ mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version
1360
+ '''
1361
+ def __init__(
1362
+ self,
1363
+ model: ConditionedDiffusionModelWrapper,
1364
+ lr: float = 1e-4,
1365
+ ema_copy = None,
1366
+ prior_type: PriorType = PriorType.MonoToStereo,
1367
+ use_reconstruction_loss: bool = False,
1368
+ log_loss_info: bool = False,
1369
+ ):
1370
+ super().__init__()
1371
+
1372
+ self.diffusion = model
1373
+
1374
+ self.diffusion_ema = EMA(
1375
+ self.diffusion,
1376
+ ema_model=ema_copy,
1377
+ beta=0.9999,
1378
+ power=3/4,
1379
+ update_every=1,
1380
+ update_after_step=1,
1381
+ include_online_model=False
1382
+ )
1383
+
1384
+ self.lr = lr
1385
+
1386
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
1387
+
1388
+ self.log_loss_info = log_loss_info
1389
+
1390
+ loss_modules = [
1391
+ MSELoss("v",
1392
+ "targets",
1393
+ weight=1.0,
1394
+ name="mse_loss"
1395
+ )
1396
+ ]
1397
+
1398
+ self.use_reconstruction_loss = use_reconstruction_loss
1399
+
1400
+ if use_reconstruction_loss:
1401
+ scales = [2048, 1024, 512, 256, 128, 64, 32]
1402
+ hop_sizes = []
1403
+ win_lengths = []
1404
+ overlap = 0.75
1405
+ for s in scales:
1406
+ hop_sizes.append(int(s * (1 - overlap)))
1407
+ win_lengths.append(s)
1408
+
1409
+ sample_rate = model.sample_rate
1410
+
1411
+ stft_loss_args = {
1412
+ "fft_sizes": scales,
1413
+ "hop_sizes": hop_sizes,
1414
+ "win_lengths": win_lengths,
1415
+ "perceptual_weighting": True
1416
+ }
1417
+
1418
+ out_channels = model.io_channels
1419
+
1420
+ self.audio_out_channels = out_channels
1421
+
1422
+ if model.pretransform is not None:
1423
+ out_channels = model.pretransform.io_channels
1424
+
1425
+ if self.audio_out_channels == 2:
1426
+ self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1427
+ self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1428
+
1429
+ # Add left and right channel reconstruction losses in addition to the sum and difference
1430
+ self.loss_modules += [
1431
+ AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05),
1432
+ AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05),
1433
+ ]
1434
+
1435
+ else:
1436
+ self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1437
+
1438
+ self.loss_modules.append(
1439
+ AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
1440
+ )
1441
+
1442
+ self.losses = MultiLoss(loss_modules)
1443
+
1444
+ self.prior_type = prior_type
1445
+
1446
+ def configure_optimizers(self):
1447
+ return optim.Adam([*self.diffusion.parameters()], lr=self.lr)
1448
+
1449
+ def training_step(self, batch, batch_idx):
1450
+ reals, metadata = batch
1451
+
1452
+ if reals.ndim == 4 and reals.shape[0] == 1:
1453
+ reals = reals[0]
1454
+
1455
+ loss_info = {}
1456
+
1457
+ loss_info["audio_reals"] = reals
1458
+
1459
+ if self.prior_type == PriorType.MonoToStereo:
1460
+ source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device)
1461
+ loss_info["audio_reals_mono"] = source
1462
+ else:
1463
+ raise ValueError(f"Unknown prior type {self.prior_type}")
1464
+
1465
+ if self.diffusion.pretransform is not None:
1466
+ with torch.no_grad():
1467
+ reals = self.diffusion.pretransform.encode(reals)
1468
+
1469
+ if self.prior_type in [PriorType.MonoToStereo]:
1470
+ source = self.diffusion.pretransform.encode(source)
1471
+
1472
+ if self.diffusion.conditioner is not None:
1473
+ with torch.cuda.amp.autocast():
1474
+ conditioning = self.diffusion.conditioner(metadata, self.device)
1475
+ else:
1476
+ conditioning = {}
1477
+
1478
+ loss_info["reals"] = reals
1479
+
1480
+ # Draw uniformly distributed continuous timesteps
1481
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
1482
+
1483
+ # Calculate the noise schedule parameters for those timesteps
1484
+ alphas, sigmas = get_alphas_sigmas(t)
1485
+
1486
+ # Combine the ground truth data and the noise
1487
+ alphas = alphas[:, None, None]
1488
+ sigmas = sigmas[:, None, None]
1489
+ noise = torch.randn_like(reals)
1490
+ noised_reals = reals * alphas + noise * sigmas
1491
+ targets = noise * alphas - reals * sigmas
1492
+
1493
+ with torch.cuda.amp.autocast():
1494
+
1495
+ conditioning['source'] = [source]
1496
+
1497
+ v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1)
1498
+
1499
+ loss_info.update({
1500
+ "v": v,
1501
+ "targets": targets
1502
+ })
1503
+
1504
+ if self.use_reconstruction_loss:
1505
+ pred = noised_reals * alphas - v * sigmas
1506
+
1507
+ loss_info["pred"] = pred
1508
+
1509
+ if self.diffusion.pretransform is not None:
1510
+ pred = self.diffusion.pretransform.decode(pred)
1511
+ loss_info["audio_pred"] = pred
1512
+
1513
+ if self.audio_out_channels == 2:
1514
+ loss_info["pred_left"] = pred[:, 0:1, :]
1515
+ loss_info["pred_right"] = pred[:, 1:2, :]
1516
+ loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :]
1517
+ loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :]
1518
+
1519
+ loss, losses = self.losses(loss_info)
1520
+
1521
+ if self.log_loss_info:
1522
+ # Loss debugging logs
1523
+ num_loss_buckets = 10
1524
+ bucket_size = 1 / num_loss_buckets
1525
+ loss_all = F.mse_loss(v, targets, reduction="none")
1526
+
1527
+ sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
1528
+
1529
+ # gather loss_all across all GPUs
1530
+ loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
1531
+
1532
+ # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
1533
+ loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
1534
+
1535
+ # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
1536
+ debug_log_dict = {
1537
+ f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
1538
+ }
1539
+
1540
+ self.log_dict(debug_log_dict)
1541
+
1542
+ log_dict = {
1543
+ 'train/loss': loss.detach(),
1544
+ 'train/std_data': reals.std()
1545
+ }
1546
+
1547
+ for loss_name, loss_value in losses.items():
1548
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
1549
+
1550
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
1551
+ return loss
1552
+
1553
+ def on_before_zero_grad(self, *args, **kwargs):
1554
+ self.diffusion_ema.update()
1555
+
1556
+ def export_model(self, path, use_safetensors=False):
1557
+
1558
+ #model = self.diffusion_ema.ema_model
1559
+ model = self.diffusion
1560
+
1561
+ if use_safetensors:
1562
+ save_file(model.state_dict(), path)
1563
+ else:
1564
+ torch.save({"state_dict": model.state_dict()}, path)
1565
+
1566
+ class DiffusionPriorDemoCallback(pl.Callback):
1567
+ def __init__(
1568
+ self,
1569
+ demo_dl,
1570
+ demo_every=2000,
1571
+ demo_steps=250,
1572
+ sample_size=65536,
1573
+ sample_rate=48000
1574
+ ):
1575
+ super().__init__()
1576
+ self.demo_every = demo_every
1577
+ self.demo_steps = demo_steps
1578
+ self.demo_samples = sample_size
1579
+ self.demo_dl = iter(demo_dl)
1580
+ self.sample_rate = sample_rate
1581
+ self.last_demo_step = -1
1582
+
1583
+ @rank_zero_only
1584
+ @torch.no_grad()
1585
+ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx):
1586
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
1587
+ return
1588
+
1589
+ self.last_demo_step = trainer.global_step
1590
+
1591
+ demo_reals, metadata = next(self.demo_dl)
1592
+
1593
+ # Remove extra dimension added by WebDataset
1594
+ if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
1595
+ demo_reals = demo_reals[0]
1596
+
1597
+ demo_reals = demo_reals.to(module.device)
1598
+
1599
+ encoder_input = demo_reals
1600
+
1601
+ if module.diffusion.conditioner is not None:
1602
+ with torch.cuda.amp.autocast():
1603
+ conditioning_tensors = module.diffusion.conditioner(metadata, module.device)
1604
+
1605
+ else:
1606
+ conditioning_tensors = {}
1607
+
1608
+
1609
+ with torch.no_grad() and torch.cuda.amp.autocast():
1610
+ if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1:
1611
+ source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device)
1612
+
1613
+ if module.diffusion.pretransform is not None:
1614
+ encoder_input = module.diffusion.pretransform.encode(encoder_input)
1615
+ source_input = module.diffusion.pretransform.encode(source)
1616
+ else:
1617
+ source_input = source
1618
+
1619
+ conditioning_tensors['source'] = [source_input]
1620
+
1621
+ fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors)
1622
+
1623
+ if module.diffusion.pretransform is not None:
1624
+ fakes = module.diffusion.pretransform.decode(fakes)
1625
+
1626
+ #Interleave reals and fakes
1627
+ reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
1628
+
1629
+ # Put the demos together
1630
+ reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
1631
+
1632
+ log_dict = {}
1633
+
1634
+ filename = f'recon_{trainer.global_step:08}.wav'
1635
+ reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
1636
+ torchaudio.save(filename, reals_fakes, self.sample_rate)
1637
+
1638
+ log_dict[f'recon'] = wandb.Audio(filename,
1639
+ sample_rate=self.sample_rate,
1640
+ caption=f'Reconstructed')
1641
+
1642
+ log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
1643
+
1644
+ #Log the source
1645
+ filename = f'source_{trainer.global_step:08}.wav'
1646
+ source = rearrange(source, 'b d n -> d (b n)')
1647
+ source = source.to(torch.float32).mul(32767).to(torch.int16).cpu()
1648
+ torchaudio.save(filename, source, self.sample_rate)
1649
+
1650
+ log_dict[f'source'] = wandb.Audio(filename,
1651
+ sample_rate=self.sample_rate,
1652
+ caption=f'Source')
1653
+
1654
+ log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source))
1655
+
1656
+ trainer.logger.experiment.log(log_dict)
stable_audio_tools/training/factory.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import Parameter
3
+ from ..models.factory import create_model_from_config
4
+
5
+ def create_training_wrapper_from_config(model_config, model):
6
+ model_type = model_config.get('model_type', None)
7
+ assert model_type is not None, 'model_type must be specified in model config'
8
+
9
+ training_config = model_config.get('training', None)
10
+ assert training_config is not None, 'training config must be specified in model config'
11
+
12
+ if model_type == 'autoencoder':
13
+ from .autoencoders import AutoencoderTrainingWrapper
14
+
15
+ ema_copy = None
16
+
17
+ if training_config.get("use_ema", False):
18
+ ema_copy = create_model_from_config(model_config)
19
+ ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once
20
+ # Copy each weight to the ema copy
21
+ for name, param in model.state_dict().items():
22
+ if isinstance(param, Parameter):
23
+ # backwards compatibility for serialized parameters
24
+ param = param.data
25
+ ema_copy.state_dict()[name].copy_(param)
26
+
27
+ use_ema = training_config.get("use_ema", False)
28
+
29
+ latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0)
30
+
31
+ teacher_model = training_config.get("teacher_model", None)
32
+ if teacher_model is not None:
33
+ teacher_model = create_model_from_config(teacher_model)
34
+ teacher_model = teacher_model.eval().requires_grad_(False)
35
+
36
+ teacher_model_ckpt = training_config.get("teacher_model_ckpt", None)
37
+ if teacher_model_ckpt is not None:
38
+ teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"])
39
+ else:
40
+ raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified")
41
+
42
+ return AutoencoderTrainingWrapper(
43
+ model,
44
+ lr=training_config["learning_rate"],
45
+ warmup_steps=training_config.get("warmup_steps", 0),
46
+ encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False),
47
+ sample_rate=model_config["sample_rate"],
48
+ loss_config=training_config.get("loss_configs", None),
49
+ optimizer_configs=training_config.get("optimizer_configs", None),
50
+ use_ema=use_ema,
51
+ ema_copy=ema_copy if use_ema else None,
52
+ force_input_mono=training_config.get("force_input_mono", False),
53
+ latent_mask_ratio=latent_mask_ratio,
54
+ teacher_model=teacher_model
55
+ )
56
+ elif model_type == 'diffusion_uncond':
57
+ from .diffusion import DiffusionUncondTrainingWrapper
58
+ return DiffusionUncondTrainingWrapper(
59
+ model,
60
+ lr=training_config["learning_rate"],
61
+ pre_encoded=training_config.get("pre_encoded", False),
62
+ )
63
+ elif model_type == 'diffusion_cond':
64
+ from .diffusion import DiffusionCondTrainingWrapper
65
+ return DiffusionCondTrainingWrapper(
66
+ model,
67
+ lr=training_config.get("learning_rate", None),
68
+ mask_padding=training_config.get("mask_padding", False),
69
+ mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0),
70
+ use_ema = training_config.get("use_ema", True),
71
+ log_loss_info=training_config.get("log_loss_info", False),
72
+ optimizer_configs=training_config.get("optimizer_configs", None),
73
+ pre_encoded=training_config.get("pre_encoded", False),
74
+ cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
75
+ timestep_sampler = training_config.get("timestep_sampler", "uniform")
76
+ )
77
+ elif model_type == 'diffusion_prior':
78
+ from .diffusion import DiffusionPriorTrainingWrapper
79
+ from ..models.diffusion_prior import PriorType
80
+
81
+ ema_copy = create_model_from_config(model_config)
82
+
83
+ # Copy each weight to the ema copy
84
+ for name, param in model.state_dict().items():
85
+ if isinstance(param, Parameter):
86
+ # backwards compatibility for serialized parameters
87
+ param = param.data
88
+ ema_copy.state_dict()[name].copy_(param)
89
+
90
+ prior_type = training_config.get("prior_type", "mono_stereo")
91
+
92
+ if prior_type == "mono_stereo":
93
+ prior_type_enum = PriorType.MonoToStereo
94
+ else:
95
+ raise ValueError(f"Unknown prior type: {prior_type}")
96
+
97
+ return DiffusionPriorTrainingWrapper(
98
+ model,
99
+ lr=training_config["learning_rate"],
100
+ ema_copy=ema_copy,
101
+ prior_type=prior_type_enum,
102
+ log_loss_info=training_config.get("log_loss_info", False),
103
+ use_reconstruction_loss=training_config.get("use_reconstruction_loss", False),
104
+ )
105
+ elif model_type == 'diffusion_cond_inpaint':
106
+ from .diffusion import DiffusionCondInpaintTrainingWrapper
107
+ return DiffusionCondInpaintTrainingWrapper(
108
+ model,
109
+ lr=training_config.get("learning_rate", None),
110
+ max_mask_segments = training_config.get("max_mask_segments", 10),
111
+ log_loss_info=training_config.get("log_loss_info", False),
112
+ optimizer_configs=training_config.get("optimizer_configs", None),
113
+ use_ema=training_config.get("use_ema", True),
114
+ pre_encoded=training_config.get("pre_encoded", False),
115
+ cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
116
+ timestep_sampler = training_config.get("timestep_sampler", "uniform")
117
+ )
118
+ elif model_type == 'diffusion_autoencoder':
119
+ from .diffusion import DiffusionAutoencoderTrainingWrapper
120
+
121
+ ema_copy = create_model_from_config(model_config)
122
+
123
+ # Copy each weight to the ema copy
124
+ for name, param in model.state_dict().items():
125
+ if isinstance(param, Parameter):
126
+ # backwards compatibility for serialized parameters
127
+ param = param.data
128
+ ema_copy.state_dict()[name].copy_(param)
129
+
130
+ return DiffusionAutoencoderTrainingWrapper(
131
+ model,
132
+ ema_copy=ema_copy,
133
+ lr=training_config["learning_rate"],
134
+ use_reconstruction_loss=training_config.get("use_reconstruction_loss", False)
135
+ )
136
+ elif model_type == 'lm':
137
+ from .lm import AudioLanguageModelTrainingWrapper
138
+
139
+ ema_copy = create_model_from_config(model_config)
140
+
141
+ for name, param in model.state_dict().items():
142
+ if isinstance(param, Parameter):
143
+ # backwards compatibility for serialized parameters
144
+ param = param.data
145
+ ema_copy.state_dict()[name].copy_(param)
146
+
147
+ return AudioLanguageModelTrainingWrapper(
148
+ model,
149
+ ema_copy=ema_copy,
150
+ lr=training_config.get("learning_rate", None),
151
+ use_ema=training_config.get("use_ema", False),
152
+ optimizer_configs=training_config.get("optimizer_configs", None),
153
+ pre_encoded=training_config.get("pre_encoded", False),
154
+ )
155
+
156
+ else:
157
+ raise NotImplementedError(f'Unknown model type: {model_type}')
158
+
159
+ def create_demo_callback_from_config(model_config, **kwargs):
160
+ model_type = model_config.get('model_type', None)
161
+ assert model_type is not None, 'model_type must be specified in model config'
162
+
163
+ training_config = model_config.get('training', None)
164
+ assert training_config is not None, 'training config must be specified in model config'
165
+
166
+ demo_config = training_config.get("demo", {})
167
+
168
+ if model_type == 'autoencoder':
169
+ from .autoencoders import AutoencoderDemoCallback
170
+ return AutoencoderDemoCallback(
171
+ demo_every=demo_config.get("demo_every", 2000),
172
+ sample_size=model_config["sample_size"],
173
+ sample_rate=model_config["sample_rate"],
174
+ **kwargs
175
+ )
176
+ elif model_type == 'diffusion_uncond':
177
+ from .diffusion import DiffusionUncondDemoCallback
178
+ return DiffusionUncondDemoCallback(
179
+ demo_every=demo_config.get("demo_every", 2000),
180
+ demo_steps=demo_config.get("demo_steps", 250),
181
+ sample_rate=model_config["sample_rate"]
182
+ )
183
+ elif model_type == "diffusion_autoencoder":
184
+ from .diffusion import DiffusionAutoencoderDemoCallback
185
+ return DiffusionAutoencoderDemoCallback(
186
+ demo_every=demo_config.get("demo_every", 2000),
187
+ demo_steps=demo_config.get("demo_steps", 250),
188
+ sample_size=model_config["sample_size"],
189
+ sample_rate=model_config["sample_rate"],
190
+ **kwargs
191
+ )
192
+ elif model_type == "diffusion_prior":
193
+ from .diffusion import DiffusionPriorDemoCallback
194
+ return DiffusionPriorDemoCallback(
195
+ demo_every=demo_config.get("demo_every", 2000),
196
+ demo_steps=demo_config.get("demo_steps", 250),
197
+ sample_size=model_config["sample_size"],
198
+ sample_rate=model_config["sample_rate"],
199
+ **kwargs
200
+ )
201
+ elif model_type == "diffusion_cond":
202
+ from .diffusion import DiffusionCondDemoCallback
203
+
204
+ return DiffusionCondDemoCallback(
205
+ demo_every=demo_config.get("demo_every", 2000),
206
+ sample_size=model_config["sample_size"],
207
+ sample_rate=model_config["sample_rate"],
208
+ demo_steps=demo_config.get("demo_steps", 250),
209
+ num_demos=demo_config["num_demos"],
210
+ demo_cfg_scales=demo_config["demo_cfg_scales"],
211
+ demo_conditioning=demo_config.get("demo_cond", {}),
212
+ demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False),
213
+ display_audio_cond=demo_config.get("display_audio_cond", False),
214
+ )
215
+ elif model_type == "diffusion_cond_inpaint":
216
+ from .diffusion import DiffusionCondInpaintDemoCallback
217
+
218
+ return DiffusionCondInpaintDemoCallback(
219
+ demo_every=demo_config.get("demo_every", 2000),
220
+ sample_size=model_config["sample_size"],
221
+ sample_rate=model_config["sample_rate"],
222
+ demo_steps=demo_config.get("demo_steps", 250),
223
+ demo_cfg_scales=demo_config["demo_cfg_scales"],
224
+ **kwargs
225
+ )
226
+
227
+ elif model_type == "lm":
228
+ from .lm import AudioLanguageModelDemoCallback
229
+
230
+ return AudioLanguageModelDemoCallback(
231
+ demo_every=demo_config.get("demo_every", 2000),
232
+ sample_size=model_config["sample_size"],
233
+ sample_rate=model_config["sample_rate"],
234
+ demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]),
235
+ demo_conditioning=demo_config.get("demo_cond", None),
236
+ num_demos=demo_config.get("num_demos", 8),
237
+ **kwargs
238
+ )
239
+ else:
240
+ raise NotImplementedError(f'Unknown model type: {model_type}')
stable_audio_tools/training/lm.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import sys, gc
3
+ import random
4
+ import torch
5
+ import torchaudio
6
+ import typing as tp
7
+ import wandb
8
+
9
+ from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
10
+ from ema_pytorch import EMA
11
+ from einops import rearrange
12
+ from safetensors.torch import save_file
13
+ from torch import optim
14
+ from torch.nn import functional as F
15
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
16
+
17
+ from ..models.lm import AudioLanguageModelWrapper
18
+ from .utils import create_optimizer_from_config, create_scheduler_from_config
19
+
20
+ class AudioLanguageModelTrainingWrapper(pl.LightningModule):
21
+ def __init__(
22
+ self,
23
+ model: AudioLanguageModelWrapper,
24
+ lr = 1e-4,
25
+ use_ema=False,
26
+ ema_copy=None,
27
+ optimizer_configs: dict = None,
28
+ pre_encoded=False
29
+ ):
30
+ super().__init__()
31
+
32
+ self.model = model
33
+
34
+ self.model.pretransform.requires_grad_(False)
35
+
36
+ self.model_ema = None
37
+ if use_ema:
38
+ self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10)
39
+
40
+ assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
41
+
42
+ if optimizer_configs is None:
43
+ optimizer_configs = {
44
+ "lm": {
45
+ "optimizer": {
46
+ "type": "AdamW",
47
+ "config": {
48
+ "lr": lr,
49
+ "betas": (0.9, 0.95),
50
+ "weight_decay": 0.1
51
+ }
52
+ }
53
+ }
54
+ }
55
+ else:
56
+ if lr is not None:
57
+ print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
58
+
59
+ self.optimizer_configs = optimizer_configs
60
+
61
+ self.pre_encoded = pre_encoded
62
+
63
+ def configure_optimizers(self):
64
+ lm_opt_config = self.optimizer_configs['lm']
65
+ opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters())
66
+
67
+ if "scheduler" in lm_opt_config:
68
+ sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm)
69
+ sched_lm_config = {
70
+ "scheduler": sched_lm,
71
+ "interval": "step"
72
+ }
73
+ return [opt_lm], [sched_lm_config]
74
+
75
+ return [opt_lm]
76
+
77
+ # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license
78
+ # License can be found in LICENSES/LICENSE_META.txt
79
+
80
+ def _compute_cross_entropy(
81
+ self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
82
+ ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
83
+ """Compute cross entropy between multi-codebook targets and model's logits.
84
+ The cross entropy is computed per codebook to provide codebook-level cross entropy.
85
+ Valid timesteps for each of the codebook are pulled from the mask, where invalid
86
+ timesteps are set to 0.
87
+
88
+ Args:
89
+ logits (torch.Tensor): Model's logits of shape [B, K, T, card].
90
+ targets (torch.Tensor): Target codes, of shape [B, K, T].
91
+ mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
92
+ Returns:
93
+ ce (torch.Tensor): Cross entropy averaged over the codebooks
94
+ ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
95
+ """
96
+ B, K, T = targets.shape
97
+ assert logits.shape[:-1] == targets.shape
98
+ assert mask.shape == targets.shape
99
+ ce = torch.zeros([], device=targets.device)
100
+ ce_per_codebook: tp.List[torch.Tensor] = []
101
+ for k in range(K):
102
+ logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
103
+ targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
104
+ mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
105
+ ce_targets = targets_k[mask_k]
106
+ ce_logits = logits_k[mask_k]
107
+ q_ce = F.cross_entropy(ce_logits, ce_targets)
108
+ ce += q_ce
109
+ ce_per_codebook.append(q_ce.detach())
110
+ # average cross entropy across codebooks
111
+ ce = ce / K
112
+ return ce, ce_per_codebook
113
+
114
+ def training_step(self, batch, batch_idx):
115
+ reals, metadata = batch
116
+
117
+ if reals.ndim == 4 and reals.shape[0] == 1:
118
+ reals = reals[0]
119
+
120
+ if not self.pre_encoded:
121
+ codes = self.model.pretransform.tokenize(reals)
122
+ else:
123
+ codes = reals
124
+
125
+ padding_masks = []
126
+ for md in metadata:
127
+ if md["padding_mask"].ndim == 1:
128
+ padding_masks.append(md["padding_mask"])
129
+ else:
130
+ padding_masks.append(md["padding_mask"][0])
131
+
132
+ padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length)
133
+
134
+ # Interpolate padding masks to the same length as the codes
135
+ padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool()
136
+
137
+ condition_tensors = None
138
+
139
+ # If the model is conditioned, get the conditioning tensors
140
+ if self.model.conditioner is not None:
141
+ condition_tensors = self.model.conditioner(metadata, self.device)
142
+
143
+ lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1)
144
+
145
+ logits = lm_output.logits # [b, k, t, c]
146
+ logits_mask = lm_output.mask # [b, k, t]
147
+
148
+ logits_mask = logits_mask & padding_masks
149
+
150
+ cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask)
151
+
152
+ loss = cross_entropy
153
+
154
+ log_dict = {
155
+ 'train/loss': loss.detach(),
156
+ 'train/cross_entropy': cross_entropy.detach(),
157
+ 'train/perplexity': torch.exp(cross_entropy).detach(),
158
+ 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
159
+ }
160
+
161
+ for k, ce_q in enumerate(cross_entropy_per_codebook):
162
+ log_dict[f'cross_entropy_q{k + 1}'] = ce_q
163
+ log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q)
164
+
165
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
166
+ return loss
167
+
168
+ def on_before_zero_grad(self, *args, **kwargs):
169
+ if self.model_ema is not None:
170
+ self.model_ema.update()
171
+
172
+ def export_model(self, path, use_safetensors=False):
173
+
174
+ model = self.model_ema.ema_model if self.model_ema is not None else self.model
175
+
176
+ if use_safetensors:
177
+ save_file(model.state_dict(), path)
178
+ else:
179
+ torch.save({"state_dict": model.state_dict()}, path)
180
+
181
+
182
+ class AudioLanguageModelDemoCallback(pl.Callback):
183
+ def __init__(self,
184
+ demo_every=2000,
185
+ num_demos=8,
186
+ sample_size=65536,
187
+ sample_rate=48000,
188
+ demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
189
+ demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7],
190
+ **kwargs
191
+ ):
192
+ super().__init__()
193
+
194
+ self.demo_every = demo_every
195
+ self.num_demos = num_demos
196
+ self.demo_samples = sample_size
197
+ self.sample_rate = sample_rate
198
+ self.last_demo_step = -1
199
+ self.demo_conditioning = demo_conditioning
200
+ self.demo_cfg_scales = demo_cfg_scales
201
+
202
+ @rank_zero_only
203
+ @torch.no_grad()
204
+ def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx):
205
+
206
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
207
+ return
208
+
209
+ module.eval()
210
+
211
+ print(f"Generating demo")
212
+ self.last_demo_step = trainer.global_step
213
+
214
+ demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio
215
+
216
+ #demo_reals = batch[0][:self.num_demos]
217
+
218
+ # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
219
+ # demo_reals = demo_reals[0]
220
+
221
+ #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals)
222
+
223
+ ##Limit to first 50 tokens
224
+ #demo_reals_tokens = demo_reals_tokens[:, :, :50]
225
+
226
+ try:
227
+ print("Getting conditioning")
228
+
229
+ for cfg_scale in self.demo_cfg_scales:
230
+
231
+ model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model
232
+
233
+ print(f"Generating demo for cfg scale {cfg_scale}")
234
+ fakes = model.generate_audio(
235
+ batch_size=self.num_demos,
236
+ max_gen_len=demo_length_tokens,
237
+ conditioning=self.demo_conditioning,
238
+ #init_data = demo_reals_tokens,
239
+ cfg_scale=cfg_scale,
240
+ temp=1.0,
241
+ top_p=0.95
242
+ )
243
+
244
+ # Put the demos together
245
+ fakes = rearrange(fakes, 'b d n -> d (b n)')
246
+
247
+ log_dict = {}
248
+
249
+ filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
250
+ fakes = fakes / fakes.abs().max()
251
+ fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu()
252
+ torchaudio.save(filename, fakes, self.sample_rate)
253
+
254
+ log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
255
+ sample_rate=self.sample_rate,
256
+ caption=f'Reconstructed')
257
+
258
+ log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
259
+
260
+ trainer.logger.experiment.log(log_dict)
261
+
262
+ except Exception as e:
263
+ raise e
264
+ finally:
265
+ gc.collect()
266
+ torch.cuda.empty_cache()
267
+ module.train()
stable_audio_tools/training/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .losses import *
stable_audio_tools/training/losses/auraloss.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied and modified from https://github.com/csteinmetz1/auraloss/blob/main/auraloss/freq.py under Apache License 2.0
2
+ # You can find the license at LICENSES/LICENSE_AURALOSS.txt
3
+
4
+ import torch
5
+ import numpy as np
6
+ from typing import List, Any
7
+ import scipy.signal
8
+
9
+ def apply_reduction(losses, reduction="none"):
10
+ """Apply reduction to collection of losses."""
11
+ if reduction == "mean":
12
+ losses = losses.mean()
13
+ elif reduction == "sum":
14
+ losses = losses.sum()
15
+ return losses
16
+
17
+ def get_window(win_type: str, win_length: int):
18
+ """Return a window function.
19
+
20
+ Args:
21
+ win_type (str): Window type. Can either be one of the window function provided in PyTorch
22
+ ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
23
+ or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
24
+ win_length (int): Window length
25
+
26
+ Returns:
27
+ win: The window as a 1D torch tensor
28
+ """
29
+
30
+ try:
31
+ win = getattr(torch, win_type)(win_length)
32
+ except:
33
+ win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length))
34
+
35
+ return win
36
+
37
+ class SumAndDifference(torch.nn.Module):
38
+ """Sum and difference signal extraction module."""
39
+
40
+ def __init__(self):
41
+ """Initialize sum and difference extraction module."""
42
+ super(SumAndDifference, self).__init__()
43
+
44
+ def forward(self, x):
45
+ """Calculate forward propagation.
46
+
47
+ Args:
48
+ x (Tensor): Predicted signal (B, #channels, #samples).
49
+ Returns:
50
+ Tensor: Sum signal.
51
+ Tensor: Difference signal.
52
+ """
53
+ if not (x.size(1) == 2): # inputs must be stereo
54
+ raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).")
55
+
56
+ sum_sig = self.sum(x).unsqueeze(1)
57
+ diff_sig = self.diff(x).unsqueeze(1)
58
+
59
+ return sum_sig, diff_sig
60
+
61
+ @staticmethod
62
+ def sum(x):
63
+ return x[:, 0, :] + x[:, 1, :]
64
+
65
+ @staticmethod
66
+ def diff(x):
67
+ return x[:, 0, :] - x[:, 1, :]
68
+
69
+
70
+ class FIRFilter(torch.nn.Module):
71
+ """FIR pre-emphasis filtering module.
72
+
73
+ Args:
74
+ filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp"
75
+ coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85
76
+ ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101
77
+ plot (bool): Plot the magnitude respond of the filter. Default: False
78
+
79
+ Based upon the perceptual loss pre-empahsis filters proposed by
80
+ [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922).
81
+
82
+ A-weighting filter - "aw"
83
+ First-order highpass - "hp"
84
+ Folded differentiator - "fd"
85
+
86
+ Note that the default coefficeint value of 0.85 is optimized for
87
+ a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates.
88
+ """
89
+
90
+ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False):
91
+ """Initilize FIR pre-emphasis filtering module."""
92
+ super(FIRFilter, self).__init__()
93
+ self.filter_type = filter_type
94
+ self.coef = coef
95
+ self.fs = fs
96
+ self.ntaps = ntaps
97
+ self.plot = plot
98
+
99
+ import scipy.signal
100
+
101
+ if ntaps % 2 == 0:
102
+ raise ValueError(f"ntaps must be odd (ntaps={ntaps}).")
103
+
104
+ if filter_type == "hp":
105
+ self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
106
+ self.fir.weight.requires_grad = False
107
+ self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1)
108
+ elif filter_type == "fd":
109
+ self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
110
+ self.fir.weight.requires_grad = False
111
+ self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1)
112
+ elif filter_type == "aw":
113
+ # Definition of analog A-weighting filter according to IEC/CD 1672.
114
+ f1 = 20.598997
115
+ f2 = 107.65265
116
+ f3 = 737.86223
117
+ f4 = 12194.217
118
+ A1000 = 1.9997
119
+
120
+ NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0]
121
+ DENs = np.polymul(
122
+ [1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2],
123
+ [1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2],
124
+ )
125
+ DENs = np.polymul(
126
+ np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2]
127
+ )
128
+
129
+ # convert analog filter to digital filter
130
+ b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs)
131
+
132
+ # compute the digital filter frequency response
133
+ w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs)
134
+
135
+ # then we fit to 101 tap FIR filter with least squares
136
+ taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs)
137
+
138
+ # now implement this digital FIR filter as a Conv1d layer
139
+ self.fir = torch.nn.Conv1d(
140
+ 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2
141
+ )
142
+ self.fir.weight.requires_grad = False
143
+ self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1)
144
+
145
+ if plot:
146
+ from .plotting import compare_filters
147
+ compare_filters(b, a, taps, fs=fs)
148
+
149
+ def forward(self, input, target):
150
+ """Calculate forward propagation.
151
+ Args:
152
+ input (Tensor): Predicted signal (B, #channels, #samples).
153
+ target (Tensor): Groundtruth signal (B, #channels, #samples).
154
+ Returns:
155
+ Tensor: Filtered signal.
156
+ """
157
+ input = torch.nn.functional.conv1d(
158
+ input, self.fir.weight.data, padding=self.ntaps // 2
159
+ )
160
+ target = torch.nn.functional.conv1d(
161
+ target, self.fir.weight.data, padding=self.ntaps // 2
162
+ )
163
+ return input, target
164
+
165
+ class SpectralConvergenceLoss(torch.nn.Module):
166
+ """Spectral convergence loss module.
167
+
168
+ See [Arik et al., 2018](https://arxiv.org/abs/1808.06719).
169
+ """
170
+
171
+ def __init__(self):
172
+ super(SpectralConvergenceLoss, self).__init__()
173
+
174
+ def forward(self, x_mag, y_mag):
175
+ return (torch.norm(y_mag - x_mag, p="fro", dim=[-1, -2]) / torch.norm(y_mag, p="fro", dim=[-1, -2])).mean()
176
+
177
+ class STFTMagnitudeLoss(torch.nn.Module):
178
+ """STFT magnitude loss module.
179
+
180
+ See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
181
+ and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1)
182
+
183
+ Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the
184
+ compression strength (larger value results in more compression), and `log_eps` can be used
185
+ to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive
186
+ output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression.
187
+
188
+ Args:
189
+ log (bool, optional): Log-scale the STFT magnitudes,
190
+ or use linear scale. Default: True
191
+ log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm.
192
+ Default: 0.0
193
+ log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm.
194
+ Default: 1.0
195
+ distance (str, optional): Distance function ["L1", "L2"]. Default: "L1"
196
+ reduction (str, optional): Reduction of the loss elements. Default: "mean"
197
+ """
198
+
199
+ def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"):
200
+ super(STFTMagnitudeLoss, self).__init__()
201
+
202
+ self.log = log
203
+ self.log_eps = log_eps
204
+ self.log_fac = log_fac
205
+
206
+ if distance == "L1":
207
+ self.distance = torch.nn.L1Loss(reduction=reduction)
208
+ elif distance == "L2":
209
+ self.distance = torch.nn.MSELoss(reduction=reduction)
210
+ else:
211
+ raise ValueError(f"Invalid distance: '{distance}'.")
212
+
213
+ def forward(self, x_mag, y_mag):
214
+ if self.log:
215
+ x_mag = torch.log(self.log_fac * x_mag + self.log_eps)
216
+ y_mag = torch.log(self.log_fac * y_mag + self.log_eps)
217
+ return self.distance(x_mag, y_mag)
218
+
219
+
220
+ class STFTLoss(torch.nn.Module):
221
+ """STFT loss module.
222
+
223
+ See [Yamamoto et al. 2019](https://arxiv.org/abs/1904.04472).
224
+
225
+ Args:
226
+ fft_size (int, optional): FFT size in samples. Default: 1024
227
+ hop_size (int, optional): Hop size of the FFT in samples. Default: 256
228
+ win_length (int, optional): Length of the FFT analysis window. Default: 1024
229
+ window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch
230
+ ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
231
+ or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
232
+ Default: 'hann_window'
233
+ w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
234
+ w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
235
+ w_lin_mag_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0
236
+ w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0
237
+ sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None
238
+ scale (str, optional): Optional frequency scaling method, options include:
239
+ ['mel', 'chroma']
240
+ Default: None
241
+ n_bins (int, optional): Number of scaling frequency bins. Default: None.
242
+ perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
243
+ scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
244
+ eps (float, optional): Small epsilon value for stablity. Default: 1e-8
245
+ output (str, optional): Format of the loss returned.
246
+ 'loss' : Return only the raw, aggregate loss term.
247
+ 'full' : Return the raw loss, plus intermediate loss terms.
248
+ Default: 'loss'
249
+ reduction (str, optional): Specifies the reduction to apply to the output:
250
+ 'none': no reduction will be applied,
251
+ 'mean': the sum of the output will be divided by the number of elements in the output,
252
+ 'sum': the output will be summed.
253
+ Default: 'mean'
254
+ mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms.
255
+ device (str, optional): Place the filterbanks on specified device. Default: None
256
+
257
+ Returns:
258
+ loss:
259
+ Aggreate loss term. Only returned if output='loss'. By default.
260
+ loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss:
261
+ Aggregate and intermediate loss terms. Only returned if output='full'.
262
+ """
263
+
264
+ def __init__(
265
+ self,
266
+ fft_size: int = 1024,
267
+ hop_size: int = 256,
268
+ win_length: int = 1024,
269
+ window: str = "hann_window",
270
+ w_sc: float = 1.0,
271
+ w_log_mag: float = 1.0,
272
+ w_lin_mag: float = 0.0,
273
+ w_phs: float = 0.0,
274
+ sample_rate: float = None,
275
+ scale: str = None,
276
+ n_bins: int = None,
277
+ perceptual_weighting: bool = False,
278
+ scale_invariance: bool = False,
279
+ eps: float = 1e-8,
280
+ output: str = "loss",
281
+ reduction: str = "mean",
282
+ mag_distance: str = "L1",
283
+ device: Any = None,
284
+ **kwargs
285
+ ):
286
+ super().__init__()
287
+ self.fft_size = fft_size
288
+ self.hop_size = hop_size
289
+ self.win_length = win_length
290
+ self.window = get_window(window, win_length)
291
+ self.w_sc = w_sc
292
+ self.w_log_mag = w_log_mag
293
+ self.w_lin_mag = w_lin_mag
294
+ self.w_phs = w_phs
295
+ self.sample_rate = sample_rate
296
+ self.scale = scale
297
+ self.n_bins = n_bins
298
+ self.perceptual_weighting = perceptual_weighting
299
+ self.scale_invariance = scale_invariance
300
+ self.eps = eps
301
+ self.output = output
302
+ self.reduction = reduction
303
+ self.mag_distance = mag_distance
304
+ self.device = device
305
+
306
+ self.phs_used = bool(self.w_phs)
307
+
308
+ self.spectralconv = SpectralConvergenceLoss()
309
+ self.logstft = STFTMagnitudeLoss(
310
+ log=True,
311
+ reduction=reduction,
312
+ distance=mag_distance,
313
+ **kwargs
314
+ )
315
+ self.linstft = STFTMagnitudeLoss(
316
+ log=False,
317
+ reduction=reduction,
318
+ distance=mag_distance,
319
+ **kwargs
320
+ )
321
+
322
+ # setup mel filterbank
323
+ if scale is not None:
324
+ try:
325
+ import librosa.filters
326
+ except Exception as e:
327
+ print(e)
328
+ print("Try `pip install auraloss[all]`.")
329
+
330
+ if self.scale == "mel":
331
+ assert sample_rate != None # Must set sample rate to use mel scale
332
+ assert n_bins <= fft_size # Must be more FFT bins than Mel bins
333
+ fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins)
334
+ fb = torch.tensor(fb).unsqueeze(0)
335
+
336
+ elif self.scale == "chroma":
337
+ assert sample_rate != None # Must set sample rate to use chroma scale
338
+ assert n_bins <= fft_size # Must be more FFT bins than chroma bins
339
+ fb = librosa.filters.chroma(
340
+ sr=sample_rate, n_fft=fft_size, n_chroma=n_bins
341
+ )
342
+
343
+ else:
344
+ raise ValueError(
345
+ f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'."
346
+ )
347
+
348
+ self.register_buffer("fb", fb)
349
+
350
+ if scale is not None and device is not None:
351
+ self.fb = self.fb.to(self.device) # move filterbank to device
352
+
353
+ if self.perceptual_weighting:
354
+ if sample_rate is None:
355
+ raise ValueError(
356
+ f"`sample_rate` must be supplied when `perceptual_weighting = True`."
357
+ )
358
+ self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate)
359
+
360
+ def stft(self, x):
361
+ """Perform STFT.
362
+ Args:
363
+ x (Tensor): Input signal tensor (B, T).
364
+
365
+ Returns:
366
+ Tensor: x_mag, x_phs
367
+ Magnitude and phase spectra (B, fft_size // 2 + 1, frames).
368
+ """
369
+ x_stft = torch.stft(
370
+ x,
371
+ self.fft_size,
372
+ self.hop_size,
373
+ self.win_length,
374
+ self.window,
375
+ return_complex=True,
376
+ )
377
+ x_mag = torch.sqrt(
378
+ torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps)
379
+ )
380
+
381
+ # torch.angle is expensive, so it is only evaluated if the values are used in the loss
382
+ if self.phs_used:
383
+ x_phs = torch.angle(x_stft)
384
+ else:
385
+ x_phs = None
386
+
387
+ return x_mag, x_phs
388
+
389
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
390
+ bs, chs, seq_len = input.size()
391
+
392
+ if self.perceptual_weighting: # apply optional A-weighting via FIR filter
393
+ # since FIRFilter only support mono audio we will move channels to batch dim
394
+ input = input.view(bs * chs, 1, -1)
395
+ target = target.view(bs * chs, 1, -1)
396
+
397
+ # now apply the filter to both
398
+ self.prefilter.to(input.device)
399
+ input, target = self.prefilter(input, target)
400
+
401
+ # now move the channels back
402
+ input = input.view(bs, chs, -1)
403
+ target = target.view(bs, chs, -1)
404
+
405
+ # compute the magnitude and phase spectra of input and target
406
+ self.window = self.window.to(input.device)
407
+
408
+ x_mag, x_phs = self.stft(input.view(-1, input.size(-1)))
409
+ y_mag, y_phs = self.stft(target.view(-1, target.size(-1)))
410
+
411
+ # apply relevant transforms
412
+ if self.scale is not None:
413
+ self.fb = self.fb.to(input.device)
414
+ x_mag = torch.matmul(self.fb, x_mag)
415
+ y_mag = torch.matmul(self.fb, y_mag)
416
+
417
+ # normalize scales
418
+ if self.scale_invariance:
419
+ alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag**2).sum([-2, -1]))
420
+ y_mag = y_mag * alpha.unsqueeze(-1)
421
+
422
+ # compute loss terms
423
+ sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0
424
+ log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0
425
+ lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0
426
+ phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0
427
+
428
+ # combine loss terms
429
+ loss = (
430
+ (self.w_sc * sc_mag_loss)
431
+ + (self.w_log_mag * log_mag_loss)
432
+ + (self.w_lin_mag * lin_mag_loss)
433
+ + (self.w_phs * phs_loss)
434
+ )
435
+
436
+ loss = apply_reduction(loss, reduction=self.reduction)
437
+
438
+ if self.output == "loss":
439
+ return loss
440
+ elif self.output == "full":
441
+ return loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss
442
+
443
+ class MultiResolutionSTFTLoss(torch.nn.Module):
444
+ """Multi resolution STFT loss module.
445
+
446
+ See [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480)
447
+
448
+ Args:
449
+ fft_sizes (list): List of FFT sizes.
450
+ hop_sizes (list): List of hop sizes.
451
+ win_lengths (list): List of window lengths.
452
+ window (str, optional): Window to apply before FFT, options include:
453
+ 'hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
454
+ Default: 'hann_window'
455
+ w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
456
+ w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
457
+ w_lin_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0
458
+ w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0
459
+ sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None
460
+ scale (str, optional): Optional frequency scaling method, options include:
461
+ ['mel', 'chroma']
462
+ Default: None
463
+ n_bins (int, optional): Number of mel frequency bins. Required when scale = 'mel'. Default: None.
464
+ scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
465
+ """
466
+
467
+ def __init__(
468
+ self,
469
+ fft_sizes: List[int] = [1024, 2048, 512],
470
+ hop_sizes: List[int] = [120, 240, 50],
471
+ win_lengths: List[int] = [600, 1200, 240],
472
+ window: str = "hann_window",
473
+ w_sc: float = 1.0,
474
+ w_log_mag: float = 1.0,
475
+ w_lin_mag: float = 0.0,
476
+ w_phs: float = 0.0,
477
+ sample_rate: float = None,
478
+ scale: str = None,
479
+ n_bins: int = None,
480
+ perceptual_weighting: bool = False,
481
+ scale_invariance: bool = False,
482
+ **kwargs,
483
+ ):
484
+ super().__init__()
485
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all
486
+ self.fft_sizes = fft_sizes
487
+ self.hop_sizes = hop_sizes
488
+ self.win_lengths = win_lengths
489
+
490
+ self.stft_losses = torch.nn.ModuleList()
491
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
492
+ self.stft_losses += [
493
+ STFTLoss(
494
+ fs,
495
+ ss,
496
+ wl,
497
+ window,
498
+ w_sc,
499
+ w_log_mag,
500
+ w_lin_mag,
501
+ w_phs,
502
+ sample_rate,
503
+ scale,
504
+ n_bins,
505
+ perceptual_weighting,
506
+ scale_invariance,
507
+ **kwargs,
508
+ )
509
+ ]
510
+
511
+ def forward(self, x, y):
512
+ mrstft_loss = 0.0
513
+ sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss = [], [], [], []
514
+
515
+ for f in self.stft_losses:
516
+ if f.output == "full": # extract just first term
517
+ tmp_loss = f(x, y)
518
+ mrstft_loss += tmp_loss[0]
519
+ sc_mag_loss.append(tmp_loss[1])
520
+ log_mag_loss.append(tmp_loss[2])
521
+ lin_mag_loss.append(tmp_loss[3])
522
+ phs_loss.append(tmp_loss[4])
523
+ else:
524
+ mrstft_loss += f(x, y)
525
+
526
+ mrstft_loss /= len(self.stft_losses)
527
+
528
+ if f.output == "loss":
529
+ return mrstft_loss
530
+ else:
531
+ return mrstft_loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss
532
+
533
+
534
+ class SumAndDifferenceSTFTLoss(torch.nn.Module):
535
+ """Sum and difference sttereo STFT loss module.
536
+
537
+ See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291)
538
+
539
+ Args:
540
+ fft_sizes (List[int]): List of FFT sizes.
541
+ hop_sizes (List[int]): List of hop sizes.
542
+ win_lengths (List[int]): List of window lengths.
543
+ window (str, optional): Window function type.
544
+ w_sum (float, optional): Weight of the sum loss component. Default: 1.0
545
+ w_diff (float, optional): Weight of the difference loss component. Default: 1.0
546
+ perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
547
+ mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False
548
+ n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128
549
+ sample_rate (float, optional): Audio sample rate. Default: None
550
+ output (str, optional): Format of the loss returned.
551
+ 'loss' : Return only the raw, aggregate loss term.
552
+ 'full' : Return the raw loss, plus intermediate loss terms.
553
+ Default: 'loss'
554
+ """
555
+
556
+ def __init__(
557
+ self,
558
+ fft_sizes: List[int],
559
+ hop_sizes: List[int],
560
+ win_lengths: List[int],
561
+ window: str = "hann_window",
562
+ w_sum: float = 1.0,
563
+ w_diff: float = 1.0,
564
+ output: str = "loss",
565
+ **kwargs,
566
+ ):
567
+ super().__init__()
568
+ self.sd = SumAndDifference()
569
+ self.w_sum = w_sum
570
+ self.w_diff = w_diff
571
+ self.output = output
572
+ self.mrstft = MultiResolutionSTFTLoss(
573
+ fft_sizes,
574
+ hop_sizes,
575
+ win_lengths,
576
+ window,
577
+ **kwargs,
578
+ )
579
+
580
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
581
+ """This loss function assumes batched input of stereo audio in the time domain.
582
+
583
+ Args:
584
+ input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len).
585
+ target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len).
586
+
587
+ Returns:
588
+ loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'.
589
+ loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor):
590
+ Aggregate and intermediate loss terms. Only returned if output='full'.
591
+ """
592
+ assert input.shape == target.shape # must have same shape
593
+ bs, chs, seq_len = input.size()
594
+
595
+ # compute sum and difference signals for both
596
+ input_sum, input_diff = self.sd(input)
597
+ target_sum, target_diff = self.sd(target)
598
+
599
+ # compute error in STFT domain
600
+ sum_loss = self.mrstft(input_sum, target_sum)
601
+ diff_loss = self.mrstft(input_diff, target_diff)
602
+ loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2
603
+
604
+ if self.output == "loss":
605
+ return loss
606
+ elif self.output == "full":
607
+ return loss, sum_loss, diff_loss
stable_audio_tools/training/losses/losses.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ from torch.nn import functional as F
4
+ from torch import nn
5
+ import torch
6
+ class LossModule(nn.Module):
7
+ def __init__(self, name: str, weight: float = 1.0):
8
+ super().__init__()
9
+
10
+ self.name = name
11
+ self.weight = weight
12
+
13
+ def forward(self, info, *args, **kwargs):
14
+ raise NotImplementedError
15
+
16
+ class ValueLoss(LossModule):
17
+ def __init__(self, key: str, name, weight: float = 1.0):
18
+ super().__init__(name=name, weight=weight)
19
+
20
+ self.key = key
21
+
22
+ def forward(self, info):
23
+ return self.weight * info[self.key]
24
+
25
+ class L1Loss(LossModule):
26
+ def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'):
27
+ super().__init__(name=name, weight=weight)
28
+
29
+ self.key_a = key_a
30
+ self.key_b = key_b
31
+
32
+ self.mask_key = mask_key
33
+
34
+ def forward(self, info):
35
+ mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none')
36
+
37
+ if self.mask_key is not None and self.mask_key in info:
38
+ mse_loss = mse_loss[info[self.mask_key]]
39
+
40
+ mse_loss = mse_loss.mean()
41
+
42
+ return self.weight * mse_loss
43
+
44
+ class MSELoss(LossModule):
45
+ def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'):
46
+ super().__init__(name=name, weight=weight)
47
+
48
+ self.key_a = key_a
49
+ self.key_b = key_b
50
+
51
+ self.mask_key = mask_key
52
+
53
+ def forward(self, info):
54
+ mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none')
55
+
56
+ if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None:
57
+ mask = info[self.mask_key]
58
+
59
+ if mask.ndim == 2 and mse_loss.ndim == 3:
60
+ mask = mask.unsqueeze(1)
61
+
62
+ if mask.shape[1] != mse_loss.shape[1]:
63
+ mask = mask.repeat(1, mse_loss.shape[1], 1)
64
+
65
+ mse_loss = mse_loss[mask]
66
+
67
+ mse_loss = mse_loss.mean()
68
+
69
+ return self.weight * mse_loss
70
+
71
+ class AuralossLoss(LossModule):
72
+ def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1):
73
+ super().__init__(name, weight)
74
+
75
+ self.auraloss_module = auraloss_module
76
+
77
+ self.input_key = input_key
78
+ self.target_key = target_key
79
+
80
+ def forward(self, info):
81
+ loss = self.auraloss_module(info[self.input_key], info[self.target_key])
82
+
83
+ return self.weight * loss
84
+
85
+ class MultiLoss(nn.Module):
86
+ def __init__(self, losses: tp.List[LossModule]):
87
+ super().__init__()
88
+
89
+ self.losses = nn.ModuleList(losses)
90
+
91
+ def forward(self, info):
92
+ total_loss = 0
93
+
94
+ losses = {}
95
+
96
+ for loss_module in self.losses:
97
+ module_loss = loss_module(info)
98
+ total_loss += module_loss
99
+ losses[loss_module.name] = module_loss
100
+
101
+ return total_loss, losses
stable_audio_tools/training/utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ def get_rank():
5
+ """Get rank of current process."""
6
+
7
+ print(os.environ.keys())
8
+
9
+ if "SLURM_PROCID" in os.environ:
10
+ return int(os.environ["SLURM_PROCID"])
11
+
12
+ if not torch.distributed.is_available() or not torch.distributed.is_initialized():
13
+ return 0
14
+
15
+ return torch.distributed.get_rank()
16
+
17
+ class InverseLR(torch.optim.lr_scheduler._LRScheduler):
18
+ """Implements an inverse decay learning rate schedule with an optional exponential
19
+ warmup. When last_epoch=-1, sets initial lr as lr.
20
+ inv_gamma is the number of steps/epochs required for the learning rate to decay to
21
+ (1 / 2)**power of its original value.
22
+ Args:
23
+ optimizer (Optimizer): Wrapped optimizer.
24
+ inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
25
+ power (float): Exponential factor of learning rate decay. Default: 1.
26
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
27
+ Default: 0.
28
+ final_lr (float): The final learning rate. Default: 0.
29
+ last_epoch (int): The index of last epoch. Default: -1.
30
+ verbose (bool): If ``True``, prints a message to stdout for
31
+ each update. Default: ``False``.
32
+ """
33
+
34
+ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0.,
35
+ last_epoch=-1, verbose=False):
36
+ self.inv_gamma = inv_gamma
37
+ self.power = power
38
+ if not 0. <= warmup < 1:
39
+ raise ValueError('Invalid value for warmup')
40
+ self.warmup = warmup
41
+ self.final_lr = final_lr
42
+ super().__init__(optimizer, last_epoch, verbose)
43
+
44
+ def get_lr(self):
45
+ if not self._get_lr_called_within_step:
46
+ import warnings
47
+ warnings.warn("To get the last learning rate computed by the scheduler, "
48
+ "please use `get_last_lr()`.")
49
+
50
+ return self._get_closed_form_lr()
51
+
52
+ def _get_closed_form_lr(self):
53
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
54
+ lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
55
+ return [warmup * max(self.final_lr, base_lr * lr_mult)
56
+ for base_lr in self.base_lrs]
57
+
58
+ def copy_state_dict(model, state_dict):
59
+ """Load state_dict to model, but only for keys that match exactly.
60
+
61
+ Args:
62
+ model (nn.Module): model to load state_dict.
63
+ state_dict (OrderedDict): state_dict to load.
64
+ """
65
+ model_state_dict = model.state_dict()
66
+ for key in state_dict:
67
+ if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape:
68
+ if isinstance(state_dict[key], torch.nn.Parameter):
69
+ # backwards compatibility for serialized parameters
70
+ state_dict[key] = state_dict[key].data
71
+ model_state_dict[key] = state_dict[key]
72
+
73
+ model.load_state_dict(model_state_dict, strict=False)
74
+
75
+ def create_optimizer_from_config(optimizer_config, parameters):
76
+ """Create optimizer from config.
77
+
78
+ Args:
79
+ parameters (iterable): parameters to optimize.
80
+ optimizer_config (dict): optimizer config.
81
+
82
+ Returns:
83
+ torch.optim.Optimizer: optimizer.
84
+ """
85
+
86
+ optimizer_type = optimizer_config["type"]
87
+
88
+ if optimizer_type == "FusedAdam":
89
+ from deepspeed.ops.adam import FusedAdam
90
+ optimizer = FusedAdam(parameters, **optimizer_config["config"])
91
+ else:
92
+ optimizer_fn = getattr(torch.optim, optimizer_type)
93
+ optimizer = optimizer_fn(parameters, **optimizer_config["config"])
94
+ return optimizer
95
+
96
+ def create_scheduler_from_config(scheduler_config, optimizer):
97
+ """Create scheduler from config.
98
+
99
+ Args:
100
+ scheduler_config (dict): scheduler config.
101
+ optimizer (torch.optim.Optimizer): optimizer.
102
+
103
+ Returns:
104
+ torch.optim.lr_scheduler._LRScheduler: scheduler.
105
+ """
106
+ if scheduler_config["type"] == "InverseLR":
107
+ scheduler_fn = InverseLR
108
+ else:
109
+ scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"])
110
+ scheduler = scheduler_fn(optimizer, **scheduler_config["config"])
111
+ return scheduler