jzq11111's picture
Upload folder using huggingface_hub
a3e05e8 verified
import torch
import torch.nn as nn
from modules.audio_tokenizer.quantize import ResidualVQ
from modules.audio_tokenizer.vocos import VocosBackbone
from modules.audio_tokenizer.transformer import TransformerEncoder
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
class RepCodec(nn.Module):
def __init__(
self,
codebook_size=8192,
hidden_size=1024,
codebook_dim=8,
vocos_dim=384,
vocos_intermediate_dim=2048,
vocos_num_layers=12,
num_quantizers=1,
use_timbre_encoder=False,
cfg=None,
):
super().__init__()
codebook_size = (
cfg.codebook_size
if cfg is not None and hasattr(cfg, "codebook_size")
else codebook_size
)
codebook_dim = (
cfg.codebook_dim
if cfg is not None and hasattr(cfg, "codebook_dim")
else codebook_dim
)
hidden_size = (
cfg.hidden_size
if cfg is not None and hasattr(cfg, "hidden_size")
else hidden_size
)
vocos_dim = (
cfg.vocos_dim
if cfg is not None and hasattr(cfg, "vocos_dim")
else vocos_dim
)
vocos_intermediate_dim = (
cfg.vocos_intermediate_dim
if cfg is not None and hasattr(cfg, "vocos_dim")
else vocos_intermediate_dim
)
vocos_num_layers = (
cfg.vocos_num_layers
if cfg is not None and hasattr(cfg, "vocos_dim")
else vocos_num_layers
)
num_quantizers = (
cfg.num_quantizers
if cfg is not None and hasattr(cfg, "num_quantizers")
else num_quantizers
)
use_timbre_encoder = (
cfg.use_timbre_encoder
if cfg is not None and hasattr(cfg, "use_timbre_encoder")
else use_timbre_encoder
)
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.hidden_size = hidden_size
self.vocos_dim = vocos_dim
self.vocos_intermediate_dim = vocos_intermediate_dim
self.vocos_num_layers = vocos_num_layers
self.num_quantizers = num_quantizers
self.use_timbre_encoder = use_timbre_encoder
self.encoder = nn.Sequential(
VocosBackbone(
input_channels=self.hidden_size,
dim=384,
intermediate_dim=2048,
num_layers=12,
adanorm_num_embeddings=None
),
nn.Linear(384, self.hidden_size)
)
self.decoder = nn.Sequential(
VocosBackbone(
input_channels=self.hidden_size,
dim=384,
intermediate_dim=2048,
num_layers=12,
adanorm_num_embeddings=None
),
nn.Linear(384, self.hidden_size)
)
self.quantizer = ResidualVQ(
input_dim=hidden_size,
num_quantizers=num_quantizers,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_type="fvq",
quantizer_dropout=0.0,
commitment=0.15,
codebook_loss_weight=1.0,
use_l2_normlize=True,
)
if self.use_timbre_encoder: #TODO: write encoder hidden (256) as a hyparam
self.timbre_in = nn.Linear(hidden_size, 256)
self.timbre_encoder = TransformerEncoder(
enc_emb_tokens=None,
encoder_layer=4,
encoder_hidden=256,
encoder_head=4,
conv_filter_size=1024,
conv_kernel_size=5,
encoder_dropout=0.1,
use_pe=False,
cfg=None,
)
self.timbre_out = nn.Linear(256, hidden_size)
self.timbre_linear = nn.Linear(hidden_size, hidden_size * 2)
self.timbre_linear.bias.data[:hidden_size] = 1
self.timbre_linear.bias.data[hidden_size:] = 0
self.timbre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.enc_ln = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.reset_parameters()
def forward(self, x):
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
if self.use_timbre_encoder:
x_timbre = x
x = x.transpose(1, 2)
x = self.enc_ln(x)
x = x.transpose(1, 2)
(
quantized_out,
all_indices,
all_commit_losses,
all_codebook_losses,
_,
) = self.quantizer(x)
if self.use_timbre_encoder:
x_timbre = x_timbre.transpose(1, 2)
x_timbre = self.timbre_in(x_timbre)
x_timbre = self.timbre_encoder(x_timbre, None, None)
x_timbre = self.timbre_out(x_timbre)
x_timbre = x_timbre.transpose(1, 2)
spk_embs = torch.mean(x_timbre, dim=2)
style = self.timbre_linear(spk_embs).unsqueeze(2) # (B, 2d, 1)
gamma, beta = style.chunk(2, 1) # (B, d, 1)
quantized_out = quantized_out.transpose(1, 2)
quantized_out = self.timbre_norm(quantized_out)
quantized_out = quantized_out.transpose(1, 2)
quantized_out = quantized_out * gamma + beta
x_rec = self.decoder(quantized_out)
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
all_indices = all_indices
return x_rec, codebook_loss, all_indices
def quantize(self, x):
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
if self.use_timbre_encoder:
x = x.transpose(1, 2)
x = self.enc_ln(x)
x = x.transpose(1, 2)
(
quantized_out,
all_indices,
all_commit_losses,
all_codebook_losses,
_,
) = self.quantizer(x)
if all_indices.shape[0] == 1:
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
return all_indices, quantized_out.transpose(1, 2)
def reset_parameters(self):
self.apply(init_weights)