|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
|
|
from tts.modules.llm_dit.cfm import ConditionalFlowMatcher |
|
from tts.modules.ar_dur.commons.layers import Embedding |
|
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb |
|
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder |
|
from tts.modules.ar_dur.ar_dur_predictor import expand_states |
|
from tts.modules.llm_dit.transformer import Transformer |
|
from tts.modules.llm_dit.time_embedding import TimestepEmbedding |
|
|
|
|
|
class Diffusion(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
self.local_cond_dim = 512 |
|
self.ctx_mask_dim = 16 |
|
self.in_channels = 32 |
|
self.out_channels = 32 |
|
|
|
self.encoder_dim = 1024 |
|
self.encoder_n_layers = 24 |
|
self.encoder_n_heads = 16 |
|
self.max_seq_len = 16384 |
|
self.multiple_of = 256 |
|
|
|
self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim) |
|
self.local_cond_project = nn.Linear( |
|
self.out_channels + self.ctx_mask_dim, self.local_cond_dim) |
|
|
|
self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len) |
|
|
|
self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim) |
|
self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim) |
|
self.postnet = nn.Linear(self.encoder_dim, self.out_channels) |
|
|
|
self.flow_matcher = ConditionalFlowMatcher(sigma=0.0) |
|
|
|
|
|
self.f5_time_embed = TimestepEmbedding(self.encoder_dim) |
|
|
|
|
|
self.ph_encoder = RelTransformerEncoder( |
|
302, self.encoder_dim, self.encoder_dim, |
|
self.encoder_dim * 2, 4, 6, |
|
3, 0.0, prenet=True, pre_ln=True) |
|
self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0) |
|
self.ph_pos_embed = PosEmb(self.encoder_dim) |
|
self.ling_pre_net = torch.nn.Sequential(*[ |
|
torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2) |
|
for i, s in enumerate([2, 2]) |
|
]) |
|
|
|
def forward(self, inputs, sigmas=None, x_noisy=None): |
|
ctx_mask = inputs['ctx_mask'] |
|
ctx_feature = inputs['lat_ctx'] * ctx_mask |
|
|
|
""" local conditioning (prompt_latent + spk_embed) """ |
|
ctx_mask_emb = self.ctx_mask_proj(ctx_mask) |
|
|
|
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1) |
|
local_cond = self.local_cond_project(local_cond) |
|
|
|
""" diffusion target latent """ |
|
x = inputs['lat'] |
|
|
|
|
|
x0 = torch.randn_like(x) |
|
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x) |
|
|
|
|
|
t = t.bfloat16() |
|
x_noisy = (xt * (1 - ctx_mask)).bfloat16() |
|
target = ut |
|
|
|
|
|
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"]) |
|
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2) |
|
x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling |
|
encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False) |
|
pred = self.postnet(encoder_out) |
|
|
|
return pred, target |
|
|
|
def forward_ling_encoder(self, txt_tokens, tone_tokens): |
|
ph_tokens = txt_tokens |
|
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] |
|
|
|
|
|
ph_enc_oembed = self.tone_embed(tone_tokens) |
|
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed( |
|
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device)) |
|
ph_enc_oembed = ph_enc_oembed |
|
ph_enc_oembed = ph_enc_oembed * ph_nonpadding |
|
x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding |
|
return x_ling |
|
|
|
def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]): |
|
""" When we use torchdiffeq, we need to include the CFG process inside _forward() """ |
|
x = x * (1 - ctx_mask) |
|
x = self.x_prenet(x) + self.prenet(local_cond) + x_ling |
|
pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device)) |
|
pred = self.postnet(pred_v) |
|
|
|
""" Perform multi-cond CFG """ |
|
cond_spk_txt, cond_txt, uncond = pred.chunk(3) |
|
pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt) |
|
return pred |
|
|
|
@torch.no_grad() |
|
def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs): |
|
|
|
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"]) |
|
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2) |
|
|
|
|
|
ctx_feature = inputs['lat_ctx'] |
|
ctx_feature[1:, :, :] = 0 |
|
ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask']) |
|
|
|
|
|
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1) |
|
local_cond = self.local_cond_project(local_cond) |
|
|
|
''' Euler ODE solver ''' |
|
bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1)) |
|
|
|
|
|
sway_sampling_coef = -1.0 |
|
t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype) |
|
if sway_sampling_coef is not None: |
|
t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule) |
|
|
|
|
|
def amo_sampling(z_t, t, t_next, v): |
|
|
|
z_t = z_t.to(torch.float32) |
|
|
|
|
|
s = t_next |
|
c = 3 |
|
|
|
|
|
o = min(t_next + c * (t_next - t), 1) |
|
pred_z_o = z_t + (o - t) * v |
|
|
|
|
|
a = s / o |
|
b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5 |
|
noise_i = torch.randn(size=z_t.shape, device=z_t.device) |
|
z_t_next = a * pred_z_o + b * noise_i |
|
return z_t_next.to(v.dtype) |
|
|
|
x = torch.randn([1, frm_len, self.out_channels], device=device) |
|
for step_index in range(timesteps): |
|
x = x.to(torch.float32) |
|
sigma = t_schedule[step_index].to(x_ling.dtype) |
|
sigma_next = t_schedule[step_index + 1] |
|
model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w) |
|
x = amo_sampling(x, sigma, sigma_next, model_out) |
|
|
|
x = x.to(model_out.dtype) |
|
|
|
return x |
|
|