Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
4aa0f34
1
Parent(s):
9577cb2
update to faster inference
Browse files- app.py +17 -31
- dia/audio.py +27 -104
- dia/config.py +17 -26
- dia/layers.py +106 -337
- dia/model.py +314 -257
- dia/state.py +234 -0
app.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
-
import argparse
|
2 |
import tempfile
|
3 |
import time
|
4 |
from pathlib import Path
|
5 |
from typing import Optional, Tuple
|
6 |
-
import spaces
|
7 |
|
8 |
import gradio as gr
|
9 |
import numpy as np
|
@@ -12,40 +10,17 @@ import torch
|
|
12 |
|
13 |
from dia.model import Dia
|
14 |
|
15 |
-
# --- Global Setup ---
|
16 |
-
parser = argparse.ArgumentParser(description="Gradio interface for Nari TTS")
|
17 |
-
parser.add_argument(
|
18 |
-
"--device", type=str, default=None, help="Force device (e.g., 'cuda', 'mps', 'cpu')"
|
19 |
-
)
|
20 |
-
parser.add_argument("--share", action="store_true", help="Enable Gradio sharing")
|
21 |
-
|
22 |
-
args = parser.parse_args()
|
23 |
-
|
24 |
-
|
25 |
-
# Determine device
|
26 |
-
if args.device:
|
27 |
-
device = torch.device(args.device)
|
28 |
-
elif torch.cuda.is_available():
|
29 |
-
device = torch.device("cuda")
|
30 |
-
# Simplified MPS check for broader compatibility
|
31 |
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
32 |
-
# Basic check is usually sufficient, detailed check can be problematic
|
33 |
-
device = torch.device("mps")
|
34 |
-
else:
|
35 |
-
device = torch.device("cpu")
|
36 |
-
|
37 |
-
print(f"Using device: {device}")
|
38 |
|
39 |
# Load Nari model and config
|
40 |
print("Loading Nari model...")
|
41 |
try:
|
42 |
# Use the function from inference.py
|
43 |
-
model = Dia.from_pretrained("nari-labs/Dia-1.6B")
|
44 |
except Exception as e:
|
45 |
print(f"Error loading Nari model: {e}")
|
46 |
raise
|
47 |
|
48 |
-
|
49 |
def run_inference(
|
50 |
text_input: str,
|
51 |
audio_prompt_input: Optional[Tuple[int, np.ndarray]],
|
@@ -60,7 +35,7 @@ def run_inference(
|
|
60 |
Runs Nari inference using the globally loaded model and provided inputs.
|
61 |
Uses temporary files for text and audio prompt compatibility with inference.generate.
|
62 |
"""
|
63 |
-
|
64 |
|
65 |
if not text_input or text_input.isspace():
|
66 |
raise gr.Error("Text input cannot be empty.")
|
@@ -146,10 +121,9 @@ def run_inference(
|
|
146 |
cfg_scale=cfg_scale,
|
147 |
temperature=temperature,
|
148 |
top_p=top_p,
|
149 |
-
use_cfg_filter=True,
|
150 |
cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
|
151 |
use_torch_compile=False, # Keep False for Gradio stability
|
152 |
-
|
153 |
)
|
154 |
|
155 |
end_time = time.time()
|
@@ -192,6 +166,16 @@ def run_inference(
|
|
192 |
f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
|
193 |
)
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
else:
|
196 |
print("\nGeneration finished, but no valid tokens were produced.")
|
197 |
# Return default silence
|
@@ -383,8 +367,10 @@ with gr.Blocks(css=css) as demo:
|
|
383 |
else:
|
384 |
gr.Markdown("_(No examples configured or example prompt file missing)_")
|
385 |
|
386 |
-
|
387 |
# --- Launch the App ---
|
388 |
if __name__ == "__main__":
|
389 |
print("Launching Gradio interface...")
|
|
|
|
|
|
|
390 |
demo.launch()
|
|
|
|
|
1 |
import tempfile
|
2 |
import time
|
3 |
from pathlib import Path
|
4 |
from typing import Optional, Tuple
|
|
|
5 |
|
6 |
import gradio as gr
|
7 |
import numpy as np
|
|
|
10 |
|
11 |
from dia.model import Dia
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Load Nari model and config
|
15 |
print("Loading Nari model...")
|
16 |
try:
|
17 |
# Use the function from inference.py
|
18 |
+
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16")
|
19 |
except Exception as e:
|
20 |
print(f"Error loading Nari model: {e}")
|
21 |
raise
|
22 |
|
23 |
+
|
24 |
def run_inference(
|
25 |
text_input: str,
|
26 |
audio_prompt_input: Optional[Tuple[int, np.ndarray]],
|
|
|
35 |
Runs Nari inference using the globally loaded model and provided inputs.
|
36 |
Uses temporary files for text and audio prompt compatibility with inference.generate.
|
37 |
"""
|
38 |
+
global model, device # Access global model, config, device
|
39 |
|
40 |
if not text_input or text_input.isspace():
|
41 |
raise gr.Error("Text input cannot be empty.")
|
|
|
121 |
cfg_scale=cfg_scale,
|
122 |
temperature=temperature,
|
123 |
top_p=top_p,
|
|
|
124 |
cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
|
125 |
use_torch_compile=False, # Keep False for Gradio stability
|
126 |
+
audio_prompt=prompt_path_for_generate,
|
127 |
)
|
128 |
|
129 |
end_time = time.time()
|
|
|
166 |
f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
|
167 |
)
|
168 |
|
169 |
+
# Explicitly convert to int16 to prevent Gradio warning
|
170 |
+
if (
|
171 |
+
output_audio[1].dtype == np.float32
|
172 |
+
or output_audio[1].dtype == np.float64
|
173 |
+
):
|
174 |
+
audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
|
175 |
+
audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
|
176 |
+
output_audio = (output_sr, audio_for_gradio)
|
177 |
+
print("Converted audio to int16 for Gradio output.")
|
178 |
+
|
179 |
else:
|
180 |
print("\nGeneration finished, but no valid tokens were produced.")
|
181 |
# Return default silence
|
|
|
367 |
else:
|
368 |
gr.Markdown("_(No examples configured or example prompt file missing)_")
|
369 |
|
|
|
370 |
# --- Launch the App ---
|
371 |
if __name__ == "__main__":
|
372 |
print("Launching Gradio interface...")
|
373 |
+
|
374 |
+
# set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
|
375 |
+
# use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
|
376 |
demo.launch()
|
dia/audio.py
CHANGED
@@ -2,10 +2,10 @@ import typing as tp
|
|
2 |
|
3 |
import torch
|
4 |
|
5 |
-
from .config import DataConfig
|
6 |
|
7 |
-
|
8 |
-
|
|
|
9 |
"""
|
10 |
Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
|
11 |
Negative t_idx => BOS; t_idx >= T => PAD.
|
@@ -69,7 +69,9 @@ def apply_audio_delay(
|
|
69 |
|
70 |
# Equivalent of tf.gather_nd using advanced indexing
|
71 |
# Ensure indices are long type if not already (build_delay_indices should handle this)
|
72 |
-
gathered_flat = audio_BxTxC[
|
|
|
|
|
73 |
gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
|
74 |
|
75 |
# Create masks on the correct device
|
@@ -82,65 +84,16 @@ def apply_audio_delay(
|
|
82 |
|
83 |
# If mask_bos, BOS; else if mask_pad, PAD; else original gather
|
84 |
# All tensors should now be on the same device
|
85 |
-
result_BxTxC = torch.where(
|
86 |
-
|
87 |
-
return result_BxTxC
|
88 |
-
|
89 |
-
|
90 |
-
@torch.no_grad()
|
91 |
-
@torch.inference_mode()
|
92 |
-
def audio_to_codebook(
|
93 |
-
model,
|
94 |
-
input_values,
|
95 |
-
data_config: DataConfig,
|
96 |
-
padding_mask=None,
|
97 |
-
sample_rate=44100,
|
98 |
-
):
|
99 |
-
"""
|
100 |
-
Encodes the input audio waveform into discrete codes.
|
101 |
-
|
102 |
-
Args:
|
103 |
-
model: The model to use for encoding.
|
104 |
-
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
105 |
-
Float values of the input audio waveform.
|
106 |
-
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
107 |
-
Padding mask used to pad the `input_values`.
|
108 |
-
sample_rate (`int`, *optional*) :
|
109 |
-
Signal sampling_rate
|
110 |
-
|
111 |
-
Returns:
|
112 |
-
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
|
113 |
-
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
|
114 |
-
`codebook` of shape `[batch_size, num_codebooks, frames]`.
|
115 |
-
Scale is not used here.
|
116 |
-
|
117 |
-
"""
|
118 |
-
audio_data = model.preprocess(input_values, sample_rate)
|
119 |
-
|
120 |
-
if padding_mask is None:
|
121 |
-
padding_mask = torch.ones_like(input_values).bool()
|
122 |
-
|
123 |
-
_, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) # 1, C, T
|
124 |
-
seq_length = encoded_frame.shape[2]
|
125 |
-
|
126 |
-
t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
|
127 |
-
B=1,
|
128 |
-
T=seq_length,
|
129 |
-
C=data_config.channels,
|
130 |
-
delay_pattern=data_config.delay_pattern,
|
131 |
)
|
132 |
|
133 |
-
|
134 |
-
audio_BxTxC=encoded_frame.transpose(1, 2), # 1, T, C
|
135 |
-
pad_value=data_config.audio_pad_value,
|
136 |
-
bos_value=data_config.audio_bos_value,
|
137 |
-
precomp=(t_idx_BxTxC, indices_BTCx3),
|
138 |
-
)
|
139 |
-
|
140 |
-
return encoded_frame
|
141 |
|
142 |
|
143 |
-
def build_revert_indices(
|
|
|
|
|
144 |
"""
|
145 |
Precompute indices for the revert operation using PyTorch.
|
146 |
|
@@ -162,8 +115,12 @@ def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) ->
|
|
162 |
t_idx_BT1 + delay_arr.view(1, 1, C),
|
163 |
torch.tensor(T - 1, device=device),
|
164 |
)
|
165 |
-
b_idx_BxTxC = torch.broadcast_to(
|
166 |
-
|
|
|
|
|
|
|
|
|
167 |
|
168 |
indices_BTCx3 = torch.stack(
|
169 |
[
|
@@ -205,15 +162,21 @@ def revert_audio_delay(
|
|
205 |
indices_BTCx3 = indices_BTCx3.to(device)
|
206 |
|
207 |
# Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
|
208 |
-
gathered_flat = audio_BxTxC[
|
209 |
-
|
|
|
|
|
|
|
|
|
210 |
|
211 |
# Create pad_tensor on the correct device
|
212 |
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
213 |
# Create T tensor on the correct device for comparison
|
214 |
T_tensor = torch.tensor(T, device=device)
|
215 |
|
216 |
-
result_BxTxC = torch.where(
|
|
|
|
|
217 |
|
218 |
return result_BxTxC
|
219 |
|
@@ -238,43 +201,3 @@ def decode(
|
|
238 |
except Exception as e:
|
239 |
print(f"Error in decode method: {str(e)}")
|
240 |
raise
|
241 |
-
|
242 |
-
|
243 |
-
def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9):
|
244 |
-
"""Process a single codebook file to generate audio"""
|
245 |
-
# Remove BOS token
|
246 |
-
generated_codes = generated_codes[:, 1:]
|
247 |
-
|
248 |
-
if generated_codes.shape[1] > T:
|
249 |
-
generated_codes = generated_codes[:, :T]
|
250 |
-
|
251 |
-
seq_length = generated_codes.shape[1]
|
252 |
-
|
253 |
-
# Build revert indices
|
254 |
-
t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern)
|
255 |
-
|
256 |
-
# Transpose and add batch dimension
|
257 |
-
audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0)
|
258 |
-
reverted_codebook = revert_audio_delay(
|
259 |
-
audio_BxTxC=audio_BxTxC,
|
260 |
-
pad_value=0,
|
261 |
-
precomp=(t_idx_BxTxC, indices_BTCx3),
|
262 |
-
T=seq_length,
|
263 |
-
)
|
264 |
-
reverted_codebook = reverted_codebook[:, :-30, :]
|
265 |
-
|
266 |
-
codebook = reverted_codebook.transpose(1, 2)
|
267 |
-
|
268 |
-
min_valid_index = 0
|
269 |
-
max_valid_index = 1023
|
270 |
-
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
|
271 |
-
|
272 |
-
num_invalid = torch.sum(invalid_mask).item()
|
273 |
-
if num_invalid > 0:
|
274 |
-
print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.")
|
275 |
-
|
276 |
-
# Set invalid values to 0 (modify the tensor in-place)
|
277 |
-
codebook[invalid_mask] = 0
|
278 |
-
audio_array = decode(model, codebook)
|
279 |
-
|
280 |
-
return audio_array
|
|
|
2 |
|
3 |
import torch
|
4 |
|
|
|
5 |
|
6 |
+
def build_delay_indices(
|
7 |
+
B: int, T: int, C: int, delay_pattern: tp.List[int]
|
8 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
9 |
"""
|
10 |
Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
|
11 |
Negative t_idx => BOS; t_idx >= T => PAD.
|
|
|
69 |
|
70 |
# Equivalent of tf.gather_nd using advanced indexing
|
71 |
# Ensure indices are long type if not already (build_delay_indices should handle this)
|
72 |
+
gathered_flat = audio_BxTxC[
|
73 |
+
indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
|
74 |
+
]
|
75 |
gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
|
76 |
|
77 |
# Create masks on the correct device
|
|
|
84 |
|
85 |
# If mask_bos, BOS; else if mask_pad, PAD; else original gather
|
86 |
# All tensors should now be on the same device
|
87 |
+
result_BxTxC = torch.where(
|
88 |
+
mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
)
|
90 |
|
91 |
+
return result_BxTxC
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
|
94 |
+
def build_revert_indices(
|
95 |
+
B: int, T: int, C: int, delay_pattern: tp.List[int]
|
96 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
97 |
"""
|
98 |
Precompute indices for the revert operation using PyTorch.
|
99 |
|
|
|
115 |
t_idx_BT1 + delay_arr.view(1, 1, C),
|
116 |
torch.tensor(T - 1, device=device),
|
117 |
)
|
118 |
+
b_idx_BxTxC = torch.broadcast_to(
|
119 |
+
torch.arange(B, device=device).view(B, 1, 1), [B, T, C]
|
120 |
+
)
|
121 |
+
c_idx_BxTxC = torch.broadcast_to(
|
122 |
+
torch.arange(C, device=device).view(1, 1, C), [B, T, C]
|
123 |
+
)
|
124 |
|
125 |
indices_BTCx3 = torch.stack(
|
126 |
[
|
|
|
162 |
indices_BTCx3 = indices_BTCx3.to(device)
|
163 |
|
164 |
# Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
|
165 |
+
gathered_flat = audio_BxTxC[
|
166 |
+
indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
|
167 |
+
]
|
168 |
+
gathered_BxTxC = gathered_flat.view(
|
169 |
+
audio_BxTxC.size()
|
170 |
+
) # Use .size() for robust reshaping
|
171 |
|
172 |
# Create pad_tensor on the correct device
|
173 |
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
174 |
# Create T tensor on the correct device for comparison
|
175 |
T_tensor = torch.tensor(T, device=device)
|
176 |
|
177 |
+
result_BxTxC = torch.where(
|
178 |
+
t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC
|
179 |
+
) # Changed np.where to torch.where
|
180 |
|
181 |
return result_BxTxC
|
182 |
|
|
|
201 |
except Exception as e:
|
202 |
print(f"Error in decode method: {str(e)}")
|
203 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dia/config.py
CHANGED
@@ -33,14 +33,20 @@ class DataConfig(BaseModel, frozen=True):
|
|
33 |
delay_pattern: List of delay values for each audio channel.
|
34 |
"""
|
35 |
|
36 |
-
text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] =
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
channels: int = Field(default=9, gt=0, multiple_of=1)
|
39 |
text_pad_value: int = Field(default=0)
|
40 |
audio_eos_value: int = Field(default=1024)
|
41 |
audio_pad_value: int = Field(default=1025)
|
42 |
audio_bos_value: int = Field(default=1026)
|
43 |
-
delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(
|
|
|
|
|
44 |
|
45 |
def __hash__(self) -> int:
|
46 |
"""Generate a hash based on all fields of the config."""
|
@@ -67,8 +73,6 @@ class EncoderConfig(BaseModel, frozen=True):
|
|
67 |
n_hidden: Hidden dimension size in the MLP layers.
|
68 |
n_head: Number of attention heads.
|
69 |
head_dim: Dimension per attention head.
|
70 |
-
mlp_activations: List of activation functions for the MLP layers.
|
71 |
-
use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP).
|
72 |
"""
|
73 |
|
74 |
n_layer: int = Field(gt=0)
|
@@ -76,8 +80,6 @@ class EncoderConfig(BaseModel, frozen=True):
|
|
76 |
n_hidden: int = Field(gt=0)
|
77 |
n_head: int = Field(gt=0)
|
78 |
head_dim: int = Field(gt=0)
|
79 |
-
mlp_activations: list[str] = Field(default=["silu", "linear"])
|
80 |
-
use_pre_norm: bool = Field(default=False)
|
81 |
|
82 |
|
83 |
class DecoderConfig(BaseModel, frozen=True):
|
@@ -92,8 +94,6 @@ class DecoderConfig(BaseModel, frozen=True):
|
|
92 |
gqa_head_dim: Dimension per query head for grouped-query self-attention.
|
93 |
cross_query_heads: Number of query heads for cross-attention.
|
94 |
cross_head_dim: Dimension per cross-attention head.
|
95 |
-
mlp_activations: List of activation functions for the MLP layers.
|
96 |
-
use_pre_norm: Whether to use pre-normalization.
|
97 |
"""
|
98 |
|
99 |
n_layer: int = Field(gt=0)
|
@@ -104,8 +104,6 @@ class DecoderConfig(BaseModel, frozen=True):
|
|
104 |
gqa_head_dim: int = Field(gt=0)
|
105 |
cross_query_heads: int = Field(gt=0)
|
106 |
cross_head_dim: int = Field(gt=0)
|
107 |
-
mlp_activations: list[str] = Field(default=["silu", "linear"])
|
108 |
-
use_pre_norm: bool = Field(default=False)
|
109 |
|
110 |
|
111 |
class ModelConfig(BaseModel, frozen=True):
|
@@ -130,24 +128,16 @@ class ModelConfig(BaseModel, frozen=True):
|
|
130 |
dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
|
131 |
normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
|
132 |
weight_dtype: str = Field(default="float32", description="Weight precision")
|
133 |
-
rope_min_timescale: int = Field(
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
|
136 |
|
137 |
class TrainingConfig(BaseModel, frozen=True):
|
138 |
-
|
139 |
-
|
140 |
-
Note: This configuration currently only includes precision settings.
|
141 |
-
Other training parameters (like batch size, learning rate, optimizer settings)
|
142 |
-
are assumed to be handled externally.
|
143 |
-
|
144 |
-
Attributes:
|
145 |
-
dtype: Data type for activations during training (e.g., "bfloat16", "float32").
|
146 |
-
logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability.
|
147 |
-
"""
|
148 |
-
|
149 |
-
dtype: str = Field(default="bfloat16", description="Activation precision")
|
150 |
-
logits_dot_in_fp32: bool = Field(default=False)
|
151 |
|
152 |
|
153 |
class DiaConfig(BaseModel, frozen=True):
|
@@ -164,6 +154,7 @@ class DiaConfig(BaseModel, frozen=True):
|
|
164 |
|
165 |
version: str = Field(default="1.0")
|
166 |
model: ModelConfig
|
|
|
167 |
training: TrainingConfig
|
168 |
data: DataConfig
|
169 |
|
|
|
33 |
delay_pattern: List of delay values for each audio channel.
|
34 |
"""
|
35 |
|
36 |
+
text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
|
37 |
+
Field(gt=0, multiple_of=128)
|
38 |
+
)
|
39 |
+
audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
|
40 |
+
Field(gt=0, multiple_of=128)
|
41 |
+
)
|
42 |
channels: int = Field(default=9, gt=0, multiple_of=1)
|
43 |
text_pad_value: int = Field(default=0)
|
44 |
audio_eos_value: int = Field(default=1024)
|
45 |
audio_pad_value: int = Field(default=1025)
|
46 |
audio_bos_value: int = Field(default=1026)
|
47 |
+
delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(
|
48 |
+
default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
|
49 |
+
)
|
50 |
|
51 |
def __hash__(self) -> int:
|
52 |
"""Generate a hash based on all fields of the config."""
|
|
|
73 |
n_hidden: Hidden dimension size in the MLP layers.
|
74 |
n_head: Number of attention heads.
|
75 |
head_dim: Dimension per attention head.
|
|
|
|
|
76 |
"""
|
77 |
|
78 |
n_layer: int = Field(gt=0)
|
|
|
80 |
n_hidden: int = Field(gt=0)
|
81 |
n_head: int = Field(gt=0)
|
82 |
head_dim: int = Field(gt=0)
|
|
|
|
|
83 |
|
84 |
|
85 |
class DecoderConfig(BaseModel, frozen=True):
|
|
|
94 |
gqa_head_dim: Dimension per query head for grouped-query self-attention.
|
95 |
cross_query_heads: Number of query heads for cross-attention.
|
96 |
cross_head_dim: Dimension per cross-attention head.
|
|
|
|
|
97 |
"""
|
98 |
|
99 |
n_layer: int = Field(gt=0)
|
|
|
104 |
gqa_head_dim: int = Field(gt=0)
|
105 |
cross_query_heads: int = Field(gt=0)
|
106 |
cross_head_dim: int = Field(gt=0)
|
|
|
|
|
107 |
|
108 |
|
109 |
class ModelConfig(BaseModel, frozen=True):
|
|
|
128 |
dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
|
129 |
normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
|
130 |
weight_dtype: str = Field(default="float32", description="Weight precision")
|
131 |
+
rope_min_timescale: int = Field(
|
132 |
+
default=1, description="Timescale For global Attention"
|
133 |
+
)
|
134 |
+
rope_max_timescale: int = Field(
|
135 |
+
default=10_000, description="Timescale For global Attention"
|
136 |
+
)
|
137 |
|
138 |
|
139 |
class TrainingConfig(BaseModel, frozen=True):
|
140 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
|
143 |
class DiaConfig(BaseModel, frozen=True):
|
|
|
154 |
|
155 |
version: str = Field(default="1.0")
|
156 |
model: ModelConfig
|
157 |
+
# TODO: remove training. this is just for backwards-compatability
|
158 |
training: TrainingConfig
|
159 |
data: DataConfig
|
160 |
|
dia/layers.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
from typing import Any
|
2 |
-
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
@@ -7,26 +5,13 @@ from torch import Tensor
|
|
7 |
from torch.nn import RMSNorm
|
8 |
|
9 |
from .config import DiaConfig
|
|
|
10 |
|
11 |
|
12 |
def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
|
13 |
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
|
14 |
|
15 |
|
16 |
-
def _str_to_dtype(dtype_str: str) -> torch.dtype | None:
|
17 |
-
# Allow None for default behavior
|
18 |
-
if dtype_str is None or dtype_str.lower() == "none":
|
19 |
-
return None
|
20 |
-
if dtype_str == "float32":
|
21 |
-
return torch.float32
|
22 |
-
elif dtype_str == "float16":
|
23 |
-
return torch.float16
|
24 |
-
elif dtype_str == "bfloat16":
|
25 |
-
return torch.bfloat16
|
26 |
-
else:
|
27 |
-
raise ValueError(f"Unsupported dtype string: {dtype_str}")
|
28 |
-
|
29 |
-
|
30 |
class DenseGeneral(nn.Module):
|
31 |
"""
|
32 |
PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
|
@@ -50,7 +35,6 @@ class DenseGeneral(nn.Module):
|
|
50 |
in_shapes: tuple[int, ...],
|
51 |
out_features: tuple[int, ...],
|
52 |
axis: tuple[int, ...] = (-1,),
|
53 |
-
dtype: torch.dtype | None = None,
|
54 |
weight_dtype: torch.dtype | None = None,
|
55 |
device: torch.device | None = None,
|
56 |
):
|
@@ -58,7 +42,6 @@ class DenseGeneral(nn.Module):
|
|
58 |
self.in_shapes = in_shapes
|
59 |
self.out_features = out_features
|
60 |
self.axis = axis
|
61 |
-
self.dtype = dtype
|
62 |
self.kernel_shape = self.in_shapes + self.out_features
|
63 |
|
64 |
factory_kwargs = {"device": device, "dtype": weight_dtype}
|
@@ -70,95 +53,44 @@ class DenseGeneral(nn.Module):
|
|
70 |
kernel_contract_axes = tuple(range(len(norm_axis)))
|
71 |
|
72 |
output = torch.tensordot(
|
73 |
-
inputs.
|
74 |
-
self.weight
|
75 |
dims=(norm_axis, kernel_contract_axes),
|
76 |
).to(inputs.dtype)
|
77 |
return output
|
78 |
|
79 |
|
80 |
-
def get_activation_fn(activation_string: str) -> nn.Module: # Return Module instance
|
81 |
-
"""Maps activation string to PyTorch activation function module."""
|
82 |
-
if activation_string == "gelu":
|
83 |
-
return nn.GELU()
|
84 |
-
elif activation_string == "relu":
|
85 |
-
return nn.ReLU()
|
86 |
-
elif activation_string == "silu" or activation_string == "swish":
|
87 |
-
return nn.SiLU()
|
88 |
-
elif activation_string == "linear":
|
89 |
-
return nn.Identity()
|
90 |
-
else:
|
91 |
-
raise ValueError(f"Unsupported activation function: {activation_string}")
|
92 |
-
|
93 |
-
|
94 |
class MlpBlock(nn.Module):
|
95 |
"""MLP block using DenseGeneral."""
|
96 |
|
97 |
def __init__(
|
98 |
-
self,
|
99 |
-
config: DiaConfig,
|
100 |
-
embed_dim: int,
|
101 |
-
intermediate_dim: int,
|
102 |
-
dropout_rate: float,
|
103 |
-
activations: list[str] = ["silu", "linear"],
|
104 |
-
use_pre_norm: bool = False,
|
105 |
):
|
106 |
super().__init__()
|
107 |
-
self.use_pre_norm = use_pre_norm
|
108 |
-
num_activations = len(activations)
|
109 |
-
compute_dtype = _str_to_dtype(config.training.dtype)
|
110 |
-
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
111 |
self.dtype = compute_dtype
|
112 |
-
# Assume default device for now, could be passed in config
|
113 |
-
|
114 |
-
if use_pre_norm:
|
115 |
-
self.pre_norm = RMSNorm(
|
116 |
-
embed_dim,
|
117 |
-
eps=config.model.normalization_layer_epsilon,
|
118 |
-
dtype=torch.float32,
|
119 |
-
)
|
120 |
|
121 |
self.wi_fused = DenseGeneral(
|
122 |
in_shapes=(embed_dim,),
|
123 |
-
out_features=(
|
124 |
-
num_activations,
|
125 |
-
intermediate_dim,
|
126 |
-
),
|
127 |
axis=(-1,),
|
128 |
-
|
129 |
-
weight_dtype=weight_dtype,
|
130 |
)
|
131 |
|
132 |
-
self.activation_fn_0 = get_activation_fn(activations[0]) # silu
|
133 |
-
self.activation_fn_1 = get_activation_fn(activations[1]) # linear
|
134 |
-
|
135 |
-
self.dropout = nn.Dropout(dropout_rate)
|
136 |
-
|
137 |
-
# Output layer using DenseGeneral
|
138 |
self.wo = DenseGeneral(
|
139 |
in_shapes=(intermediate_dim,),
|
140 |
out_features=(embed_dim,),
|
141 |
axis=(-1,),
|
142 |
-
|
143 |
-
weight_dtype=weight_dtype,
|
144 |
)
|
145 |
|
146 |
-
def forward(self, x: torch.Tensor
|
147 |
"""Forward pass."""
|
148 |
-
if self.use_pre_norm and hasattr(self, "pre_norm"):
|
149 |
-
x = self.pre_norm(x)
|
150 |
-
|
151 |
fused_x = self.wi_fused(x)
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
gate = self.activation_fn_0(gate_input)
|
157 |
-
up = self.activation_fn_1(up_input)
|
158 |
-
hidden = torch.mul(gate, up).to(self.dtype)
|
159 |
|
160 |
-
|
161 |
-
hidden = self.dropout(hidden)
|
162 |
|
163 |
output = self.wo(hidden)
|
164 |
return output
|
@@ -207,37 +139,6 @@ class RotaryEmbedding(nn.Module):
|
|
207 |
return torch.cat((first_part, second_part), dim=-1)
|
208 |
|
209 |
|
210 |
-
class KVCache:
|
211 |
-
def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None):
|
212 |
-
self.k = torch.zeros((2, num_heads, max_len, head_dim), device=device) if k is None else k
|
213 |
-
self.v = torch.zeros((2, num_heads, max_len, head_dim), device=device) if v is None else v
|
214 |
-
self.current_idx = 0
|
215 |
-
self.max_len = max_len
|
216 |
-
|
217 |
-
def get_kv_for_attention(self, current_k, current_v):
|
218 |
-
if self.current_idx == 0:
|
219 |
-
return current_k, current_v
|
220 |
-
else:
|
221 |
-
past_k = self.k[:, :, : self.current_idx, :]
|
222 |
-
past_v = self.v[:, :, : self.current_idx, :]
|
223 |
-
attn_k = torch.cat((past_k, current_k), dim=2)
|
224 |
-
attn_v = torch.cat((past_v, current_v), dim=2)
|
225 |
-
return attn_k, attn_v
|
226 |
-
|
227 |
-
def update_cache(self, k, v):
|
228 |
-
assert self.current_idx < self.max_len
|
229 |
-
self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
|
230 |
-
self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
|
231 |
-
self.current_idx += 1
|
232 |
-
|
233 |
-
def prefill_kv(self, k, v):
|
234 |
-
prefill_len = k.shape[2]
|
235 |
-
assert prefill_len <= self.max_len
|
236 |
-
self.k[:, :, :prefill_len, :] = k
|
237 |
-
self.v[:, :, :prefill_len, :] = v
|
238 |
-
self.current_idx = prefill_len
|
239 |
-
|
240 |
-
|
241 |
class Attention(nn.Module):
|
242 |
"""Attention using DenseGeneral."""
|
243 |
|
@@ -249,7 +150,7 @@ class Attention(nn.Module):
|
|
249 |
num_query_heads: int,
|
250 |
num_kv_heads: int,
|
251 |
head_dim: int,
|
252 |
-
|
253 |
is_cross_attn: bool = False,
|
254 |
out_embed_dim: int | None = None,
|
255 |
):
|
@@ -258,13 +159,12 @@ class Attention(nn.Module):
|
|
258 |
self.num_kv_heads = num_kv_heads
|
259 |
self.head_dim = head_dim
|
260 |
self.is_cross_attn = is_cross_attn
|
261 |
-
self.dropout_rate = dropout_rate
|
262 |
-
compute_dtype = _str_to_dtype(config.training.dtype)
|
263 |
-
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
264 |
self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
|
265 |
self.projected_query_dim = num_query_heads * head_dim
|
266 |
if num_query_heads % num_kv_heads != 0:
|
267 |
-
raise ValueError(
|
|
|
|
|
268 |
self.num_gqa_groups = num_query_heads // num_kv_heads
|
269 |
|
270 |
# --- Projection Layers using DenseGeneral ---
|
@@ -272,29 +172,25 @@ class Attention(nn.Module):
|
|
272 |
in_shapes=(q_embed_dim,),
|
273 |
out_features=(num_query_heads, head_dim),
|
274 |
axis=(-1,),
|
275 |
-
|
276 |
-
weight_dtype=weight_dtype,
|
277 |
)
|
278 |
self.k_proj = DenseGeneral(
|
279 |
in_shapes=(kv_embed_dim,),
|
280 |
out_features=(num_kv_heads, head_dim),
|
281 |
axis=(-1,),
|
282 |
-
|
283 |
-
weight_dtype=weight_dtype,
|
284 |
)
|
285 |
self.v_proj = DenseGeneral(
|
286 |
in_shapes=(kv_embed_dim,),
|
287 |
out_features=(num_kv_heads, head_dim),
|
288 |
axis=(-1,),
|
289 |
-
|
290 |
-
weight_dtype=weight_dtype,
|
291 |
)
|
292 |
self.o_proj = DenseGeneral(
|
293 |
in_shapes=(num_query_heads, head_dim),
|
294 |
out_features=(self.output_dim,),
|
295 |
axis=(-2, -1),
|
296 |
-
|
297 |
-
weight_dtype=weight_dtype,
|
298 |
)
|
299 |
|
300 |
# --- Rotary Embedding ---
|
@@ -311,10 +207,11 @@ class Attention(nn.Module):
|
|
311 |
Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
|
312 |
q_positions: torch.Tensor, # (B, T)
|
313 |
kv_positions: torch.Tensor | None = None, # (B, S)
|
314 |
-
|
315 |
-
|
316 |
cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
|
317 |
-
prefill: bool = False,
|
|
|
318 |
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
|
319 |
"""
|
320 |
Performs attention calculation with optional KV caching.
|
@@ -324,7 +221,6 @@ class Attention(nn.Module):
|
|
324 |
Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
|
325 |
q_positions: Positions for queries (B, T).
|
326 |
kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
|
327 |
-
deterministic: If True, disable dropout.
|
328 |
attn_mask: Attention mask.
|
329 |
cache: KVCache.
|
330 |
prefill: If True, use prefill mode.
|
@@ -342,72 +238,51 @@ class Attention(nn.Module):
|
|
342 |
Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
|
343 |
Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
|
344 |
|
345 |
-
# Input values into attention calculation
|
346 |
attn_k: torch.Tensor | None = None
|
347 |
attn_v: torch.Tensor | None = None
|
348 |
-
new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
|
349 |
|
350 |
-
# Decoder Cross Attention
|
351 |
if self.is_cross_attn:
|
352 |
-
# Directly use cache (no need to check index)
|
353 |
attn_k, attn_v = cache.k, cache.v
|
354 |
-
if attn_k.shape[1] != self.num_query_heads or attn_v.shape[1] != self.num_query_heads:
|
355 |
-
raise ValueError(
|
356 |
-
f"Cross-attention cache head dimension ({attn_k.shape[1]}) "
|
357 |
-
f"does not match num_query_heads ({self.num_query_heads}). "
|
358 |
-
"Cache should be pre-repeated for GQA."
|
359 |
-
)
|
360 |
-
# Self Attention
|
361 |
else:
|
362 |
Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
|
363 |
Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
|
364 |
-
Xk_BxSxKxH = self.rotary_emb(
|
|
|
|
|
365 |
|
366 |
Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
367 |
Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
368 |
-
# S=1 for Decode Step
|
369 |
-
|
370 |
-
if self.num_gqa_groups > 1:
|
371 |
-
Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
|
372 |
-
Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
|
373 |
-
else:
|
374 |
-
Xk_BxNxSxH = Xk_BxKxSxH
|
375 |
-
Xv_BxNxSxH = Xv_BxKxSxH
|
376 |
|
377 |
-
# Encoder Self Attention
|
378 |
if cache is None:
|
379 |
-
attn_k =
|
380 |
-
attn_v =
|
381 |
-
# Decoder Self Attention
|
382 |
else:
|
383 |
-
# In prefill mode, we fill in cache until prefill length
|
384 |
if prefill:
|
385 |
-
attn_k, attn_v =
|
386 |
-
cache.
|
387 |
-
# In decode step, we add current K/V to cache step by step
|
388 |
else:
|
389 |
-
|
390 |
-
attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH)
|
391 |
|
392 |
attn_output = F.scaled_dot_product_attention(
|
393 |
Xq_BxNxTxH,
|
394 |
attn_k,
|
395 |
attn_v,
|
396 |
attn_mask=attn_mask,
|
397 |
-
dropout_p=self.dropout_rate if not deterministic else 0.0,
|
398 |
scale=1.0,
|
|
|
|
|
399 |
)
|
400 |
|
401 |
attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
|
402 |
output = self.o_proj(attn_output)
|
403 |
|
404 |
-
return output.to(original_dtype)
|
405 |
|
406 |
|
407 |
class EncoderLayer(nn.Module):
|
408 |
"""Transformer Encoder Layer using DenseGeneral."""
|
409 |
|
410 |
-
def __init__(self, config: DiaConfig):
|
411 |
super().__init__()
|
412 |
self.config = config
|
413 |
model_config = config.model
|
@@ -420,13 +295,13 @@ class EncoderLayer(nn.Module):
|
|
420 |
dtype=torch.float32,
|
421 |
)
|
422 |
self.self_attention = Attention(
|
423 |
-
config
|
424 |
q_embed_dim=embed_dim,
|
425 |
kv_embed_dim=embed_dim,
|
426 |
num_query_heads=enc_config.n_head,
|
427 |
num_kv_heads=enc_config.n_head,
|
428 |
head_dim=enc_config.head_dim,
|
429 |
-
|
430 |
is_cross_attn=False,
|
431 |
out_embed_dim=embed_dim,
|
432 |
)
|
@@ -436,62 +311,52 @@ class EncoderLayer(nn.Module):
|
|
436 |
dtype=torch.float32,
|
437 |
)
|
438 |
self.mlp = MlpBlock(
|
439 |
-
config=config,
|
440 |
embed_dim=embed_dim,
|
441 |
intermediate_dim=enc_config.n_hidden,
|
442 |
-
|
443 |
-
dropout_rate=model_config.dropout,
|
444 |
-
use_pre_norm=enc_config.use_pre_norm,
|
445 |
)
|
446 |
-
self.dropout = nn.Dropout(model_config.dropout)
|
447 |
|
448 |
def forward(
|
449 |
self,
|
450 |
x: torch.Tensor,
|
451 |
-
|
452 |
-
deterministic: bool = True,
|
453 |
-
attn_mask: torch.Tensor | None = None,
|
454 |
) -> torch.Tensor:
|
455 |
residual = x
|
456 |
x_norm = self.pre_sa_norm(x)
|
457 |
-
|
458 |
-
sa_out, _ = self.self_attention(
|
459 |
Xq=x_norm,
|
460 |
Xkv=x_norm,
|
461 |
-
q_positions=
|
462 |
-
kv_positions=
|
463 |
-
|
464 |
-
attn_mask=attn_mask,
|
465 |
)
|
466 |
x = residual + sa_out
|
467 |
|
468 |
residual = x
|
469 |
x_norm = self.post_sa_norm(x)
|
470 |
-
mlp_out = self.mlp(x_norm
|
471 |
x = residual + mlp_out
|
472 |
|
473 |
-
if not deterministic:
|
474 |
-
x = self.dropout(x)
|
475 |
return x
|
476 |
|
477 |
|
478 |
class Encoder(nn.Module):
|
479 |
"""Transformer Encoder Stack using DenseGeneral."""
|
480 |
|
481 |
-
def __init__(self, config: DiaConfig):
|
482 |
super().__init__()
|
483 |
self.config = config
|
484 |
model_config = config.model
|
485 |
enc_config = config.model.encoder
|
486 |
-
compute_dtype = _str_to_dtype(config.training.dtype)
|
487 |
|
488 |
self.embedding = nn.Embedding(
|
489 |
model_config.src_vocab_size,
|
490 |
enc_config.n_embd,
|
491 |
dtype=compute_dtype,
|
492 |
)
|
493 |
-
self.
|
494 |
-
|
|
|
495 |
self.norm = RMSNorm(
|
496 |
enc_config.n_embd,
|
497 |
eps=model_config.normalization_layer_epsilon,
|
@@ -501,32 +366,21 @@ class Encoder(nn.Module):
|
|
501 |
def forward(
|
502 |
self,
|
503 |
x_ids: torch.Tensor,
|
504 |
-
|
505 |
-
deterministic: bool = True,
|
506 |
-
attn_mask: torch.Tensor | None = None,
|
507 |
) -> torch.Tensor:
|
508 |
x = self.embedding(x_ids)
|
509 |
|
510 |
-
if not deterministic:
|
511 |
-
x = self.dropout(x)
|
512 |
-
|
513 |
for layer in self.layers:
|
514 |
-
x = layer(
|
515 |
-
|
516 |
-
src_positions=src_positions,
|
517 |
-
deterministic=deterministic,
|
518 |
-
attn_mask=attn_mask,
|
519 |
-
)
|
520 |
x = self.norm(x)
|
521 |
-
if not deterministic:
|
522 |
-
x = self.dropout(x)
|
523 |
return x
|
524 |
|
525 |
|
526 |
class DecoderLayer(nn.Module):
|
527 |
"""Transformer Decoder Layer using DenseGeneral."""
|
528 |
|
529 |
-
def __init__(self, config: DiaConfig):
|
530 |
super().__init__()
|
531 |
self.config = config
|
532 |
model_config = config.model
|
@@ -554,13 +408,13 @@ class DecoderLayer(nn.Module):
|
|
554 |
|
555 |
# Self-Attention (GQA) with Causal Masking
|
556 |
self.self_attention = Attention(
|
557 |
-
config
|
558 |
q_embed_dim=dec_embed_dim,
|
559 |
kv_embed_dim=dec_embed_dim,
|
560 |
num_query_heads=dec_config.gqa_query_heads,
|
561 |
num_kv_heads=dec_config.kv_heads,
|
562 |
head_dim=dec_config.gqa_head_dim,
|
563 |
-
|
564 |
is_cross_attn=False,
|
565 |
out_embed_dim=dec_embed_dim,
|
566 |
)
|
@@ -572,116 +426,105 @@ class DecoderLayer(nn.Module):
|
|
572 |
num_query_heads=dec_config.cross_query_heads,
|
573 |
num_kv_heads=dec_config.cross_query_heads,
|
574 |
head_dim=dec_config.cross_head_dim,
|
575 |
-
|
576 |
is_cross_attn=True,
|
577 |
out_embed_dim=dec_embed_dim,
|
578 |
)
|
579 |
# MLP
|
580 |
self.mlp = MlpBlock(
|
581 |
-
config=config,
|
582 |
embed_dim=dec_embed_dim,
|
583 |
intermediate_dim=dec_config.n_hidden,
|
584 |
-
|
585 |
-
dropout_rate=model_config.dropout,
|
586 |
-
use_pre_norm=dec_config.use_pre_norm,
|
587 |
)
|
588 |
|
589 |
def forward(
|
590 |
self,
|
591 |
x: torch.Tensor,
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
deterministic: bool,
|
596 |
-
self_attn_mask: torch.Tensor,
|
597 |
-
cross_attn_mask: torch.Tensor,
|
598 |
-
self_attn_cache: KVCache,
|
599 |
-
cross_attn_cache: KVCache,
|
600 |
prefill: bool = False,
|
601 |
) -> torch.Tensor:
|
602 |
residual = x
|
603 |
x_norm = self.pre_sa_norm(x)
|
604 |
|
605 |
-
sa_out
|
606 |
Xq=x_norm, # (2, 1, D)
|
607 |
Xkv=x_norm, # (2, 1, D)
|
608 |
-
q_positions=
|
609 |
-
kv_positions=
|
610 |
-
|
611 |
-
attn_mask=self_attn_mask, # (2, 1, 1, S_max)
|
612 |
cache=self_attn_cache,
|
613 |
prefill=prefill,
|
|
|
614 |
)
|
615 |
|
616 |
x = residual + sa_out
|
617 |
|
618 |
-
# 2. Cross-Attention
|
619 |
residual = x
|
620 |
x_norm = self.pre_ca_norm(x)
|
621 |
-
ca_out
|
622 |
Xq=x_norm,
|
623 |
-
Xkv=
|
624 |
-
q_positions=
|
625 |
-
kv_positions=
|
626 |
-
|
627 |
-
attn_mask=cross_attn_mask,
|
628 |
cache=cross_attn_cache,
|
629 |
)
|
630 |
x = residual + ca_out
|
631 |
|
632 |
-
# 3. MLP
|
633 |
residual = x
|
634 |
x_norm = self.pre_mlp_norm(x)
|
635 |
-
mlp_out = self.mlp(x_norm
|
636 |
x = residual + mlp_out
|
637 |
|
638 |
-
return x
|
639 |
|
640 |
|
641 |
class Decoder(nn.Module):
|
642 |
"""Transformer Decoder Stack using DenseGeneral."""
|
643 |
|
644 |
-
def __init__(self, config: DiaConfig):
|
645 |
super().__init__()
|
646 |
self.config = config
|
647 |
model_config = config.model
|
648 |
dec_config = config.model.decoder
|
649 |
-
train_config = config.training
|
650 |
data_config = config.data
|
651 |
-
compute_dtype = _str_to_dtype(config.training.dtype)
|
652 |
-
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
653 |
self.num_channels = data_config.channels
|
654 |
self.num_layers = dec_config.n_layer
|
655 |
|
656 |
self.embeddings = nn.ModuleList(
|
657 |
[
|
658 |
-
nn.Embedding(
|
|
|
|
|
659 |
for _ in range(self.num_channels)
|
660 |
]
|
661 |
)
|
662 |
-
self.
|
663 |
-
|
|
|
|
|
|
|
|
|
|
|
664 |
self.norm = RMSNorm(
|
665 |
dec_config.n_embd,
|
666 |
eps=model_config.normalization_layer_epsilon,
|
667 |
dtype=torch.float32,
|
668 |
)
|
669 |
|
670 |
-
# Final Logits Projection using DenseGeneral
|
671 |
self.logits_dense = DenseGeneral(
|
672 |
in_shapes=(dec_config.n_embd,),
|
673 |
out_features=(self.num_channels, model_config.tgt_vocab_size),
|
674 |
axis=(-1,),
|
675 |
-
|
676 |
-
weight_dtype=weight_dtype,
|
677 |
)
|
678 |
-
self.logits_in_fp32 = train_config.logits_dot_in_fp32
|
679 |
|
680 |
-
def
|
681 |
self,
|
682 |
-
|
683 |
-
|
684 |
-
src_positions: torch.Tensor | None, # (B, S)
|
685 |
) -> list[KVCache]:
|
686 |
"""
|
687 |
Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
|
@@ -690,35 +533,21 @@ class Decoder(nn.Module):
|
|
690 |
|
691 |
for layer in self.layers:
|
692 |
cross_attn_module = layer.cross_attention
|
693 |
-
k_proj = cross_attn_module.k_proj(
|
694 |
-
v_proj = cross_attn_module.v_proj(
|
695 |
|
696 |
-
k_proj = cross_attn_module.rotary_emb(k_proj, position=
|
697 |
k = k_proj.transpose(1, 2)
|
698 |
v = v_proj.transpose(1, 2)
|
699 |
|
700 |
-
per_layer_kv_cache.append(
|
701 |
-
KVCache(
|
702 |
-
cross_attn_module.num_kv_heads,
|
703 |
-
max_len,
|
704 |
-
cross_attn_module.head_dim,
|
705 |
-
k.device,
|
706 |
-
k=k,
|
707 |
-
v=v,
|
708 |
-
)
|
709 |
-
)
|
710 |
|
711 |
return per_layer_kv_cache
|
712 |
|
713 |
def decode_step(
|
714 |
self,
|
715 |
tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
|
716 |
-
|
717 |
-
encoder_out: torch.Tensor, # [B, S, E]
|
718 |
-
self_attn_mask: Any, # None
|
719 |
-
cross_attn_mask: torch.Tensor, # [B, 1, 1, S]
|
720 |
-
self_attention_cache: list[KVCache],
|
721 |
-
cross_attention_cache: list[KVCache],
|
722 |
) -> torch.Tensor:
|
723 |
"""
|
724 |
Performs a single decoding step, managing KV caches layer by layer.
|
@@ -727,7 +556,6 @@ class Decoder(nn.Module):
|
|
727 |
A tuple containing:
|
728 |
- logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
|
729 |
"""
|
730 |
-
assert self_attn_mask is None, "Self-attention mask should be None, kept for pattern"
|
731 |
|
732 |
x = None
|
733 |
for i in range(self.num_channels):
|
@@ -735,40 +563,23 @@ class Decoder(nn.Module):
|
|
735 |
channel_embed = self.embeddings[i](channel_tokens)
|
736 |
x = channel_embed if x is None else x + channel_embed
|
737 |
|
738 |
-
new_cache = []
|
739 |
-
|
740 |
for i, layer in enumerate(self.layers):
|
741 |
-
self_cache =
|
742 |
-
cross_cache =
|
743 |
-
x
|
744 |
x, # (2, 1, D)
|
745 |
-
|
746 |
-
src_positions=None, # CA KV is already computed
|
747 |
-
tgt_positions=tgt_pos_Bx1, # (2, 1)
|
748 |
-
deterministic=True,
|
749 |
-
self_attn_mask=None,
|
750 |
-
cross_attn_mask=cross_attn_mask,
|
751 |
self_attn_cache=self_cache,
|
752 |
cross_attn_cache=cross_cache,
|
753 |
)
|
754 |
-
new_cache.append(new_kv_cache)
|
755 |
|
756 |
x = self.norm(x)
|
757 |
logits_Bx1xCxV = self.logits_dense(x)
|
758 |
|
759 |
-
return logits_Bx1xCxV.to(torch.float32)
|
760 |
|
761 |
def forward(
|
762 |
-
self,
|
763 |
-
tgt_ids_BxTxC: torch.Tensor,
|
764 |
-
encoder_out: torch.Tensor,
|
765 |
-
tgt_positions: torch.Tensor,
|
766 |
-
src_positions: torch.Tensor,
|
767 |
-
deterministic: bool,
|
768 |
-
self_attn_mask: torch.Tensor,
|
769 |
-
cross_attn_mask: torch.Tensor,
|
770 |
-
self_attention_cache: list[KVCache],
|
771 |
-
cross_attention_cache: list[KVCache],
|
772 |
) -> torch.Tensor:
|
773 |
"""
|
774 |
Forward pass for the Decoder stack, managing KV caches.
|
@@ -778,7 +589,6 @@ class Decoder(nn.Module):
|
|
778 |
encoder_out: Output from the encoder (B, S, E).
|
779 |
tgt_positions: Positions for target sequence (B, T).
|
780 |
src_positions: Positions for source sequence (B, S).
|
781 |
-
deterministic: Disable dropout if True.
|
782 |
self_attn_mask: Mask for self-attention.
|
783 |
cross_attn_mask: Mask for cross-attention.
|
784 |
past_key_values: List containing the self-attention KV cache for each layer
|
@@ -804,20 +614,14 @@ class Decoder(nn.Module):
|
|
804 |
channel_embed = self.embeddings[i](channel_tokens)
|
805 |
x = channel_embed if x is None else x + channel_embed
|
806 |
|
807 |
-
if not deterministic:
|
808 |
-
x = self.dropout(x)
|
809 |
-
|
810 |
for i, layer in enumerate(self.layers):
|
811 |
-
|
|
|
|
|
812 |
x,
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
deterministic=deterministic,
|
817 |
-
self_attn_mask=self_attn_mask,
|
818 |
-
cross_attn_mask=cross_attn_mask,
|
819 |
-
self_attn_cache=self_attention_cache[i],
|
820 |
-
cross_attn_cache=cross_attention_cache[i],
|
821 |
prefill=True,
|
822 |
)
|
823 |
|
@@ -831,43 +635,8 @@ class Decoder(nn.Module):
|
|
831 |
class DiaModel(nn.Module):
|
832 |
"""PyTorch Dia Model using DenseGeneral."""
|
833 |
|
834 |
-
def __init__(self, config: DiaConfig):
|
835 |
super().__init__()
|
836 |
self.config = config
|
837 |
-
self.encoder = Encoder(config)
|
838 |
-
self.decoder = Decoder(config)
|
839 |
-
|
840 |
-
def forward(
|
841 |
-
self,
|
842 |
-
src_BxS: torch.Tensor,
|
843 |
-
tgt_BxTxC: torch.Tensor,
|
844 |
-
src_positions: torch.Tensor | None = None,
|
845 |
-
tgt_positions: torch.Tensor | None = None,
|
846 |
-
enc_self_attn_mask: torch.Tensor | None = None,
|
847 |
-
dec_self_attn_mask: torch.Tensor | None = None,
|
848 |
-
dec_cross_attn_mask: torch.Tensor | None = None,
|
849 |
-
enable_dropout: bool = True,
|
850 |
-
):
|
851 |
-
deterministic = not enable_dropout
|
852 |
-
|
853 |
-
# --- Encoder Pass ---
|
854 |
-
encoder_out = self.encoder(
|
855 |
-
x_ids=src_BxS,
|
856 |
-
src_positions=src_positions,
|
857 |
-
deterministic=deterministic,
|
858 |
-
attn_mask=enc_self_attn_mask,
|
859 |
-
)
|
860 |
-
|
861 |
-
# --- Decoder Pass ---
|
862 |
-
logits, _ = self.decoder(
|
863 |
-
tgt_ids_BxTxC=tgt_BxTxC,
|
864 |
-
encoder_out=encoder_out,
|
865 |
-
tgt_positions=tgt_positions,
|
866 |
-
src_positions=src_positions,
|
867 |
-
deterministic=deterministic,
|
868 |
-
self_attn_mask=dec_self_attn_mask,
|
869 |
-
cross_attn_mask=dec_cross_attn_mask,
|
870 |
-
precomputed_cross_attn_kv=None,
|
871 |
-
)
|
872 |
-
|
873 |
-
return logits
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
|
|
5 |
from torch.nn import RMSNorm
|
6 |
|
7 |
from .config import DiaConfig
|
8 |
+
from .state import DecoderInferenceState, EncoderInferenceState, KVCache
|
9 |
|
10 |
|
11 |
def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
|
12 |
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
|
13 |
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
class DenseGeneral(nn.Module):
|
16 |
"""
|
17 |
PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
|
|
|
35 |
in_shapes: tuple[int, ...],
|
36 |
out_features: tuple[int, ...],
|
37 |
axis: tuple[int, ...] = (-1,),
|
|
|
38 |
weight_dtype: torch.dtype | None = None,
|
39 |
device: torch.device | None = None,
|
40 |
):
|
|
|
42 |
self.in_shapes = in_shapes
|
43 |
self.out_features = out_features
|
44 |
self.axis = axis
|
|
|
45 |
self.kernel_shape = self.in_shapes + self.out_features
|
46 |
|
47 |
factory_kwargs = {"device": device, "dtype": weight_dtype}
|
|
|
53 |
kernel_contract_axes = tuple(range(len(norm_axis)))
|
54 |
|
55 |
output = torch.tensordot(
|
56 |
+
inputs.to(self.weight.dtype),
|
57 |
+
self.weight,
|
58 |
dims=(norm_axis, kernel_contract_axes),
|
59 |
).to(inputs.dtype)
|
60 |
return output
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
class MlpBlock(nn.Module):
|
64 |
"""MLP block using DenseGeneral."""
|
65 |
|
66 |
def __init__(
|
67 |
+
self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
):
|
69 |
super().__init__()
|
|
|
|
|
|
|
|
|
70 |
self.dtype = compute_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
self.wi_fused = DenseGeneral(
|
73 |
in_shapes=(embed_dim,),
|
74 |
+
out_features=(2, intermediate_dim),
|
|
|
|
|
|
|
75 |
axis=(-1,),
|
76 |
+
weight_dtype=compute_dtype,
|
|
|
77 |
)
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
self.wo = DenseGeneral(
|
80 |
in_shapes=(intermediate_dim,),
|
81 |
out_features=(embed_dim,),
|
82 |
axis=(-1,),
|
83 |
+
weight_dtype=compute_dtype,
|
|
|
84 |
)
|
85 |
|
86 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
87 |
"""Forward pass."""
|
|
|
|
|
|
|
88 |
fused_x = self.wi_fused(x)
|
89 |
|
90 |
+
gate = fused_x[..., 0, :]
|
91 |
+
up = fused_x[..., 1, :]
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
hidden = torch.mul(F.silu(gate), up).to(self.dtype)
|
|
|
94 |
|
95 |
output = self.wo(hidden)
|
96 |
return output
|
|
|
139 |
return torch.cat((first_part, second_part), dim=-1)
|
140 |
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
class Attention(nn.Module):
|
143 |
"""Attention using DenseGeneral."""
|
144 |
|
|
|
150 |
num_query_heads: int,
|
151 |
num_kv_heads: int,
|
152 |
head_dim: int,
|
153 |
+
compute_dtype: torch.dtype,
|
154 |
is_cross_attn: bool = False,
|
155 |
out_embed_dim: int | None = None,
|
156 |
):
|
|
|
159 |
self.num_kv_heads = num_kv_heads
|
160 |
self.head_dim = head_dim
|
161 |
self.is_cross_attn = is_cross_attn
|
|
|
|
|
|
|
162 |
self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
|
163 |
self.projected_query_dim = num_query_heads * head_dim
|
164 |
if num_query_heads % num_kv_heads != 0:
|
165 |
+
raise ValueError(
|
166 |
+
f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
|
167 |
+
)
|
168 |
self.num_gqa_groups = num_query_heads // num_kv_heads
|
169 |
|
170 |
# --- Projection Layers using DenseGeneral ---
|
|
|
172 |
in_shapes=(q_embed_dim,),
|
173 |
out_features=(num_query_heads, head_dim),
|
174 |
axis=(-1,),
|
175 |
+
weight_dtype=compute_dtype,
|
|
|
176 |
)
|
177 |
self.k_proj = DenseGeneral(
|
178 |
in_shapes=(kv_embed_dim,),
|
179 |
out_features=(num_kv_heads, head_dim),
|
180 |
axis=(-1,),
|
181 |
+
weight_dtype=compute_dtype,
|
|
|
182 |
)
|
183 |
self.v_proj = DenseGeneral(
|
184 |
in_shapes=(kv_embed_dim,),
|
185 |
out_features=(num_kv_heads, head_dim),
|
186 |
axis=(-1,),
|
187 |
+
weight_dtype=compute_dtype,
|
|
|
188 |
)
|
189 |
self.o_proj = DenseGeneral(
|
190 |
in_shapes=(num_query_heads, head_dim),
|
191 |
out_features=(self.output_dim,),
|
192 |
axis=(-2, -1),
|
193 |
+
weight_dtype=compute_dtype,
|
|
|
194 |
)
|
195 |
|
196 |
# --- Rotary Embedding ---
|
|
|
207 |
Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
|
208 |
q_positions: torch.Tensor, # (B, T)
|
209 |
kv_positions: torch.Tensor | None = None, # (B, S)
|
210 |
+
attn_mask: torch.Tensor
|
211 |
+
| None = None, # None in Decoder Self Attention, Valid mask in Others
|
212 |
cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
|
213 |
+
prefill: bool = False,
|
214 |
+
is_causal: bool = False,
|
215 |
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
|
216 |
"""
|
217 |
Performs attention calculation with optional KV caching.
|
|
|
221 |
Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
|
222 |
q_positions: Positions for queries (B, T).
|
223 |
kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
|
|
|
224 |
attn_mask: Attention mask.
|
225 |
cache: KVCache.
|
226 |
prefill: If True, use prefill mode.
|
|
|
238 |
Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
|
239 |
Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
|
240 |
|
|
|
241 |
attn_k: torch.Tensor | None = None
|
242 |
attn_v: torch.Tensor | None = None
|
|
|
243 |
|
|
|
244 |
if self.is_cross_attn:
|
|
|
245 |
attn_k, attn_v = cache.k, cache.v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
else:
|
247 |
Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
|
248 |
Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
|
249 |
+
Xk_BxSxKxH = self.rotary_emb(
|
250 |
+
Xk_BxSxKxH, position=kv_positions
|
251 |
+
) # (B, S, K, H)
|
252 |
|
253 |
Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
254 |
Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
|
|
256 |
if cache is None:
|
257 |
+
attn_k = Xk_BxKxSxH
|
258 |
+
attn_v = Xv_BxKxSxH
|
|
|
259 |
else:
|
|
|
260 |
if prefill:
|
261 |
+
attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
|
262 |
+
cache.prefill(attn_k, attn_v)
|
|
|
263 |
else:
|
264 |
+
attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH)
|
|
|
265 |
|
266 |
attn_output = F.scaled_dot_product_attention(
|
267 |
Xq_BxNxTxH,
|
268 |
attn_k,
|
269 |
attn_v,
|
270 |
attn_mask=attn_mask,
|
|
|
271 |
scale=1.0,
|
272 |
+
enable_gqa=self.num_gqa_groups > 1,
|
273 |
+
is_causal=is_causal,
|
274 |
)
|
275 |
|
276 |
attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
|
277 |
output = self.o_proj(attn_output)
|
278 |
|
279 |
+
return output.to(original_dtype)
|
280 |
|
281 |
|
282 |
class EncoderLayer(nn.Module):
|
283 |
"""Transformer Encoder Layer using DenseGeneral."""
|
284 |
|
285 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
286 |
super().__init__()
|
287 |
self.config = config
|
288 |
model_config = config.model
|
|
|
295 |
dtype=torch.float32,
|
296 |
)
|
297 |
self.self_attention = Attention(
|
298 |
+
config,
|
299 |
q_embed_dim=embed_dim,
|
300 |
kv_embed_dim=embed_dim,
|
301 |
num_query_heads=enc_config.n_head,
|
302 |
num_kv_heads=enc_config.n_head,
|
303 |
head_dim=enc_config.head_dim,
|
304 |
+
compute_dtype=compute_dtype,
|
305 |
is_cross_attn=False,
|
306 |
out_embed_dim=embed_dim,
|
307 |
)
|
|
|
311 |
dtype=torch.float32,
|
312 |
)
|
313 |
self.mlp = MlpBlock(
|
|
|
314 |
embed_dim=embed_dim,
|
315 |
intermediate_dim=enc_config.n_hidden,
|
316 |
+
compute_dtype=compute_dtype,
|
|
|
|
|
317 |
)
|
|
|
318 |
|
319 |
def forward(
|
320 |
self,
|
321 |
x: torch.Tensor,
|
322 |
+
state: EncoderInferenceState,
|
|
|
|
|
323 |
) -> torch.Tensor:
|
324 |
residual = x
|
325 |
x_norm = self.pre_sa_norm(x)
|
326 |
+
sa_out = self.self_attention(
|
|
|
327 |
Xq=x_norm,
|
328 |
Xkv=x_norm,
|
329 |
+
q_positions=state.positions,
|
330 |
+
kv_positions=state.positions,
|
331 |
+
attn_mask=state.attn_mask,
|
|
|
332 |
)
|
333 |
x = residual + sa_out
|
334 |
|
335 |
residual = x
|
336 |
x_norm = self.post_sa_norm(x)
|
337 |
+
mlp_out = self.mlp(x_norm)
|
338 |
x = residual + mlp_out
|
339 |
|
|
|
|
|
340 |
return x
|
341 |
|
342 |
|
343 |
class Encoder(nn.Module):
|
344 |
"""Transformer Encoder Stack using DenseGeneral."""
|
345 |
|
346 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
347 |
super().__init__()
|
348 |
self.config = config
|
349 |
model_config = config.model
|
350 |
enc_config = config.model.encoder
|
|
|
351 |
|
352 |
self.embedding = nn.Embedding(
|
353 |
model_config.src_vocab_size,
|
354 |
enc_config.n_embd,
|
355 |
dtype=compute_dtype,
|
356 |
)
|
357 |
+
self.layers = nn.ModuleList(
|
358 |
+
[EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)]
|
359 |
+
)
|
360 |
self.norm = RMSNorm(
|
361 |
enc_config.n_embd,
|
362 |
eps=model_config.normalization_layer_epsilon,
|
|
|
366 |
def forward(
|
367 |
self,
|
368 |
x_ids: torch.Tensor,
|
369 |
+
state: EncoderInferenceState,
|
|
|
|
|
370 |
) -> torch.Tensor:
|
371 |
x = self.embedding(x_ids)
|
372 |
|
|
|
|
|
|
|
373 |
for layer in self.layers:
|
374 |
+
x = layer(x, state)
|
375 |
+
|
|
|
|
|
|
|
|
|
376 |
x = self.norm(x)
|
|
|
|
|
377 |
return x
|
378 |
|
379 |
|
380 |
class DecoderLayer(nn.Module):
|
381 |
"""Transformer Decoder Layer using DenseGeneral."""
|
382 |
|
383 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
384 |
super().__init__()
|
385 |
self.config = config
|
386 |
model_config = config.model
|
|
|
408 |
|
409 |
# Self-Attention (GQA) with Causal Masking
|
410 |
self.self_attention = Attention(
|
411 |
+
config,
|
412 |
q_embed_dim=dec_embed_dim,
|
413 |
kv_embed_dim=dec_embed_dim,
|
414 |
num_query_heads=dec_config.gqa_query_heads,
|
415 |
num_kv_heads=dec_config.kv_heads,
|
416 |
head_dim=dec_config.gqa_head_dim,
|
417 |
+
compute_dtype=compute_dtype,
|
418 |
is_cross_attn=False,
|
419 |
out_embed_dim=dec_embed_dim,
|
420 |
)
|
|
|
426 |
num_query_heads=dec_config.cross_query_heads,
|
427 |
num_kv_heads=dec_config.cross_query_heads,
|
428 |
head_dim=dec_config.cross_head_dim,
|
429 |
+
compute_dtype=compute_dtype,
|
430 |
is_cross_attn=True,
|
431 |
out_embed_dim=dec_embed_dim,
|
432 |
)
|
433 |
# MLP
|
434 |
self.mlp = MlpBlock(
|
|
|
435 |
embed_dim=dec_embed_dim,
|
436 |
intermediate_dim=dec_config.n_hidden,
|
437 |
+
compute_dtype=compute_dtype,
|
|
|
|
|
438 |
)
|
439 |
|
440 |
def forward(
|
441 |
self,
|
442 |
x: torch.Tensor,
|
443 |
+
state: DecoderInferenceState,
|
444 |
+
self_attn_cache: KVCache | None = None,
|
445 |
+
cross_attn_cache: KVCache | None = None,
|
|
|
|
|
|
|
|
|
|
|
446 |
prefill: bool = False,
|
447 |
) -> torch.Tensor:
|
448 |
residual = x
|
449 |
x_norm = self.pre_sa_norm(x)
|
450 |
|
451 |
+
sa_out = self.self_attention(
|
452 |
Xq=x_norm, # (2, 1, D)
|
453 |
Xkv=x_norm, # (2, 1, D)
|
454 |
+
q_positions=state.dec_positions, # (2, 1)
|
455 |
+
kv_positions=state.dec_positions, # (2, 1)
|
456 |
+
attn_mask=None,
|
|
|
457 |
cache=self_attn_cache,
|
458 |
prefill=prefill,
|
459 |
+
is_causal=prefill,
|
460 |
)
|
461 |
|
462 |
x = residual + sa_out
|
463 |
|
|
|
464 |
residual = x
|
465 |
x_norm = self.pre_ca_norm(x)
|
466 |
+
ca_out = self.cross_attention(
|
467 |
Xq=x_norm,
|
468 |
+
Xkv=state.enc_out,
|
469 |
+
q_positions=state.dec_positions,
|
470 |
+
kv_positions=state.enc_positions,
|
471 |
+
attn_mask=state.dec_cross_attn_mask,
|
|
|
472 |
cache=cross_attn_cache,
|
473 |
)
|
474 |
x = residual + ca_out
|
475 |
|
|
|
476 |
residual = x
|
477 |
x_norm = self.pre_mlp_norm(x)
|
478 |
+
mlp_out = self.mlp(x_norm)
|
479 |
x = residual + mlp_out
|
480 |
|
481 |
+
return x
|
482 |
|
483 |
|
484 |
class Decoder(nn.Module):
|
485 |
"""Transformer Decoder Stack using DenseGeneral."""
|
486 |
|
487 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
488 |
super().__init__()
|
489 |
self.config = config
|
490 |
model_config = config.model
|
491 |
dec_config = config.model.decoder
|
|
|
492 |
data_config = config.data
|
|
|
|
|
493 |
self.num_channels = data_config.channels
|
494 |
self.num_layers = dec_config.n_layer
|
495 |
|
496 |
self.embeddings = nn.ModuleList(
|
497 |
[
|
498 |
+
nn.Embedding(
|
499 |
+
model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype
|
500 |
+
)
|
501 |
for _ in range(self.num_channels)
|
502 |
]
|
503 |
)
|
504 |
+
self.layers = nn.ModuleList(
|
505 |
+
[
|
506 |
+
DecoderLayer(config=config, compute_dtype=compute_dtype)
|
507 |
+
for _ in range(self.num_layers)
|
508 |
+
]
|
509 |
+
)
|
510 |
+
|
511 |
self.norm = RMSNorm(
|
512 |
dec_config.n_embd,
|
513 |
eps=model_config.normalization_layer_epsilon,
|
514 |
dtype=torch.float32,
|
515 |
)
|
516 |
|
|
|
517 |
self.logits_dense = DenseGeneral(
|
518 |
in_shapes=(dec_config.n_embd,),
|
519 |
out_features=(self.num_channels, model_config.tgt_vocab_size),
|
520 |
axis=(-1,),
|
521 |
+
weight_dtype=compute_dtype,
|
|
|
522 |
)
|
|
|
523 |
|
524 |
+
def precompute_cross_attn_cache(
|
525 |
self,
|
526 |
+
enc_out: torch.Tensor, # (B, S, E)
|
527 |
+
enc_positions: torch.Tensor, # (B, S)
|
|
|
528 |
) -> list[KVCache]:
|
529 |
"""
|
530 |
Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
|
|
|
533 |
|
534 |
for layer in self.layers:
|
535 |
cross_attn_module = layer.cross_attention
|
536 |
+
k_proj = cross_attn_module.k_proj(enc_out)
|
537 |
+
v_proj = cross_attn_module.v_proj(enc_out)
|
538 |
|
539 |
+
k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions)
|
540 |
k = k_proj.transpose(1, 2)
|
541 |
v = v_proj.transpose(1, 2)
|
542 |
|
543 |
+
per_layer_kv_cache.append(KVCache.from_kv(k, v))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
544 |
|
545 |
return per_layer_kv_cache
|
546 |
|
547 |
def decode_step(
|
548 |
self,
|
549 |
tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
|
550 |
+
state: DecoderInferenceState,
|
|
|
|
|
|
|
|
|
|
|
551 |
) -> torch.Tensor:
|
552 |
"""
|
553 |
Performs a single decoding step, managing KV caches layer by layer.
|
|
|
556 |
A tuple containing:
|
557 |
- logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
|
558 |
"""
|
|
|
559 |
|
560 |
x = None
|
561 |
for i in range(self.num_channels):
|
|
|
563 |
channel_embed = self.embeddings[i](channel_tokens)
|
564 |
x = channel_embed if x is None else x + channel_embed
|
565 |
|
|
|
|
|
566 |
for i, layer in enumerate(self.layers):
|
567 |
+
self_cache = state.self_attn_cache[i]
|
568 |
+
cross_cache = state.cross_attn_cache[i]
|
569 |
+
x = layer(
|
570 |
x, # (2, 1, D)
|
571 |
+
state,
|
|
|
|
|
|
|
|
|
|
|
572 |
self_attn_cache=self_cache,
|
573 |
cross_attn_cache=cross_cache,
|
574 |
)
|
|
|
575 |
|
576 |
x = self.norm(x)
|
577 |
logits_Bx1xCxV = self.logits_dense(x)
|
578 |
|
579 |
+
return logits_Bx1xCxV.to(torch.float32)
|
580 |
|
581 |
def forward(
|
582 |
+
self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
583 |
) -> torch.Tensor:
|
584 |
"""
|
585 |
Forward pass for the Decoder stack, managing KV caches.
|
|
|
589 |
encoder_out: Output from the encoder (B, S, E).
|
590 |
tgt_positions: Positions for target sequence (B, T).
|
591 |
src_positions: Positions for source sequence (B, S).
|
|
|
592 |
self_attn_mask: Mask for self-attention.
|
593 |
cross_attn_mask: Mask for cross-attention.
|
594 |
past_key_values: List containing the self-attention KV cache for each layer
|
|
|
614 |
channel_embed = self.embeddings[i](channel_tokens)
|
615 |
x = channel_embed if x is None else x + channel_embed
|
616 |
|
|
|
|
|
|
|
617 |
for i, layer in enumerate(self.layers):
|
618 |
+
self_cache = state.self_attn_cache[i]
|
619 |
+
cross_cache = state.cross_attn_cache[i]
|
620 |
+
x = layer(
|
621 |
x,
|
622 |
+
state,
|
623 |
+
self_attn_cache=self_cache,
|
624 |
+
cross_attn_cache=cross_cache,
|
|
|
|
|
|
|
|
|
|
|
625 |
prefill=True,
|
626 |
)
|
627 |
|
|
|
635 |
class DiaModel(nn.Module):
|
636 |
"""PyTorch Dia Model using DenseGeneral."""
|
637 |
|
638 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
639 |
super().__init__()
|
640 |
self.config = config
|
641 |
+
self.encoder = Encoder(config, compute_dtype)
|
642 |
+
self.decoder = Decoder(config, compute_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dia/model.py
CHANGED
@@ -1,26 +1,46 @@
|
|
|
|
|
|
|
|
1 |
import dac
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
import torchaudio
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
|
7 |
-
from .audio import
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from .config import DiaConfig
|
9 |
-
from .layers import DiaModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def _sample_next_token(
|
13 |
logits_BCxV: torch.Tensor,
|
14 |
temperature: float,
|
15 |
top_p: float,
|
16 |
-
use_cfg_filter: bool,
|
17 |
cfg_filter_top_k: int | None = None,
|
18 |
) -> torch.Tensor:
|
19 |
if temperature == 0.0:
|
20 |
return torch.argmax(logits_BCxV, dim=-1)
|
21 |
|
22 |
logits_BCxV = logits_BCxV / temperature
|
23 |
-
if
|
24 |
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
|
25 |
mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
|
26 |
mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
|
@@ -28,17 +48,21 @@ def _sample_next_token(
|
|
28 |
|
29 |
if top_p < 1.0:
|
30 |
probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
31 |
-
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
|
|
|
|
|
32 |
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
|
33 |
|
34 |
-
# Calculate indices to remove based on top_p
|
35 |
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
39 |
|
40 |
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
|
41 |
-
indices_to_remove_BCxV.scatter_(
|
|
|
|
|
42 |
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
|
43 |
|
44 |
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
@@ -48,31 +72,61 @@ def _sample_next_token(
|
|
48 |
return sampled_indices_C
|
49 |
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
class Dia:
|
52 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
53 |
"""Initializes the Dia model.
|
54 |
|
55 |
Args:
|
56 |
config: The configuration object for the model.
|
57 |
-
device: The device to load the model onto.
|
58 |
|
59 |
Raises:
|
60 |
RuntimeError: If there is an error loading the DAC model.
|
61 |
"""
|
62 |
super().__init__()
|
63 |
self.config = config
|
64 |
-
self.device = device
|
65 |
-
|
|
|
|
|
|
|
66 |
self.dac_model = None
|
67 |
|
68 |
@classmethod
|
69 |
-
def from_local(
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
"""Loads the Dia model from local configuration and checkpoint files.
|
71 |
|
72 |
Args:
|
73 |
config_path: Path to the configuration JSON file.
|
74 |
checkpoint_path: Path to the model checkpoint (.pth) file.
|
75 |
-
device: The device to load the model onto.
|
76 |
|
77 |
Returns:
|
78 |
An instance of the Dia model loaded with weights and set to eval mode.
|
@@ -85,23 +139,29 @@ class Dia:
|
|
85 |
if config is None:
|
86 |
raise FileNotFoundError(f"Config file not found at {config_path}")
|
87 |
|
88 |
-
dia = cls(config, device)
|
89 |
|
90 |
try:
|
91 |
-
|
|
|
92 |
except FileNotFoundError:
|
93 |
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
94 |
except Exception as e:
|
95 |
-
raise RuntimeError(
|
|
|
|
|
96 |
|
97 |
-
dia.model.to(device)
|
98 |
dia.model.eval()
|
99 |
dia._load_dac_model()
|
100 |
return dia
|
101 |
|
102 |
@classmethod
|
103 |
def from_pretrained(
|
104 |
-
cls,
|
|
|
|
|
|
|
105 |
) -> "Dia":
|
106 |
"""Loads the Dia model from a Hugging Face Hub repository.
|
107 |
|
@@ -110,7 +170,7 @@ class Dia:
|
|
110 |
|
111 |
Args:
|
112 |
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
|
113 |
-
device: The device to load the model onto.
|
114 |
|
115 |
Returns:
|
116 |
An instance of the Dia model loaded with weights and set to eval mode.
|
@@ -121,7 +181,7 @@ class Dia:
|
|
121 |
"""
|
122 |
config_path = hf_hub_download(repo_id=model_name, filename="config.json")
|
123 |
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
|
124 |
-
return cls.from_local(config_path, checkpoint_path, device)
|
125 |
|
126 |
def _load_dac_model(self):
|
127 |
try:
|
@@ -131,44 +191,7 @@ class Dia:
|
|
131 |
raise RuntimeError("Failed to load DAC model") from e
|
132 |
self.dac_model = dac_model
|
133 |
|
134 |
-
def
|
135 |
-
self,
|
136 |
-
q_padding_mask_1d: torch.Tensor,
|
137 |
-
k_padding_mask_1d: torch.Tensor,
|
138 |
-
is_causal: bool = False,
|
139 |
-
) -> torch.Tensor:
|
140 |
-
"""
|
141 |
-
Creates the attention mask (self or cross) mimicking JAX segment ID logic.
|
142 |
-
"""
|
143 |
-
B1, Tq = q_padding_mask_1d.shape
|
144 |
-
B2, Tk = k_padding_mask_1d.shape
|
145 |
-
assert B1 == B2, "Query and key batch dimensions must match"
|
146 |
-
|
147 |
-
p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
|
148 |
-
p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
|
149 |
-
|
150 |
-
# Condition A: Non-padding query attends to non-padding key
|
151 |
-
non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
|
152 |
-
|
153 |
-
# Condition B: Padding query attends to padding key
|
154 |
-
pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
|
155 |
-
|
156 |
-
# Combine: True if padding status is compatible (both non-pad OR both pad)
|
157 |
-
# This implementation follows Jax TPU splash attention kernel
|
158 |
-
mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
|
159 |
-
|
160 |
-
if is_causal:
|
161 |
-
# Ensure causality for self-attention (Tq == Tk)
|
162 |
-
assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
|
163 |
-
# Standard lower-triangular causal mask (True means allow)
|
164 |
-
causal_mask_2d = torch.tril(torch.ones((Tq, Tk), dtype=torch.bool, device=self.device)) # Shape [Tq, Tk]
|
165 |
-
causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
|
166 |
-
return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
|
167 |
-
else:
|
168 |
-
# For cross-attention or non-causal self-attention
|
169 |
-
return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
|
170 |
-
|
171 |
-
def _prepare_text_input(self, text: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
172 |
"""Encodes text prompt, pads, and creates attention mask and positions."""
|
173 |
text_pad_value = self.config.data.text_pad_value
|
174 |
max_len = self.config.data.text_length
|
@@ -190,14 +213,168 @@ class Dia:
|
|
190 |
constant_values=text_pad_value,
|
191 |
).astype(np.uint8)
|
192 |
|
193 |
-
src_tokens =
|
194 |
-
|
|
|
|
|
195 |
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
@torch.inference_mode()
|
203 |
def generate(
|
@@ -207,225 +384,105 @@ class Dia:
|
|
207 |
cfg_scale: float = 3.0,
|
208 |
temperature: float = 1.3,
|
209 |
top_p: float = 0.95,
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
audio_prompt_path: str | None = None,
|
|
|
|
|
214 |
) -> np.ndarray:
|
215 |
-
"""
|
216 |
-
Generates audio from a text prompt (and optional audio prompt) using the Nari model.
|
217 |
-
|
218 |
-
Returns:
|
219 |
-
A tensor of generated audio codes (shape: [max_tokens, num_channels]).
|
220 |
-
"""
|
221 |
-
num_channels = self.config.data.channels
|
222 |
-
audio_bos_value = self.config.data.audio_bos_value
|
223 |
audio_eos_value = self.config.data.audio_eos_value
|
224 |
audio_pad_value = self.config.data.audio_pad_value
|
225 |
delay_pattern = self.config.data.delay_pattern
|
226 |
max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
|
227 |
-
delay_tensor = torch.tensor(delay_pattern, dtype=torch.long, device=self.device)
|
228 |
max_delay_pattern = max(delay_pattern)
|
229 |
self.model.eval()
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
) = self._prepare_text_input(text)
|
237 |
-
|
238 |
-
unc_src_BxS = torch.zeros_like(cond_src_BxS)
|
239 |
-
src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0)
|
240 |
-
src_positions_BxS = cond_src_positions_BxS.expand(2, -1)
|
241 |
-
src_padding_mask_BxS = cond_src_padding_mask_BxS.expand(2, -1)
|
242 |
-
enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(2, -1, -1, -1)
|
243 |
-
|
244 |
-
# 2. Encoder Pass
|
245 |
-
# with torch.autocast(device_type="cuda", dtype=forward_dtype):
|
246 |
-
encoder_out = self.model.encoder(
|
247 |
-
x_ids=src_BxS,
|
248 |
-
src_positions=src_positions_BxS,
|
249 |
-
deterministic=True,
|
250 |
-
attn_mask=enc_self_attn_mask_Bx1xSxS,
|
251 |
-
) # Shape: (B, S, E)
|
252 |
-
|
253 |
-
# 3. Prepare Decoder Inputs
|
254 |
-
# 3-1. Allocate KV Cache (Static)
|
255 |
-
decoder_cross_attention_cache: list[KVCache] = self.model.decoder.precompute_cross_attention_kv(
|
256 |
-
max_tokens, encoder_out, src_positions_BxS
|
257 |
-
)
|
258 |
-
|
259 |
-
decoder_self_attention_cache: list[KVCache] = []
|
260 |
-
for _ in range(self.model.decoder.num_layers):
|
261 |
-
decoder_self_attention_cache.append(
|
262 |
-
KVCache(
|
263 |
-
self.config.model.decoder.gqa_query_heads,
|
264 |
-
max_tokens,
|
265 |
-
self.config.model.decoder.gqa_head_dim,
|
266 |
-
self.device,
|
267 |
-
)
|
268 |
-
)
|
269 |
-
|
270 |
-
# 3-2. Initialize Decoder Inputs
|
271 |
-
generated_BxTxC = torch.full(
|
272 |
-
(2, 1, num_channels),
|
273 |
-
fill_value=audio_bos_value,
|
274 |
-
dtype=torch.long,
|
275 |
-
device=self.device,
|
276 |
-
)
|
277 |
-
|
278 |
-
current_step = 0
|
279 |
-
prompt_len_inc_bos = 1 # Start with BOS length
|
280 |
-
|
281 |
-
# 3-3. Load Audio Prompt (if provided)
|
282 |
-
if audio_prompt_path is not None:
|
283 |
-
audio_prompt, sr = torchaudio.load(audio_prompt_path, channels_first=True) # C, T
|
284 |
-
if sr != 44100: # Resample to 44.1kHz
|
285 |
-
audio_prompt = torchaudio.functional.resample(audio_prompt, sr, 44100)
|
286 |
-
audio_prompt = audio_prompt.to(self.device).unsqueeze(0) # 1, C, T
|
287 |
-
audio_prompt = audio_to_codebook(self.dac_model, audio_prompt, data_config=self.config.data)
|
288 |
-
generated_BxTxC = torch.cat([generated_BxTxC, audio_prompt.expand(2, -1, -1)], dim=1)
|
289 |
-
|
290 |
-
prefill_len = generated_BxTxC.shape[1]
|
291 |
-
prompt_len_inc_bos = prefill_len
|
292 |
-
prefill_tgt_pos = torch.arange(prefill_len, device=self.device).unsqueeze(0).expand(2, -1)
|
293 |
-
prefill_tgt_padding_mask = (generated_BxTxC != audio_pad_value).any(dim=2)
|
294 |
-
|
295 |
-
prefill_self_attn_mask = self._create_attn_mask(
|
296 |
-
prefill_tgt_padding_mask,
|
297 |
-
prefill_tgt_padding_mask,
|
298 |
-
is_causal=True,
|
299 |
-
)
|
300 |
-
prefill_cross_attn_mask = self._create_attn_mask(
|
301 |
-
prefill_tgt_padding_mask,
|
302 |
-
src_padding_mask_BxS,
|
303 |
-
is_causal=False,
|
304 |
-
)
|
305 |
|
306 |
-
|
307 |
-
|
308 |
-
encoder_out=encoder_out,
|
309 |
-
tgt_positions=prefill_tgt_pos,
|
310 |
-
src_positions=src_positions_BxS,
|
311 |
-
deterministic=True,
|
312 |
-
self_attn_mask=prefill_self_attn_mask,
|
313 |
-
cross_attn_mask=prefill_cross_attn_mask,
|
314 |
-
self_attention_cache=decoder_self_attention_cache,
|
315 |
-
cross_attention_cache=decoder_cross_attention_cache,
|
316 |
-
)
|
317 |
|
318 |
-
|
|
|
319 |
|
320 |
-
|
321 |
-
|
322 |
eos_countdown = -1
|
323 |
-
extra_steps_after_eos = 30
|
324 |
-
# Make generated_BxTxC a fixed size tensor
|
325 |
-
# Length is either 1 + max tokens or 1 + prompt len + max tokens
|
326 |
-
generated_BxTxC = torch.cat(
|
327 |
-
[
|
328 |
-
generated_BxTxC,
|
329 |
-
torch.full(
|
330 |
-
(2, max_tokens, num_channels),
|
331 |
-
fill_value=-1,
|
332 |
-
dtype=torch.long,
|
333 |
-
device=self.device,
|
334 |
-
),
|
335 |
-
],
|
336 |
-
dim=1,
|
337 |
-
)
|
338 |
|
339 |
-
decode_step = self.model.decoder.decode_step
|
340 |
if use_torch_compile:
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
)
|
345 |
|
346 |
-
|
347 |
-
(
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
is_causal=False,
|
354 |
-
) # [B, 1, 1, S]
|
355 |
-
|
356 |
-
for step in range(current_step, current_step + max_tokens):
|
357 |
-
tgt_ids_Bx1xC = generated_BxTxC[:, step, :].unsqueeze(1)
|
358 |
-
tgt_pos_Bx1 = torch.full(
|
359 |
-
(2, 1),
|
360 |
-
fill_value=step,
|
361 |
-
dtype=torch.long,
|
362 |
-
device=self.device,
|
363 |
-
)
|
364 |
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
self_attn_mask=None,
|
370 |
-
cross_attn_mask=decoder_cross_attn_mask,
|
371 |
-
self_attention_cache=decoder_self_attention_cache,
|
372 |
-
cross_attention_cache=decoder_cross_attention_cache,
|
373 |
)
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
cond_logits_CxV = logits_last_BxCxV[1, :, :]
|
382 |
-
|
383 |
-
cfg_logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
|
384 |
-
|
385 |
-
logits_CxV = cfg_logits_CxV.reshape((-1, V)) # C, V
|
386 |
-
logits_CxV[:, 1025:] = -torch.inf
|
387 |
-
|
388 |
-
# Sample next token
|
389 |
-
pred_C = _sample_next_token(
|
390 |
-
logits_CxV.float(),
|
391 |
-
temperature=temperature,
|
392 |
-
top_p=top_p,
|
393 |
-
use_cfg_filter=use_cfg_filter,
|
394 |
-
cfg_filter_top_k=cfg_filter_top_k,
|
395 |
)
|
396 |
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
audio_bos_value,
|
403 |
-
)
|
404 |
-
|
405 |
-
generated_BxTxC[:, step + 1, :] = pred_C.unsqueeze(0).expand(2, -1)
|
406 |
-
|
407 |
-
if not eos_detected_channel_0 and pred_C[0] == audio_eos_value:
|
408 |
-
eos_detected_channel_0 = True
|
409 |
-
eos_countdown = extra_steps_after_eos
|
410 |
|
411 |
if eos_countdown > 0:
|
412 |
step_after_eos = max_delay_pattern - eos_countdown
|
413 |
for i, d in enumerate(delay_pattern):
|
414 |
if step_after_eos == d:
|
415 |
-
|
416 |
elif step_after_eos > d:
|
417 |
-
|
418 |
eos_countdown -= 1
|
419 |
-
if eos_countdown == 0:
|
420 |
-
break
|
421 |
|
422 |
-
|
|
|
423 |
|
424 |
-
|
|
|
425 |
|
426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
import dac
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
import torchaudio
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
|
10 |
+
from .audio import (
|
11 |
+
apply_audio_delay,
|
12 |
+
build_delay_indices,
|
13 |
+
build_revert_indices,
|
14 |
+
decode,
|
15 |
+
revert_audio_delay,
|
16 |
+
)
|
17 |
from .config import DiaConfig
|
18 |
+
from .layers import DiaModel
|
19 |
+
from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState
|
20 |
+
|
21 |
+
|
22 |
+
DEFAULT_SAMPLE_RATE = 44100
|
23 |
+
|
24 |
+
|
25 |
+
def _get_default_device():
|
26 |
+
if torch.cuda.is_available():
|
27 |
+
return torch.device("cuda")
|
28 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
29 |
+
return torch.device("mps")
|
30 |
+
return torch.device("cpu")
|
31 |
|
32 |
|
33 |
def _sample_next_token(
|
34 |
logits_BCxV: torch.Tensor,
|
35 |
temperature: float,
|
36 |
top_p: float,
|
|
|
37 |
cfg_filter_top_k: int | None = None,
|
38 |
) -> torch.Tensor:
|
39 |
if temperature == 0.0:
|
40 |
return torch.argmax(logits_BCxV, dim=-1)
|
41 |
|
42 |
logits_BCxV = logits_BCxV / temperature
|
43 |
+
if cfg_filter_top_k is not None:
|
44 |
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
|
45 |
mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
|
46 |
mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
|
|
|
48 |
|
49 |
if top_p < 1.0:
|
50 |
probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
51 |
+
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
|
52 |
+
probs_BCxV, dim=-1, descending=True
|
53 |
+
)
|
54 |
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
|
55 |
|
|
|
56 |
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
|
57 |
+
sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[
|
58 |
+
..., :-1
|
59 |
+
].clone()
|
60 |
+
sorted_indices_to_remove_BCxV[..., 0] = 0
|
61 |
|
62 |
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
|
63 |
+
indices_to_remove_BCxV.scatter_(
|
64 |
+
dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
|
65 |
+
)
|
66 |
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
|
67 |
|
68 |
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
|
|
72 |
return sampled_indices_C
|
73 |
|
74 |
|
75 |
+
class ComputeDtype(str, Enum):
|
76 |
+
FLOAT32 = "float32"
|
77 |
+
FLOAT16 = "float16"
|
78 |
+
BFLOAT16 = "bfloat16"
|
79 |
+
|
80 |
+
def to_dtype(self) -> torch.dtype:
|
81 |
+
if self == ComputeDtype.FLOAT32:
|
82 |
+
return torch.float32
|
83 |
+
elif self == ComputeDtype.FLOAT16:
|
84 |
+
return torch.float16
|
85 |
+
elif self == ComputeDtype.BFLOAT16:
|
86 |
+
return torch.bfloat16
|
87 |
+
else:
|
88 |
+
raise ValueError(f"Unsupported compute dtype: {self}")
|
89 |
+
|
90 |
+
|
91 |
class Dia:
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
config: DiaConfig,
|
95 |
+
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
|
96 |
+
device: torch.device | None = None,
|
97 |
+
):
|
98 |
"""Initializes the Dia model.
|
99 |
|
100 |
Args:
|
101 |
config: The configuration object for the model.
|
102 |
+
device: The device to load the model onto. If None, will automatically select the best available device.
|
103 |
|
104 |
Raises:
|
105 |
RuntimeError: If there is an error loading the DAC model.
|
106 |
"""
|
107 |
super().__init__()
|
108 |
self.config = config
|
109 |
+
self.device = device if device is not None else _get_default_device()
|
110 |
+
if isinstance(compute_dtype, str):
|
111 |
+
compute_dtype = ComputeDtype(compute_dtype)
|
112 |
+
self.compute_dtype = compute_dtype.to_dtype()
|
113 |
+
self.model = DiaModel(config, self.compute_dtype)
|
114 |
self.dac_model = None
|
115 |
|
116 |
@classmethod
|
117 |
+
def from_local(
|
118 |
+
cls,
|
119 |
+
config_path: str,
|
120 |
+
checkpoint_path: str,
|
121 |
+
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
|
122 |
+
device: torch.device | None = None,
|
123 |
+
) -> "Dia":
|
124 |
"""Loads the Dia model from local configuration and checkpoint files.
|
125 |
|
126 |
Args:
|
127 |
config_path: Path to the configuration JSON file.
|
128 |
checkpoint_path: Path to the model checkpoint (.pth) file.
|
129 |
+
device: The device to load the model onto. If None, will automatically select the best available device.
|
130 |
|
131 |
Returns:
|
132 |
An instance of the Dia model loaded with weights and set to eval mode.
|
|
|
139 |
if config is None:
|
140 |
raise FileNotFoundError(f"Config file not found at {config_path}")
|
141 |
|
142 |
+
dia = cls(config, compute_dtype, device)
|
143 |
|
144 |
try:
|
145 |
+
state_dict = torch.load(checkpoint_path, map_location=dia.device)
|
146 |
+
dia.model.load_state_dict(state_dict)
|
147 |
except FileNotFoundError:
|
148 |
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
149 |
except Exception as e:
|
150 |
+
raise RuntimeError(
|
151 |
+
f"Error loading checkpoint from {checkpoint_path}"
|
152 |
+
) from e
|
153 |
|
154 |
+
dia.model.to(dia.device)
|
155 |
dia.model.eval()
|
156 |
dia._load_dac_model()
|
157 |
return dia
|
158 |
|
159 |
@classmethod
|
160 |
def from_pretrained(
|
161 |
+
cls,
|
162 |
+
model_name: str = "nari-labs/Dia-1.6B",
|
163 |
+
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
|
164 |
+
device: torch.device | None = None,
|
165 |
) -> "Dia":
|
166 |
"""Loads the Dia model from a Hugging Face Hub repository.
|
167 |
|
|
|
170 |
|
171 |
Args:
|
172 |
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
|
173 |
+
device: The device to load the model onto. If None, will automatically select the best available device.
|
174 |
|
175 |
Returns:
|
176 |
An instance of the Dia model loaded with weights and set to eval mode.
|
|
|
181 |
"""
|
182 |
config_path = hf_hub_download(repo_id=model_name, filename="config.json")
|
183 |
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
|
184 |
+
return cls.from_local(config_path, checkpoint_path, compute_dtype, device)
|
185 |
|
186 |
def _load_dac_model(self):
|
187 |
try:
|
|
|
191 |
raise RuntimeError("Failed to load DAC model") from e
|
192 |
self.dac_model = dac_model
|
193 |
|
194 |
+
def _prepare_text_input(self, text: str) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
"""Encodes text prompt, pads, and creates attention mask and positions."""
|
196 |
text_pad_value = self.config.data.text_pad_value
|
197 |
max_len = self.config.data.text_length
|
|
|
213 |
constant_values=text_pad_value,
|
214 |
).astype(np.uint8)
|
215 |
|
216 |
+
src_tokens = (
|
217 |
+
torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0)
|
218 |
+
) # [1, S]
|
219 |
+
return src_tokens
|
220 |
|
221 |
+
def _prepare_audio_prompt(
|
222 |
+
self, audio_prompt: torch.Tensor | None
|
223 |
+
) -> tuple[torch.Tensor, int]:
|
224 |
+
num_channels = self.config.data.channels
|
225 |
+
audio_bos_value = self.config.data.audio_bos_value
|
226 |
+
audio_pad_value = self.config.data.audio_pad_value
|
227 |
+
delay_pattern = self.config.data.delay_pattern
|
228 |
+
max_delay_pattern = max(delay_pattern)
|
229 |
|
230 |
+
prefill = torch.full(
|
231 |
+
(1, num_channels),
|
232 |
+
fill_value=audio_bos_value,
|
233 |
+
dtype=torch.int,
|
234 |
+
device=self.device,
|
235 |
+
)
|
236 |
|
237 |
+
prefill_step = 1
|
238 |
+
|
239 |
+
if audio_prompt is not None:
|
240 |
+
prefill_step += audio_prompt.shape[0]
|
241 |
+
prefill = torch.cat([prefill, audio_prompt], dim=0)
|
242 |
+
|
243 |
+
delay_pad_tensor = torch.full(
|
244 |
+
(max_delay_pattern, num_channels),
|
245 |
+
fill_value=-1,
|
246 |
+
dtype=torch.int,
|
247 |
+
device=self.device,
|
248 |
+
)
|
249 |
+
prefill = torch.cat([prefill, delay_pad_tensor], dim=0)
|
250 |
+
|
251 |
+
delay_precomp = build_delay_indices(
|
252 |
+
B=1,
|
253 |
+
T=prefill.shape[0],
|
254 |
+
C=num_channels,
|
255 |
+
delay_pattern=delay_pattern,
|
256 |
+
)
|
257 |
+
|
258 |
+
prefill = apply_audio_delay(
|
259 |
+
audio_BxTxC=prefill.unsqueeze(0),
|
260 |
+
pad_value=audio_pad_value,
|
261 |
+
bos_value=audio_bos_value,
|
262 |
+
precomp=delay_precomp,
|
263 |
+
).squeeze(0)
|
264 |
+
|
265 |
+
return prefill, prefill_step
|
266 |
+
|
267 |
+
def _prepare_generation(
|
268 |
+
self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool
|
269 |
+
):
|
270 |
+
enc_input_cond = self._prepare_text_input(text)
|
271 |
+
enc_input_uncond = torch.zeros_like(enc_input_cond)
|
272 |
+
enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0)
|
273 |
+
|
274 |
+
if isinstance(audio_prompt, str):
|
275 |
+
audio_prompt = self.load_audio(audio_prompt)
|
276 |
+
prefill, prefill_step = self._prepare_audio_prompt(audio_prompt)
|
277 |
+
|
278 |
+
if verbose:
|
279 |
+
print("generate: data loaded")
|
280 |
+
|
281 |
+
enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
|
282 |
+
encoder_out = self.model.encoder(enc_input, enc_state)
|
283 |
+
|
284 |
+
dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(
|
285 |
+
encoder_out, enc_state.positions
|
286 |
+
)
|
287 |
+
dec_state = DecoderInferenceState.new(
|
288 |
+
self.config,
|
289 |
+
enc_state,
|
290 |
+
encoder_out,
|
291 |
+
dec_cross_attn_cache,
|
292 |
+
self.compute_dtype,
|
293 |
+
)
|
294 |
+
dec_output = DecoderOutput.new(self.config, self.device)
|
295 |
+
dec_output.prefill(prefill, prefill_step)
|
296 |
+
|
297 |
+
dec_step = prefill_step - 1
|
298 |
+
if dec_step > 0:
|
299 |
+
dec_state.prepare_step(0, dec_step)
|
300 |
+
tokens_BxTxC = (
|
301 |
+
dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1)
|
302 |
+
)
|
303 |
+
self.model.decoder.forward(tokens_BxTxC, dec_state)
|
304 |
+
|
305 |
+
return dec_state, dec_output
|
306 |
+
|
307 |
+
def _decoder_step(
|
308 |
+
self,
|
309 |
+
tokens_Bx1xC: torch.Tensor,
|
310 |
+
dec_state: DecoderInferenceState,
|
311 |
+
cfg_scale: float,
|
312 |
+
temperature: float,
|
313 |
+
top_p: float,
|
314 |
+
cfg_filter_top_k: int,
|
315 |
+
) -> torch.Tensor:
|
316 |
+
audio_eos_value = self.config.data.audio_eos_value
|
317 |
+
logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state)
|
318 |
+
|
319 |
+
logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :]
|
320 |
+
uncond_logits_CxV = logits_last_BxCxV[0, :, :]
|
321 |
+
cond_logits_CxV = logits_last_BxCxV[1, :, :]
|
322 |
+
|
323 |
+
logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
|
324 |
+
logits_CxV[:, audio_eos_value + 1 :] = -torch.inf
|
325 |
+
logits_CxV[1:, audio_eos_value:] = -torch.inf
|
326 |
+
|
327 |
+
pred_C = _sample_next_token(
|
328 |
+
logits_CxV.float(),
|
329 |
+
temperature=temperature,
|
330 |
+
top_p=top_p,
|
331 |
+
cfg_filter_top_k=cfg_filter_top_k,
|
332 |
+
)
|
333 |
+
return pred_C
|
334 |
+
|
335 |
+
def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray:
|
336 |
+
num_channels = self.config.data.channels
|
337 |
+
seq_length = generated_codes.shape[0]
|
338 |
+
delay_pattern = self.config.data.delay_pattern
|
339 |
+
audio_pad_value = self.config.data.audio_pad_value
|
340 |
+
max_delay_pattern = max(delay_pattern)
|
341 |
+
|
342 |
+
revert_precomp = build_revert_indices(
|
343 |
+
B=1,
|
344 |
+
T=seq_length,
|
345 |
+
C=num_channels,
|
346 |
+
delay_pattern=delay_pattern,
|
347 |
+
)
|
348 |
+
|
349 |
+
codebook = revert_audio_delay(
|
350 |
+
audio_BxTxC=generated_codes.unsqueeze(0),
|
351 |
+
pad_value=audio_pad_value,
|
352 |
+
precomp=revert_precomp,
|
353 |
+
T=seq_length,
|
354 |
+
)[:, :-max_delay_pattern, :]
|
355 |
+
|
356 |
+
min_valid_index = 0
|
357 |
+
max_valid_index = 1023
|
358 |
+
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
|
359 |
+
codebook[invalid_mask] = 0
|
360 |
+
|
361 |
+
audio = decode(self.dac_model, codebook.transpose(1, 2))
|
362 |
+
|
363 |
+
return audio.squeeze().cpu().numpy()
|
364 |
+
|
365 |
+
def load_audio(self, audio_path: str) -> torch.Tensor:
|
366 |
+
audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T
|
367 |
+
if sr != DEFAULT_SAMPLE_RATE:
|
368 |
+
audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
|
369 |
+
audio = audio.to(self.device).unsqueeze(0) # 1, C, T
|
370 |
+
audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
|
371 |
+
_, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) # 1, C, T
|
372 |
+
return encoded_frame.squeeze(0).transpose(0, 1)
|
373 |
+
|
374 |
+
def save_audio(self, path: str, audio: np.ndarray):
|
375 |
+
import soundfile as sf
|
376 |
+
|
377 |
+
sf.write(path, audio, DEFAULT_SAMPLE_RATE)
|
378 |
|
379 |
@torch.inference_mode()
|
380 |
def generate(
|
|
|
384 |
cfg_scale: float = 3.0,
|
385 |
temperature: float = 1.3,
|
386 |
top_p: float = 0.95,
|
387 |
+
use_torch_compile: bool = False,
|
388 |
+
cfg_filter_top_k: int = 35,
|
389 |
+
audio_prompt: str | torch.Tensor | None = None,
|
390 |
audio_prompt_path: str | None = None,
|
391 |
+
use_cfg_filter: bool | None = None,
|
392 |
+
verbose: bool = False,
|
393 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
audio_eos_value = self.config.data.audio_eos_value
|
395 |
audio_pad_value = self.config.data.audio_pad_value
|
396 |
delay_pattern = self.config.data.delay_pattern
|
397 |
max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
|
|
|
398 |
max_delay_pattern = max(delay_pattern)
|
399 |
self.model.eval()
|
400 |
|
401 |
+
if audio_prompt_path:
|
402 |
+
print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.")
|
403 |
+
audio_prompt = audio_prompt_path
|
404 |
+
if use_cfg_filter is not None:
|
405 |
+
print("Warning: use_cfg_filter is deprecated.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
|
407 |
+
if verbose:
|
408 |
+
total_start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
+
dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose)
|
411 |
+
dec_step = dec_output.prefill_step - 1
|
412 |
|
413 |
+
bos_countdown = max_delay_pattern
|
414 |
+
eos_detected = False
|
415 |
eos_countdown = -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
|
|
417 |
if use_torch_compile:
|
418 |
+
step_fn = torch.compile(self._decoder_step, mode="default")
|
419 |
+
else:
|
420 |
+
step_fn = self._decoder_step
|
|
|
421 |
|
422 |
+
if verbose:
|
423 |
+
print("generate: starting generation loop")
|
424 |
+
if use_torch_compile:
|
425 |
+
print(
|
426 |
+
"generate: by using use_torch_compile=True, the first step would take long"
|
427 |
+
)
|
428 |
+
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
429 |
|
430 |
+
while dec_step < max_tokens:
|
431 |
+
dec_state.prepare_step(dec_step)
|
432 |
+
tokens_Bx1xC = (
|
433 |
+
dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1)
|
|
|
|
|
|
|
|
|
434 |
)
|
435 |
+
pred_C = step_fn(
|
436 |
+
tokens_Bx1xC,
|
437 |
+
dec_state,
|
438 |
+
cfg_scale,
|
439 |
+
temperature,
|
440 |
+
top_p,
|
441 |
+
cfg_filter_top_k,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
)
|
443 |
|
444 |
+
if (
|
445 |
+
not eos_detected and pred_C[0] == audio_eos_value
|
446 |
+
) or dec_step == max_tokens - max_delay_pattern - 1:
|
447 |
+
eos_detected = True
|
448 |
+
eos_countdown = max_delay_pattern
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
|
450 |
if eos_countdown > 0:
|
451 |
step_after_eos = max_delay_pattern - eos_countdown
|
452 |
for i, d in enumerate(delay_pattern):
|
453 |
if step_after_eos == d:
|
454 |
+
pred_C[i] = audio_eos_value
|
455 |
elif step_after_eos > d:
|
456 |
+
pred_C[i] = audio_pad_value
|
457 |
eos_countdown -= 1
|
|
|
|
|
458 |
|
459 |
+
bos_countdown = max(0, bos_countdown - 1)
|
460 |
+
dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0)
|
461 |
|
462 |
+
if eos_countdown == 0:
|
463 |
+
break
|
464 |
|
465 |
+
dec_step += 1
|
466 |
+
if verbose and dec_step % 86 == 0:
|
467 |
+
duration = time.time() - start_time
|
468 |
+
print(
|
469 |
+
f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x"
|
470 |
+
)
|
471 |
+
start_time = time.time()
|
472 |
|
473 |
+
if dec_output.prefill_step >= dec_step + 1:
|
474 |
+
print("Warning: Nothing generated")
|
475 |
+
return None
|
476 |
+
|
477 |
+
generated_codes = dec_output.generated_tokens[
|
478 |
+
dec_output.prefill_step : dec_step + 1, :
|
479 |
+
]
|
480 |
+
|
481 |
+
if verbose:
|
482 |
+
total_step = dec_step + 1 - dec_output.prefill_step
|
483 |
+
total_duration = time.time() - total_start_time
|
484 |
+
print(
|
485 |
+
f"generate: total step={total_step}, total duration={total_duration:.3f}s"
|
486 |
+
)
|
487 |
+
|
488 |
+
return self._generate_output(generated_codes)
|
dia/state.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .config import DiaConfig
|
6 |
+
|
7 |
+
|
8 |
+
def create_attn_mask(
|
9 |
+
q_padding_mask_1d: torch.Tensor,
|
10 |
+
k_padding_mask_1d: torch.Tensor,
|
11 |
+
device: torch.device,
|
12 |
+
is_causal: bool = False,
|
13 |
+
) -> torch.Tensor:
|
14 |
+
"""
|
15 |
+
Creates the attention mask (self or cross) mimicking JAX segment ID logic.
|
16 |
+
"""
|
17 |
+
B1, Tq = q_padding_mask_1d.shape
|
18 |
+
B2, Tk = k_padding_mask_1d.shape
|
19 |
+
assert B1 == B2, "Query and key batch dimensions must match"
|
20 |
+
|
21 |
+
p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
|
22 |
+
p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
|
23 |
+
|
24 |
+
# Condition A: Non-padding query attends to non-padding key
|
25 |
+
non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
|
26 |
+
|
27 |
+
# Condition B: Padding query attends to padding key
|
28 |
+
pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
|
29 |
+
|
30 |
+
# Combine: True if padding status is compatible (both non-pad OR both pad)
|
31 |
+
mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
|
32 |
+
|
33 |
+
if is_causal:
|
34 |
+
assert Tq == Tk, (
|
35 |
+
"Causal mask requires query and key sequence lengths to be equal"
|
36 |
+
)
|
37 |
+
causal_mask_2d = torch.tril(
|
38 |
+
torch.ones((Tq, Tk), dtype=torch.bool, device=device)
|
39 |
+
) # Shape [Tq, Tk]
|
40 |
+
causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
|
41 |
+
return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
|
42 |
+
else:
|
43 |
+
return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
|
44 |
+
|
45 |
+
|
46 |
+
@dataclass
|
47 |
+
class EncoderInferenceState:
|
48 |
+
"""Parameters specifically for encoder inference."""
|
49 |
+
|
50 |
+
max_seq_len: int
|
51 |
+
device: torch.device
|
52 |
+
positions: torch.Tensor
|
53 |
+
padding_mask: torch.Tensor
|
54 |
+
attn_mask: torch.Tensor
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def new(cls, config: DiaConfig, cond_src: torch.Tensor) -> "EncoderInferenceState":
|
58 |
+
"""Creates EtorchrInferenceParams from DiaConfig and a device."""
|
59 |
+
device = cond_src.device
|
60 |
+
|
61 |
+
positions = (
|
62 |
+
torch.arange(config.data.text_length, device=device)
|
63 |
+
.to(torch.long)
|
64 |
+
.unsqueeze(0)
|
65 |
+
.expand(2, -1)
|
66 |
+
)
|
67 |
+
padding_mask = (cond_src != config.data.text_pad_value).to(device).expand(2, -1)
|
68 |
+
attn_mask = create_attn_mask(
|
69 |
+
padding_mask, padding_mask, device, is_causal=False
|
70 |
+
)
|
71 |
+
|
72 |
+
return cls(
|
73 |
+
max_seq_len=config.data.text_length,
|
74 |
+
device=device,
|
75 |
+
positions=positions,
|
76 |
+
padding_mask=padding_mask,
|
77 |
+
attn_mask=attn_mask,
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
class KVCache:
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
num_heads: int,
|
85 |
+
max_len: int,
|
86 |
+
head_dim: int,
|
87 |
+
dtype: torch.dtype,
|
88 |
+
device: torch.device,
|
89 |
+
k: torch.Tensor | None = None,
|
90 |
+
v: torch.Tensor | None = None,
|
91 |
+
):
|
92 |
+
self.k = (
|
93 |
+
torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
|
94 |
+
if k is None
|
95 |
+
else k
|
96 |
+
)
|
97 |
+
self.v = (
|
98 |
+
torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
|
99 |
+
if v is None
|
100 |
+
else v
|
101 |
+
)
|
102 |
+
self.current_idx = torch.tensor(0)
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache":
|
106 |
+
return cls(
|
107 |
+
num_heads=k.shape[1],
|
108 |
+
max_len=k.shape[2],
|
109 |
+
head_dim=k.shape[3],
|
110 |
+
dtype=k.dtype,
|
111 |
+
device=k.device,
|
112 |
+
k=k,
|
113 |
+
v=v,
|
114 |
+
)
|
115 |
+
|
116 |
+
def update(
|
117 |
+
self, k: torch.Tensor, v: torch.Tensor
|
118 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
119 |
+
self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
|
120 |
+
self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
|
121 |
+
self.current_idx += 1
|
122 |
+
return self.k[:, :, : self.current_idx, :], self.v[:, :, : self.current_idx, :]
|
123 |
+
|
124 |
+
def prefill(
|
125 |
+
self, k: torch.Tensor, v: torch.Tensor
|
126 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
127 |
+
prefill_len = k.shape[2]
|
128 |
+
self.k[:, :, :prefill_len, :] = k
|
129 |
+
self.v[:, :, :prefill_len, :] = v
|
130 |
+
self.current_idx = prefill_len - 1
|
131 |
+
|
132 |
+
|
133 |
+
@dataclass
|
134 |
+
class DecoderInferenceState:
|
135 |
+
"""Parameters specifically for decoder inference."""
|
136 |
+
|
137 |
+
device: torch.device
|
138 |
+
dtype: torch.dtype
|
139 |
+
enc_out: torch.Tensor
|
140 |
+
enc_positions: torch.Tensor
|
141 |
+
dec_positions: torch.Tensor
|
142 |
+
dec_cross_attn_mask: torch.Tensor
|
143 |
+
self_attn_cache: list[KVCache]
|
144 |
+
cross_attn_cache: list[KVCache]
|
145 |
+
|
146 |
+
@classmethod
|
147 |
+
def new(
|
148 |
+
cls,
|
149 |
+
config: DiaConfig,
|
150 |
+
enc_state: EncoderInferenceState,
|
151 |
+
enc_out: torch.Tensor,
|
152 |
+
dec_cross_attn_cache: list[KVCache],
|
153 |
+
compute_dtype: torch.dtype,
|
154 |
+
) -> "DecoderInferenceState":
|
155 |
+
"""Creates DecoderInferenceParams from DiaConfig and a device."""
|
156 |
+
device = enc_out.device
|
157 |
+
max_audio_len = config.data.audio_length
|
158 |
+
|
159 |
+
dec_positions = torch.full(
|
160 |
+
(2, 1), fill_value=0, dtype=torch.long, device=device
|
161 |
+
)
|
162 |
+
tgt_padding_mask = torch.ones((2, 1), dtype=torch.bool, device=device)
|
163 |
+
dec_cross_attn_mask = create_attn_mask(
|
164 |
+
tgt_padding_mask, enc_state.padding_mask, device, is_causal=False
|
165 |
+
)
|
166 |
+
|
167 |
+
self_attn_cache = [
|
168 |
+
KVCache(
|
169 |
+
config.model.decoder.kv_heads,
|
170 |
+
max_audio_len,
|
171 |
+
config.model.decoder.gqa_head_dim,
|
172 |
+
compute_dtype,
|
173 |
+
device,
|
174 |
+
)
|
175 |
+
for _ in range(config.model.decoder.n_layer)
|
176 |
+
]
|
177 |
+
|
178 |
+
return cls(
|
179 |
+
device=device,
|
180 |
+
dtype=compute_dtype,
|
181 |
+
enc_out=enc_out,
|
182 |
+
enc_positions=enc_state.positions,
|
183 |
+
dec_positions=dec_positions,
|
184 |
+
dec_cross_attn_mask=dec_cross_attn_mask,
|
185 |
+
self_attn_cache=self_attn_cache,
|
186 |
+
cross_attn_cache=dec_cross_attn_cache,
|
187 |
+
)
|
188 |
+
|
189 |
+
def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
|
190 |
+
if step_to is None:
|
191 |
+
step_to = step_from + 1
|
192 |
+
self.dec_positions = (
|
193 |
+
torch.arange(step_from, step_to, device=self.device)
|
194 |
+
.unsqueeze(0)
|
195 |
+
.expand(2, -1)
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
@dataclass
|
200 |
+
class DecoderOutput:
|
201 |
+
generated_tokens: torch.Tensor
|
202 |
+
prefill_step: int
|
203 |
+
|
204 |
+
@classmethod
|
205 |
+
def new(cls, config: DiaConfig, device: torch.device) -> "DecoderOutput":
|
206 |
+
max_audio_len = config.data.audio_length
|
207 |
+
return cls(
|
208 |
+
generated_tokens=torch.full(
|
209 |
+
(max_audio_len, config.data.channels),
|
210 |
+
fill_value=-1,
|
211 |
+
dtype=torch.int,
|
212 |
+
device=device,
|
213 |
+
),
|
214 |
+
prefill_step=0,
|
215 |
+
)
|
216 |
+
|
217 |
+
def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor:
|
218 |
+
if step_to is None:
|
219 |
+
step_to = step_from + 1
|
220 |
+
return self.generated_tokens[step_from:step_to, :]
|
221 |
+
|
222 |
+
def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
|
223 |
+
if apply_mask:
|
224 |
+
mask = self.generated_tokens[step : step + 1, :] == -1
|
225 |
+
self.generated_tokens[step : step + 1, :] = torch.where(
|
226 |
+
mask, dec_out, self.generated_tokens[step : step + 1, :]
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
self.generated_tokens[step : step + 1, :] = dec_out
|
230 |
+
|
231 |
+
def prefill(self, dec_out: torch.Tensor, prefill_step: int):
|
232 |
+
length = dec_out.shape[0]
|
233 |
+
self.generated_tokens[0:length, :] = dec_out
|
234 |
+
self.prefill_step = prefill_step
|