Spaces:
Paused
Paused
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import torch.nn.functional as F | |
import math | |
import json5 | |
from librosa.filters import mel as librosa_mel_fn | |
from einops.layers.torch import Rearrange | |
class Diffusion(nn.Module): | |
def __init__(self, cfg, diff_model): | |
super().__init__() | |
self.cfg = cfg | |
self.diff_estimator = diff_model | |
self.beta_min = cfg.beta_min | |
self.beta_max = cfg.beta_max | |
self.sigma = cfg.sigma | |
self.noise_factor = cfg.noise_factor | |
def forward(self, x, condition_embedding, x_mask, reference_embedding, offset=1e-5): | |
diffusion_step = torch.rand( | |
x.shape[0], dtype=x.dtype, device=x.device, requires_grad=False | |
) | |
diffusion_step = torch.clamp(diffusion_step, offset, 1.0 - offset) | |
xt, z = self.forward_diffusion(x0=x, diffusion_step=diffusion_step) | |
cum_beta = self.get_cum_beta(diffusion_step.unsqueeze(-1).unsqueeze(-1)) | |
x0_pred = self.diff_estimator( | |
xt, condition_embedding, x_mask, reference_embedding, diffusion_step | |
) | |
mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2)) | |
variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2))) | |
noise_pred = (xt - mean_pred) / (torch.sqrt(variance) * self.noise_factor) | |
noise = z | |
diff_out = {"x0_pred": x0_pred, "noise_pred": noise_pred, "noise": noise} | |
return diff_out | |
def get_cum_beta(self, time_step): | |
return self.beta_min * time_step + 0.5 * (self.beta_max - self.beta_min) * ( | |
time_step**2 | |
) | |
def get_beta_t(self, time_step): | |
return self.beta_min + (self.beta_max - self.beta_min) * time_step | |
def forward_diffusion(self, x0, diffusion_step): | |
time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1) | |
cum_beta = self.get_cum_beta(time_step) | |
mean = x0 * torch.exp(-0.5 * cum_beta / (self.sigma**2)) | |
variance = (self.sigma**2) * (1 - torch.exp(-cum_beta / (self.sigma**2))) | |
z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False) | |
xt = mean + z * torch.sqrt(variance) * self.noise_factor | |
return xt, z | |
def cal_dxt( | |
self, xt, condition_embedding, x_mask, reference_embedding, diffusion_step, h | |
): | |
time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1) | |
cum_beta = self.get_cum_beta(time_step=time_step) | |
beta_t = self.get_beta_t(time_step=time_step) | |
x0_pred = self.diff_estimator( | |
xt, condition_embedding, x_mask, reference_embedding, diffusion_step | |
) | |
mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2)) | |
noise_pred = xt - mean_pred | |
variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2))) | |
logp = -noise_pred / (variance + 1e-8) | |
dxt = -0.5 * h * beta_t * (logp + xt / (self.sigma**2)) | |
return dxt | |
def reverse_diffusion( | |
self, z, condition_embedding, x_mask, reference_embedding, n_timesteps | |
): | |
h = 1.0 / max(n_timesteps, 1) | |
xt = z | |
for i in range(n_timesteps): | |
t = (1.0 - (i + 0.5) * h) * torch.ones( | |
z.shape[0], dtype=z.dtype, device=z.device | |
) | |
dxt = self.cal_dxt( | |
xt, | |
condition_embedding, | |
x_mask, | |
reference_embedding, | |
diffusion_step=t, | |
h=h, | |
) | |
xt_ = xt - dxt | |
if self.cfg.ode_solve_method == "midpoint": | |
x_mid = 0.5 * (xt_ + xt) | |
dxt = self.cal_dxt( | |
x_mid, | |
condition_embedding, | |
x_mask, | |
reference_embedding, | |
diffusion_step=t + 0.5 * h, | |
h=h, | |
) | |
xt = xt - dxt | |
elif self.cfg.ode_solve_method == "euler": | |
xt = xt_ | |
return xt | |
def reverse_diffusion_from_t( | |
self, z, condition_embedding, x_mask, reference_embedding, n_timesteps, t_start | |
): | |
h = t_start / max(n_timesteps, 1) | |
xt = z | |
for i in range(n_timesteps): | |
t = (t_start - (i + 0.5) * h) * torch.ones( | |
z.shape[0], dtype=z.dtype, device=z.device | |
) | |
dxt = self.cal_dxt( | |
xt, | |
x_mask, | |
condition_embedding, | |
x_mask, | |
reference_embedding, | |
diffusion_step=t, | |
h=h, | |
) | |
xt_ = xt - dxt | |
if self.cfg.ode_solve_method == "midpoint": | |
x_mid = 0.5 * (xt_ + xt) | |
dxt = self.cal_dxt( | |
x_mid, | |
condition_embedding, | |
x_mask, | |
reference_embedding, | |
diffusion_step=t + 0.5 * h, | |
h=h, | |
) | |
xt = xt - dxt | |
elif self.cfg.ode_solve_method == "euler": | |
xt = xt_ | |
return xt | |
class SinusoidalPosEmb(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, None] * emb[None, :] | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
class Linear2(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
self.linear_1 = nn.Linear(dim, dim * 2) | |
self.linear_2 = nn.Linear(dim * 2, dim) | |
self.linear_1.weight.data.normal_(0.0, 0.02) | |
self.linear_2.weight.data.normal_(0.0, 0.02) | |
def forward(self, x): | |
x = self.linear_1(x) | |
x = F.silu(x) | |
x = self.linear_2(x) | |
return x | |
class StyleAdaptiveLayerNorm(nn.Module): | |
def __init__(self, normalized_shape, eps=1e-5): | |
super().__init__() | |
self.in_dim = normalized_shape | |
self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False) | |
self.style = nn.Linear(self.in_dim, self.in_dim * 2) | |
self.style.bias.data[: self.in_dim] = 1 | |
self.style.bias.data[self.in_dim :] = 0 | |
def forward(self, x, condition): | |
# x: (B, T, d); condition: (B, T, d) | |
style = self.style(torch.mean(condition, dim=1, keepdim=True)) | |
gamma, beta = style.chunk(2, -1) | |
out = self.norm(x) | |
out = gamma * out + beta | |
return out | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout, max_len=5000): | |
super().__init__() | |
self.dropout = dropout | |
position = torch.arange(max_len).unsqueeze(1) | |
div_term = torch.exp( | |
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) | |
) | |
pe = torch.zeros(max_len, 1, d_model) | |
pe[:, 0, 0::2] = torch.sin(position * div_term) | |
pe[:, 0, 1::2] = torch.cos(position * div_term) | |
self.register_buffer("pe", pe) | |
def forward(self, x): | |
x = x + self.pe[: x.size(0)] | |
return F.dropout(x, self.dropout, training=self.training) | |
class TransformerFFNLayer(nn.Module): | |
def __init__( | |
self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout | |
): | |
super().__init__() | |
self.encoder_hidden = encoder_hidden | |
self.conv_filter_size = conv_filter_size | |
self.conv_kernel_size = conv_kernel_size | |
self.encoder_dropout = encoder_dropout | |
self.ffn_1 = nn.Conv1d( | |
self.encoder_hidden, | |
self.conv_filter_size, | |
self.conv_kernel_size, | |
padding=self.conv_kernel_size // 2, | |
) | |
self.ffn_1.weight.data.normal_(0.0, 0.02) | |
self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden) | |
self.ffn_2.weight.data.normal_(0.0, 0.02) | |
def forward(self, x): | |
# x: (B, T, d) | |
x = self.ffn_1(x.permute(0, 2, 1)).permute( | |
0, 2, 1 | |
) # (B, T, d) -> (B, d, T) -> (B, T, d) | |
x = F.silu(x) | |
x = F.dropout(x, self.encoder_dropout, training=self.training) | |
x = self.ffn_2(x) | |
return x | |
class TransformerFFNLayerOld(nn.Module): | |
def __init__( | |
self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout | |
): | |
super().__init__() | |
self.encoder_hidden = encoder_hidden | |
self.conv_filter_size = conv_filter_size | |
self.conv_kernel_size = conv_kernel_size | |
self.encoder_dropout = encoder_dropout | |
self.ffn_1 = nn.Linear(self.encoder_hidden, self.conv_filter_size) | |
self.ffn_1.weight.data.normal_(0.0, 0.02) | |
self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden) | |
self.ffn_2.weight.data.normal_(0.0, 0.02) | |
def forward(self, x): | |
x = self.ffn_1(x) | |
x = F.silu(x) | |
x = F.dropout(x, self.encoder_dropout, training=self.training) | |
x = self.ffn_2(x) | |
return x | |
class TransformerEncoderLayer(nn.Module): | |
def __init__( | |
self, | |
encoder_hidden, | |
encoder_head, | |
conv_filter_size, | |
conv_kernel_size, | |
encoder_dropout, | |
use_cln, | |
use_skip_connection, | |
use_new_ffn, | |
add_diff_step, | |
): | |
super().__init__() | |
self.encoder_hidden = encoder_hidden | |
self.encoder_head = encoder_head | |
self.conv_filter_size = conv_filter_size | |
self.conv_kernel_size = conv_kernel_size | |
self.encoder_dropout = encoder_dropout | |
self.use_cln = use_cln | |
self.use_skip_connection = use_skip_connection | |
self.use_new_ffn = use_new_ffn | |
self.add_diff_step = add_diff_step | |
if not self.use_cln: | |
self.ln_1 = nn.LayerNorm(self.encoder_hidden) | |
self.ln_2 = nn.LayerNorm(self.encoder_hidden) | |
else: | |
self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden) | |
self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden) | |
self.self_attn = nn.MultiheadAttention( | |
self.encoder_hidden, self.encoder_head, batch_first=True | |
) | |
if self.use_new_ffn: | |
self.ffn = TransformerFFNLayer( | |
self.encoder_hidden, | |
self.conv_filter_size, | |
self.conv_kernel_size, | |
self.encoder_dropout, | |
) | |
else: | |
self.ffn = TransformerFFNLayerOld( | |
self.encoder_hidden, | |
self.conv_filter_size, | |
self.conv_kernel_size, | |
self.encoder_dropout, | |
) | |
if self.use_skip_connection: | |
self.skip_linear = nn.Linear(self.encoder_hidden * 2, self.encoder_hidden) | |
self.skip_linear.weight.data.normal_(0.0, 0.02) | |
self.skip_layernorm = nn.LayerNorm(self.encoder_hidden) | |
if self.add_diff_step: | |
self.diff_step_emb = SinusoidalPosEmb(dim=self.encoder_hidden) | |
# self.diff_step_projection = nn.linear(self.encoder_hidden, self.encoder_hidden) | |
# self.encoder_hidden.weight.data.normal_(0.0, 0.02) | |
self.diff_step_projection = Linear2(self.encoder_hidden) | |
def forward( | |
self, x, key_padding_mask, conditon=None, skip_res=None, diffusion_step=None | |
): | |
# x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d); skip_res: (B, T, d); diffusion_step: (B,) | |
if self.use_skip_connection and skip_res != None: | |
x = torch.cat([x, skip_res], dim=-1) # (B, T, 2*d) | |
x = self.skip_linear(x) | |
x = self.skip_layernorm(x) | |
if self.add_diff_step and diffusion_step != None: | |
diff_step_embedding = self.diff_step_emb(diffusion_step) | |
diff_step_embedding = self.diff_step_projection(diff_step_embedding) | |
x = x + diff_step_embedding.unsqueeze(1) | |
residual = x | |
# pre norm | |
if self.use_cln: | |
x = self.ln_1(x, conditon) | |
else: | |
x = self.ln_1(x) | |
# self attention | |
if key_padding_mask != None: | |
key_padding_mask_input = ~(key_padding_mask.bool()) | |
else: | |
key_padding_mask_input = None | |
x, _ = self.self_attn( | |
query=x, key=x, value=x, key_padding_mask=key_padding_mask_input | |
) | |
x = F.dropout(x, self.encoder_dropout, training=self.training) | |
x = residual + x | |
# pre norm | |
residual = x | |
if self.use_cln: | |
x = self.ln_2(x, conditon) | |
else: | |
x = self.ln_2(x) | |
# ffn | |
x = self.ffn(x) | |
x = residual + x | |
return x | |
class TransformerEncoder(nn.Module): | |
def __init__( | |
self, | |
enc_emb_tokens=None, | |
encoder_layer=None, | |
encoder_hidden=None, | |
encoder_head=None, | |
conv_filter_size=None, | |
conv_kernel_size=None, | |
encoder_dropout=None, | |
use_cln=None, | |
use_skip_connection=None, | |
use_new_ffn=None, | |
add_diff_step=None, | |
cfg=None, | |
): | |
super().__init__() | |
self.encoder_layer = ( | |
encoder_layer if encoder_layer is not None else cfg.encoder_layer | |
) | |
self.encoder_hidden = ( | |
encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden | |
) | |
self.encoder_head = ( | |
encoder_head if encoder_head is not None else cfg.encoder_head | |
) | |
self.conv_filter_size = ( | |
conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size | |
) | |
self.conv_kernel_size = ( | |
conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size | |
) | |
self.encoder_dropout = ( | |
encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout | |
) | |
self.use_cln = use_cln if use_cln is not None else cfg.use_cln | |
self.use_skip_connection = ( | |
use_skip_connection | |
if use_skip_connection is not None | |
else cfg.use_skip_connection | |
) | |
self.add_diff_step = ( | |
add_diff_step if add_diff_step is not None else cfg.add_diff_step | |
) | |
self.use_new_ffn = use_new_ffn if use_new_ffn is not None else cfg.use_new_ffn | |
if enc_emb_tokens != None: | |
self.use_enc_emb = True | |
self.enc_emb_tokens = enc_emb_tokens | |
else: | |
self.use_enc_emb = False | |
self.position_emb = PositionalEncoding( | |
self.encoder_hidden, self.encoder_dropout | |
) | |
self.layers = nn.ModuleList([]) | |
if self.use_skip_connection: | |
self.layers.extend( | |
[ | |
TransformerEncoderLayer( | |
self.encoder_hidden, | |
self.encoder_head, | |
self.conv_filter_size, | |
self.conv_kernel_size, | |
self.encoder_dropout, | |
self.use_cln, | |
use_skip_connection=False, | |
use_new_ffn=self.use_new_ffn, | |
add_diff_step=self.add_diff_step, | |
) | |
for i in range( | |
(self.encoder_layer + 1) // 2 | |
) # for example: 12 -> 6; 13 -> 7 | |
] | |
) | |
self.layers.extend( | |
[ | |
TransformerEncoderLayer( | |
self.encoder_hidden, | |
self.encoder_head, | |
self.conv_filter_size, | |
self.conv_kernel_size, | |
self.encoder_dropout, | |
self.use_cln, | |
use_skip_connection=True, | |
use_new_ffn=self.use_new_ffn, | |
add_diff_step=self.add_diff_step, | |
) | |
for i in range( | |
self.encoder_layer - (self.encoder_layer + 1) // 2 | |
) # 12 -> 6; 13 -> 6 | |
] | |
) | |
else: | |
self.layers.extend( | |
[ | |
TransformerEncoderLayer( | |
self.encoder_hidden, | |
self.encoder_head, | |
self.conv_filter_size, | |
self.conv_kernel_size, | |
self.encoder_dropout, | |
self.use_cln, | |
use_new_ffn=self.use_new_ffn, | |
add_diff_step=self.add_diff_step, | |
use_skip_connection=False, | |
) | |
for i in range(self.encoder_layer) | |
] | |
) | |
if self.use_cln: | |
self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden) | |
else: | |
self.last_ln = nn.LayerNorm(self.encoder_hidden) | |
if self.add_diff_step: | |
self.diff_step_emb = SinusoidalPosEmb(dim=self.encoder_hidden) | |
# self.diff_step_projection = nn.linear(self.encoder_hidden, self.encoder_hidden) | |
# self.encoder_hidden.weight.data.normal_(0.0, 0.02) | |
self.diff_step_projection = Linear2(self.encoder_hidden) | |
def forward(self, x, key_padding_mask, condition=None, diffusion_step=None): | |
if len(x.shape) == 2 and self.use_enc_emb: | |
x = self.enc_emb_tokens(x) | |
x = self.position_emb(x) | |
else: | |
x = self.position_emb(x) # (B, T, d) | |
if self.add_diff_step and diffusion_step != None: | |
diff_step_embedding = self.diff_step_emb(diffusion_step) | |
diff_step_embedding = self.diff_step_projection(diff_step_embedding) | |
x = x + diff_step_embedding.unsqueeze(1) | |
if self.use_skip_connection: | |
skip_res_list = [] | |
# down | |
for layer in self.layers[: self.encoder_layer // 2]: | |
x = layer(x, key_padding_mask, condition) | |
res = x | |
skip_res_list.append(res) | |
# middle | |
for layer in self.layers[ | |
self.encoder_layer // 2 : (self.encoder_layer + 1) // 2 | |
]: | |
x = layer(x, key_padding_mask, condition) | |
# up | |
for layer in self.layers[(self.encoder_layer + 1) // 2 :]: | |
skip_res = skip_res_list.pop() | |
x = layer(x, key_padding_mask, condition, skip_res) | |
else: | |
for layer in self.layers: | |
x = layer(x, key_padding_mask, condition) | |
if self.use_cln: | |
x = self.last_ln(x, condition) | |
else: | |
x = self.last_ln(x) | |
return x | |
class DiffTransformer(nn.Module): | |
def __init__( | |
self, | |
encoder_layer=None, | |
encoder_hidden=None, | |
encoder_head=None, | |
conv_filter_size=None, | |
conv_kernel_size=None, | |
encoder_dropout=None, | |
use_cln=None, | |
use_skip_connection=None, | |
use_new_ffn=None, | |
add_diff_step=None, | |
cat_diff_step=None, | |
in_dim=None, | |
out_dim=None, | |
cond_dim=None, | |
cfg=None, | |
): | |
super().__init__() | |
self.encoder_layer = ( | |
encoder_layer if encoder_layer is not None else cfg.encoder_layer | |
) | |
self.encoder_hidden = ( | |
encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden | |
) | |
self.encoder_head = ( | |
encoder_head if encoder_head is not None else cfg.encoder_head | |
) | |
self.conv_filter_size = ( | |
conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size | |
) | |
self.conv_kernel_size = ( | |
conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size | |
) | |
self.encoder_dropout = ( | |
encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout | |
) | |
self.use_cln = use_cln if use_cln is not None else cfg.use_cln | |
self.use_skip_connection = ( | |
use_skip_connection | |
if use_skip_connection is not None | |
else cfg.use_skip_connection | |
) | |
self.use_new_ffn = use_new_ffn if use_new_ffn is not None else cfg.use_new_ffn | |
self.add_diff_step = ( | |
add_diff_step if add_diff_step is not None else cfg.add_diff_step | |
) | |
self.cat_diff_step = ( | |
cat_diff_step if cat_diff_step is not None else cfg.cat_diff_step | |
) | |
self.in_dim = in_dim if in_dim is not None else cfg.in_dim | |
self.out_dim = out_dim if out_dim is not None else cfg.out_dim | |
self.cond_dim = cond_dim if cond_dim is not None else cfg.cond_dim | |
if self.in_dim != self.encoder_hidden: | |
self.in_linear = nn.Linear(self.in_dim, self.encoder_hidden) | |
self.in_linear.weight.data.normal_(0.0, 0.02) | |
else: | |
self.in_dim = None | |
if self.out_dim != self.encoder_hidden: | |
self.out_linear = nn.Linear(self.encoder_hidden, self.out_dim) | |
self.out_linear.weight.data.normal_(0.0, 0.02) | |
else: | |
self.out_dim = None | |
assert not ((self.cat_diff_step == True) and (self.add_diff_step == True)) | |
self.transformer_encoder = TransformerEncoder( | |
encoder_layer=self.encoder_layer, | |
encoder_hidden=self.encoder_hidden, | |
encoder_head=self.encoder_head, | |
conv_kernel_size=self.conv_kernel_size, | |
conv_filter_size=self.conv_filter_size, | |
encoder_dropout=self.encoder_dropout, | |
use_cln=self.use_cln, | |
use_skip_connection=self.use_skip_connection, | |
use_new_ffn=self.use_new_ffn, | |
add_diff_step=self.add_diff_step, | |
) | |
self.cond_project = nn.Linear(self.cond_dim, self.encoder_hidden) | |
self.cond_project.weight.data.normal_(0.0, 0.02) | |
self.cat_linear = nn.Linear(self.encoder_hidden * 2, self.encoder_hidden) | |
self.cat_linear.weight.data.normal_(0.0, 0.02) | |
if self.cat_diff_step: | |
self.diff_step_emb = SinusoidalPosEmb(dim=self.encoder_hidden) | |
self.diff_step_projection = Linear2(self.encoder_hidden) | |
def forward( | |
self, | |
x, | |
condition_embedding, | |
key_padding_mask=None, | |
reference_embedding=None, | |
diffusion_step=None, | |
): | |
# x: shape is (B, T, d_x) | |
# key_padding_mask: shape is (B, T), mask is 0 | |
# condition_embedding: from condition adapter, shape is (B, T, d_c) | |
# reference_embedding: from reference encoder, shape is (B, N, d_r), or (B, 1, d_r), or (B, d_r) | |
if self.in_linear != None: | |
x = self.in_linear(x) | |
condition_embedding = self.cond_project(condition_embedding) | |
x = torch.cat([x, condition_embedding], dim=-1) | |
x = self.cat_linear(x) | |
if self.cat_diff_step and diffusion_step != None: | |
diff_step_embedding = self.diff_step_emb(diffusion_step) | |
diff_step_embedding = self.diff_step_projection( | |
diff_step_embedding | |
).unsqueeze( | |
1 | |
) # (B, 1, d) | |
x = torch.cat([diff_step_embedding, x], dim=1) | |
if key_padding_mask != None: | |
key_padding_mask = torch.cat( | |
[ | |
key_padding_mask, | |
torch.ones(key_padding_mask.shape[0], 1).to( | |
key_padding_mask.device | |
), | |
], | |
dim=1, | |
) | |
x = self.transformer_encoder( | |
x, | |
key_padding_mask=key_padding_mask, | |
condition=reference_embedding, | |
diffusion_step=diffusion_step, | |
) | |
if self.cat_diff_step and diffusion_step != None: | |
x = x[:, 1:, :] | |
if self.out_linear != None: | |
x = self.out_linear(x) | |
return x | |
class ReferenceEncoder(nn.Module): | |
def __init__( | |
self, | |
encoder_layer=None, | |
encoder_hidden=None, | |
encoder_head=None, | |
conv_filter_size=None, | |
conv_kernel_size=None, | |
encoder_dropout=None, | |
use_skip_connection=None, | |
use_new_ffn=None, | |
ref_in_dim=None, | |
ref_out_dim=None, | |
use_query_emb=None, | |
num_query_emb=None, | |
cfg=None, | |
): | |
super().__init__() | |
self.encoder_layer = ( | |
encoder_layer if encoder_layer is not None else cfg.encoder_layer | |
) | |
self.encoder_hidden = ( | |
encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden | |
) | |
self.encoder_head = ( | |
encoder_head if encoder_head is not None else cfg.encoder_head | |
) | |
self.conv_filter_size = ( | |
conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size | |
) | |
self.conv_kernel_size = ( | |
conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size | |
) | |
self.encoder_dropout = ( | |
encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout | |
) | |
self.use_skip_connection = ( | |
use_skip_connection | |
if use_skip_connection is not None | |
else cfg.use_skip_connection | |
) | |
self.use_new_ffn = use_new_ffn if use_new_ffn is not None else cfg.use_new_ffn | |
self.in_dim = ref_in_dim if ref_in_dim is not None else cfg.ref_in_dim | |
self.out_dim = ref_out_dim if ref_out_dim is not None else cfg.ref_out_dim | |
self.use_query_emb = ( | |
use_query_emb if use_query_emb is not None else cfg.use_query_emb | |
) | |
self.num_query_emb = ( | |
num_query_emb if num_query_emb is not None else cfg.num_query_emb | |
) | |
if self.in_dim != self.encoder_hidden: | |
self.in_linear = nn.Linear(self.in_dim, self.encoder_hidden) | |
self.in_linear.weight.data.normal_(0.0, 0.02) | |
else: | |
self.in_dim = None | |
if self.out_dim != self.encoder_hidden: | |
self.out_linear = nn.Linear(self.encoder_hidden, self.out_dim) | |
self.out_linear.weight.data.normal_(0.0, 0.02) | |
else: | |
self.out_linear = None | |
self.transformer_encoder = TransformerEncoder( | |
encoder_layer=self.encoder_layer, | |
encoder_hidden=self.encoder_hidden, | |
encoder_head=self.encoder_head, | |
conv_kernel_size=self.conv_kernel_size, | |
conv_filter_size=self.conv_filter_size, | |
encoder_dropout=self.encoder_dropout, | |
use_new_ffn=self.use_new_ffn, | |
use_cln=False, | |
use_skip_connection=False, | |
add_diff_step=False, | |
) | |
if self.use_query_emb: | |
# 32 x 512 | |
self.query_embs = nn.Embedding(self.num_query_emb, self.encoder_hidden) | |
self.query_attn = nn.MultiheadAttention( | |
self.encoder_hidden, self.encoder_hidden // 64, batch_first=True | |
) | |
def forward(self, x_ref, key_padding_mask=None): | |
# x_ref: (B, T, d_ref) | |
# key_padding_mask: (B, T) | |
# return speaker embedding: x_spk | |
# if self.use_query_embs: shape is (B, N_query, d_out) | |
# else: shape is (B, T, d_out) | |
if self.in_linear != None: | |
# print('x_ref:',x_ref.shape) | |
x = self.in_linear(x_ref) | |
x = self.transformer_encoder( | |
x, key_padding_mask=key_padding_mask, condition=None, diffusion_step=None | |
) # B, T, d_out | |
if self.use_query_emb: | |
spk_query_emb = self.query_embs( | |
torch.arange(self.num_query_emb).to(x.device) | |
).repeat(x.shape[0], 1, 1) | |
# k/v b x t x d | |
# q b x n x d | |
spk_embs, _ = self.query_attn( | |
query=spk_query_emb, | |
key=x, | |
value=x, | |
key_padding_mask=( | |
~(key_padding_mask.bool()) if key_padding_mask is not None else None | |
), | |
) # B, N_query, d_out | |
if self.out_linear != None: | |
spk_embs = self.out_linear(spk_embs) | |
else: | |
spk_query_emb = None | |
# B x n x d | |
# b x t x d | |
return spk_embs, x | |
def pad(input_ele, mel_max_length=None): | |
if mel_max_length: | |
max_len = mel_max_length | |
else: | |
max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) | |
out_list = list() | |
for i, batch in enumerate(input_ele): | |
if len(batch.shape) == 1: | |
one_batch_padded = F.pad( | |
batch, (0, max_len - batch.size(0)), "constant", 0.0 | |
) | |
elif len(batch.shape) == 2: | |
one_batch_padded = F.pad( | |
batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 | |
) | |
out_list.append(one_batch_padded) | |
out_padded = torch.stack(out_list) | |
return out_padded | |
class FiLM(nn.Module): | |
def __init__(self, in_dim, cond_dim): | |
super().__init__() | |
self.gain = Linear(cond_dim, in_dim) | |
self.bias = Linear(cond_dim, in_dim) | |
nn.init.xavier_uniform_(self.gain.weight) | |
nn.init.constant_(self.gain.bias, 1) | |
nn.init.xavier_uniform_(self.bias.weight) | |
nn.init.constant_(self.bias.bias, 0) | |
def forward(self, x, condition): | |
gain = self.gain(condition) | |
bias = self.bias(condition) | |
if gain.dim() == 2: | |
gain = gain.unsqueeze(-1) | |
if bias.dim() == 2: | |
bias = bias.unsqueeze(-1) | |
return x * gain + bias | |
class Mish(nn.Module): | |
def forward(self, x): | |
return x * torch.tanh(F.softplus(x)) | |
def Conv1d(*args, **kwargs): | |
layer = nn.Conv1d(*args, **kwargs) | |
layer.weight.data.normal_(0.0, 0.02) | |
return layer | |
def Linear(*args, **kwargs): | |
layer = nn.Linear(*args, **kwargs) | |
layer.weight.data.normal_(0.0, 0.02) | |
return layer | |
class ResidualBlock(nn.Module): | |
def __init__(self, hidden_dim, attn_head, dilation, drop_out, has_cattn=False): | |
super().__init__() | |
self.hidden_dim = hidden_dim | |
self.dilation = dilation | |
self.has_cattn = has_cattn | |
self.attn_head = attn_head | |
self.drop_out = drop_out | |
self.dilated_conv = Conv1d( | |
hidden_dim, 2 * hidden_dim, 3, padding=dilation, dilation=dilation | |
) | |
self.diffusion_proj = Linear(hidden_dim, hidden_dim) | |
self.cond_proj = Conv1d(hidden_dim, hidden_dim * 2, 1) | |
self.out_proj = Conv1d(hidden_dim, hidden_dim * 2, 1) | |
if self.has_cattn: | |
self.attn = nn.MultiheadAttention( | |
hidden_dim, attn_head, 0.1, batch_first=True | |
) | |
self.film = FiLM(hidden_dim * 2, hidden_dim) | |
self.ln = nn.LayerNorm(hidden_dim) | |
self.dropout = nn.Dropout(self.drop_out) | |
def forward(self, x, x_mask, cond, diffusion_step, spk_query_emb): | |
diffusion_step = self.diffusion_proj(diffusion_step).unsqueeze(-1) # (B, d, 1) | |
cond = self.cond_proj(cond) # (B, 2*d, T) | |
y = x + diffusion_step | |
if x_mask != None: | |
y = y * x_mask.to(y.dtype)[:, None, :] # (B, 2*d, T) | |
if self.has_cattn: | |
y_ = y.transpose(1, 2) | |
y_ = self.ln(y_) | |
y_, _ = self.attn(y_, spk_query_emb, spk_query_emb) # (B, T, d) | |
y = self.dilated_conv(y) + cond # (B, 2*d, T) | |
if self.has_cattn: | |
y = self.film(y.transpose(1, 2), y_) # (B, T, 2*d) | |
y = y.transpose(1, 2) # (B, 2*d, T) | |
gate, filter_ = torch.chunk(y, 2, dim=1) | |
y = torch.sigmoid(gate) * torch.tanh(filter_) | |
y = self.out_proj(y) | |
residual, skip = torch.chunk(y, 2, dim=1) | |
if x_mask != None: | |
residual = residual * x_mask.to(y.dtype)[:, None, :] | |
skip = skip * x_mask.to(y.dtype)[:, None, :] | |
return (x + residual) / math.sqrt(2.0), skip | |
class DiffWaveNet(nn.Module): | |
def __init__( | |
self, | |
cfg=None, | |
): | |
super().__init__() | |
self.cfg = cfg | |
self.in_dim = cfg.input_size | |
self.hidden_dim = cfg.hidden_size | |
self.out_dim = cfg.out_size | |
self.num_layers = cfg.num_layers | |
self.cross_attn_per_layer = cfg.cross_attn_per_layer | |
self.dilation_cycle = cfg.dilation_cycle | |
self.attn_head = cfg.attn_head | |
self.drop_out = cfg.drop_out | |
self.in_proj = Conv1d(self.in_dim, self.hidden_dim, 1) | |
self.diffusion_embedding = SinusoidalPosEmb(self.hidden_dim) | |
self.mlp = nn.Sequential( | |
Linear(self.hidden_dim, self.hidden_dim * 4), | |
Mish(), | |
Linear(self.hidden_dim * 4, self.hidden_dim), | |
) | |
self.cond_ln = nn.LayerNorm(self.hidden_dim) | |
self.layers = nn.ModuleList( | |
[ | |
ResidualBlock( | |
self.hidden_dim, | |
self.attn_head, | |
2 ** (i % self.dilation_cycle), | |
self.drop_out, | |
has_cattn=(i % self.cross_attn_per_layer == 0), | |
) | |
for i in range(self.num_layers) | |
] | |
) | |
self.skip_proj = Conv1d(self.hidden_dim, self.hidden_dim, 1) | |
self.out_proj = Conv1d(self.hidden_dim, self.out_dim, 1) | |
nn.init.zeros_(self.out_proj.weight) | |
def forward( | |
self, | |
x, | |
condition_embedding, | |
key_padding_mask=None, | |
reference_embedding=None, | |
diffusion_step=None, | |
): | |
x = x.transpose(1, 2) # (B, T, d) -> (B, d, T) | |
x_mask = key_padding_mask | |
cond = condition_embedding | |
spk_query_emb = reference_embedding | |
diffusion_step = diffusion_step | |
cond = self.cond_ln(cond) | |
cond_input = cond.transpose(1, 2) | |
x_input = self.in_proj(x) | |
x_input = F.relu(x_input) | |
diffusion_step = self.diffusion_embedding(diffusion_step).to(x.dtype) | |
diffusion_step = self.mlp(diffusion_step) | |
skip = [] | |
for _, layer in enumerate(self.layers): | |
x_input, skip_connection = layer( | |
x_input, x_mask, cond_input, diffusion_step, spk_query_emb | |
) | |
skip.append(skip_connection) | |
x_input = torch.sum(torch.stack(skip), dim=0) / math.sqrt(self.num_layers) | |
x_out = self.skip_proj(x_input) | |
x_out = F.relu(x_out) | |
x_out = self.out_proj(x_out) # (B, 80, T) | |
x_out = x_out.transpose(1, 2) | |
return x_out | |
def override_config(base_config, new_config): | |
"""Update new configurations in the original dict with the new dict | |
Args: | |
base_config (dict): original dict to be overridden | |
new_config (dict): dict with new configurations | |
Returns: | |
dict: updated configuration dict | |
""" | |
for k, v in new_config.items(): | |
if type(v) == dict: | |
if k not in base_config.keys(): | |
base_config[k] = {} | |
base_config[k] = override_config(base_config[k], v) | |
else: | |
base_config[k] = v | |
return base_config | |
def get_lowercase_keys_config(cfg): | |
"""Change all keys in cfg to lower case | |
Args: | |
cfg (dict): dictionary that stores configurations | |
Returns: | |
dict: dictionary that stores configurations | |
""" | |
updated_cfg = dict() | |
for k, v in cfg.items(): | |
if type(v) == dict: | |
v = get_lowercase_keys_config(v) | |
updated_cfg[k.lower()] = v | |
return updated_cfg | |
def save_config(save_path, cfg): | |
"""Save configurations into a json file | |
Args: | |
save_path (str): path to save configurations | |
cfg (dict): dictionary that stores configurations | |
""" | |
with open(save_path, "w") as f: | |
json5.dump( | |
cfg, f, ensure_ascii=False, indent=4, quote_keys=True, sort_keys=True | |
) | |
class JsonHParams: | |
def __init__(self, **kwargs): | |
for k, v in kwargs.items(): | |
if type(v) == dict: | |
v = JsonHParams(**v) | |
self[k] = v | |
def keys(self): | |
return self.__dict__.keys() | |
def items(self): | |
return self.__dict__.items() | |
def values(self): | |
return self.__dict__.values() | |
def __len__(self): | |
return len(self.__dict__) | |
def __getitem__(self, key): | |
return getattr(self, key) | |
def __setitem__(self, key, value): | |
return setattr(self, key, value) | |
def __contains__(self, key): | |
return key in self.__dict__ | |
def __repr__(self): | |
return self.__dict__.__repr__() | |
class Noro_VCmodel(nn.Module): | |
def __init__(self, cfg, use_ref_noise=False): | |
super().__init__() | |
self.cfg = cfg | |
self.use_ref_noise = use_ref_noise | |
self.reference_encoder = ReferenceEncoder(cfg=cfg.reference_encoder) | |
if cfg.diffusion.diff_model_type == "WaveNet": | |
self.diffusion = Diffusion( | |
cfg=cfg.diffusion, | |
diff_model=DiffWaveNet(cfg=cfg.diffusion.diff_wavenet), | |
) | |
else: | |
raise NotImplementedError() | |
pitch_dim = 1 | |
self.content_f0_enc = nn.Sequential( | |
nn.LayerNorm( | |
cfg.vc_feature.content_feature_dim + pitch_dim | |
), # 768 (for mhubert) + 1 (for f0) | |
Rearrange("b t d -> b d t"), | |
nn.Conv1d( | |
cfg.vc_feature.content_feature_dim + pitch_dim, | |
cfg.vc_feature.hidden_dim, | |
kernel_size=3, | |
padding=1, | |
), | |
Rearrange("b d t -> b t d"), | |
) | |
self.reset_parameters() | |
def forward( | |
self, | |
x=None, | |
content_feature=None, | |
pitch=None, | |
x_ref=None, | |
x_mask=None, | |
x_ref_mask=None, | |
noisy_x_ref=None, | |
): | |
noisy_reference_embedding = None | |
noisy_condition_embedding = None | |
reference_embedding, encoded_x = self.reference_encoder( | |
x_ref=x_ref, key_padding_mask=x_ref_mask | |
) | |
# content_feature: B x T x D | |
# pitch: B x T x 1 | |
# B x t x D+1 | |
# 2B x T | |
condition_embedding = torch.cat([content_feature, pitch[:, :, None]], dim=-1) | |
condition_embedding = self.content_f0_enc(condition_embedding) | |
# 2B x T x D | |
if self.use_ref_noise: | |
# noisy_reference | |
noisy_reference_embedding, _ = self.reference_encoder( | |
x_ref=noisy_x_ref, key_padding_mask=x_ref_mask | |
) | |
combined_reference_embedding = ( | |
noisy_reference_embedding + reference_embedding | |
) / 2 | |
else: | |
combined_reference_embedding = reference_embedding | |
combined_condition_embedding = condition_embedding | |
diff_out = self.diffusion( | |
x=x, | |
condition_embedding=combined_condition_embedding, | |
x_mask=x_mask, | |
reference_embedding=combined_reference_embedding, | |
) | |
return ( | |
diff_out, | |
(reference_embedding, noisy_reference_embedding), | |
(condition_embedding, noisy_condition_embedding), | |
) | |
def inference( | |
self, | |
content_feature=None, | |
pitch=None, | |
x_ref=None, | |
x_ref_mask=None, | |
inference_steps=1000, | |
sigma=1.2, | |
): | |
reference_embedding, _ = self.reference_encoder( | |
x_ref=x_ref, key_padding_mask=x_ref_mask | |
) | |
condition_embedding = torch.cat([content_feature, pitch[:, :, None]], dim=-1) | |
condition_embedding = self.content_f0_enc(condition_embedding) | |
bsz, l, _ = condition_embedding.shape | |
if self.cfg.diffusion.diff_model_type == "WaveNet": | |
z = ( | |
torch.randn(bsz, l, self.cfg.diffusion.diff_wavenet.input_size).to( | |
condition_embedding.device | |
) | |
/ sigma | |
) | |
x0 = self.diffusion.reverse_diffusion( | |
z=z, | |
condition_embedding=condition_embedding, | |
x_mask=None, | |
reference_embedding=reference_embedding, | |
n_timesteps=inference_steps, | |
) | |
return x0 | |
def reset_parameters(self): | |
def _reset_parameters(m): | |
if isinstance(m, nn.MultiheadAttention): | |
if m._qkv_same_embed_dim: | |
nn.init.normal_(m.in_proj_weight, std=0.02) | |
else: | |
nn.init.normal_(m.q_proj_weight, std=0.02) | |
nn.init.normal_(m.k_proj_weight, std=0.02) | |
nn.init.normal_(m.v_proj_weight, std=0.02) | |
if m.in_proj_bias is not None: | |
nn.init.constant_(m.in_proj_bias, 0.0) | |
nn.init.constant_(m.out_proj.bias, 0.0) | |
if m.bias_k is not None: | |
nn.init.xavier_normal_(m.bias_k) | |
if m.bias_v is not None: | |
nn.init.xavier_normal_(m.bias_v) | |
elif ( | |
isinstance(m, nn.Conv1d) | |
or isinstance(m, nn.ConvTranspose1d) | |
or isinstance(m, nn.Conv2d) | |
or isinstance(m, nn.ConvTranspose2d) | |
): | |
m.weight.data.normal_(0.0, 0.02) | |
elif isinstance(m, nn.Linear): | |
m.weight.data.normal_(mean=0.0, std=0.02) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.Embedding): | |
m.weight.data.normal_(mean=0.0, std=0.02) | |
if m.padding_idx is not None: | |
m.weight.data[m.padding_idx].zero_() | |
self.apply(_reset_parameters) | |