Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Compression models or wrapper around existing models. | |
Also defines the main interface that a model must follow to be usable as an audio tokenizer. | |
""" | |
from abc import ABC, abstractmethod | |
from dataclasses import dataclass, field | |
import logging | |
import math | |
from pathlib import Path | |
import typing as tp | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch import einsum | |
import torch.nn.functional as F | |
from torch.nn.utils import spectral_norm, weight_norm | |
import logging | |
import warnings | |
from einops import rearrange, repeat | |
import omegaconf | |
# import flashy | |
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', | |
'time_group_norm']) | |
def dict_from_config(cfg: omegaconf.DictConfig) -> dict: | |
"""Convenience function to map an omegaconf configuration to a dictionary. | |
Args: | |
cfg (omegaconf.DictConfig): Original configuration to map to dict. | |
Returns: | |
dict: Config as dictionary object. | |
""" | |
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) | |
assert isinstance(dct, dict) | |
return dct | |
class QuantizedResult: | |
x: torch.Tensor | |
codes: torch.Tensor | |
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. | |
penalty: tp.Optional[torch.Tensor] = None | |
metrics: dict = field(default_factory=dict) | |
class BaseQuantizer(nn.Module): | |
"""Base class for quantizers. | |
""" | |
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: | |
""" | |
Given input tensor x, returns first the quantized (or approximately quantized) | |
representation along with quantized codes, bandwidth, and any penalty term for the loss. | |
Finally, this returns a dict of metrics to update logging etc. | |
Frame rate must be passed so that the bandwidth is properly computed. | |
""" | |
raise NotImplementedError() | |
def encode(self, x: torch.Tensor) -> torch.Tensor: | |
"""Encode a given input tensor with the specified sample rate at the given bandwidth.""" | |
raise NotImplementedError() | |
def decode(self, codes: torch.Tensor) -> torch.Tensor: | |
"""Decode the given codes to the quantized representation.""" | |
raise NotImplementedError() | |
def total_codebooks(self): | |
"""Total number of codebooks.""" | |
raise NotImplementedError() | |
def num_codebooks(self): | |
"""Number of active codebooks.""" | |
raise NotImplementedError() | |
def set_num_codebooks(self, n: int): | |
"""Set the number of active codebooks.""" | |
raise NotImplementedError() | |
class CompressionModel(ABC, nn.Module): | |
"""Base API for all compression model that aim at being used as audio tokenizers | |
with a language model. | |
""" | |
def forward(self, x: torch.Tensor) -> QuantizedResult: | |
... | |
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | |
"""See `EncodecModel.encode`.""" | |
... | |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): | |
"""See `EncodecModel.decode`.""" | |
... | |
def decode_latent(self, codes: torch.Tensor): | |
"""Decode from the discrete codes to continuous latent space.""" | |
... | |
def channels(self) -> int: | |
... | |
def frame_rate(self) -> float: | |
... | |
def sample_rate(self) -> int: | |
... | |
def cardinality(self) -> int: | |
... | |
def num_codebooks(self) -> int: | |
... | |
def total_codebooks(self) -> int: | |
... | |
def set_num_codebooks(self, n: int): | |
"""Set the active number of codebooks used by the quantizer.""" | |
... | |
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): | |
assert norm in CONV_NORMALIZATIONS | |
if norm == 'weight_norm': | |
return weight_norm(module) | |
elif norm == 'spectral_norm': | |
return spectral_norm(module) | |
else: | |
# We already check was in CONV_NORMALIZATION, so any other choice | |
# doesn't need reparametrization. | |
return module | |
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs): | |
"""Return the proper normalization module. If causal is True, this will ensure the returned | |
module is causal, or return an error if the normalization doesn't support causal evaluation. | |
""" | |
assert norm in CONV_NORMALIZATIONS | |
if norm == 'time_group_norm': | |
if causal: | |
raise ValueError("GroupNorm doesn't support causal evaluation.") | |
assert isinstance(module, nn.modules.conv._ConvNd) | |
return nn.GroupNorm(1, module.out_channels, **norm_kwargs) | |
else: | |
return nn.Identity() | |
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, | |
padding_total: int = 0) -> int: | |
"""See `pad_for_conv1d`.""" | |
length = x.shape[-1] | |
n_frames = (length - kernel_size + padding_total) / stride + 1 | |
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) | |
return ideal_length - length | |
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): | |
"""Pad for a convolution to make sure that the last window is full. | |
Extra padding is added at the end. This is required to ensure that we can rebuild | |
an output of the same length, as otherwise, even with padding, some time steps | |
might get removed. | |
For instance, with total padding = 4, kernel size = 4, stride = 2: | |
0 0 1 2 3 4 5 0 0 # (0s are padding) | |
1 2 3 # (output frames of a convolution, last 0 is never used) | |
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) | |
1 2 3 4 # once you removed padding, we are missing one time step ! | |
""" | |
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) | |
return F.pad(x, (0, extra_padding)) | |
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): | |
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input. | |
If this is the case, we insert extra 0 padding to the right before the reflection happen. | |
""" | |
length = x.shape[-1] | |
padding_left, padding_right = paddings | |
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
if mode == 'reflect': | |
max_pad = max(padding_left, padding_right) | |
extra_pad = 0 | |
if length <= max_pad: | |
extra_pad = max_pad - length + 1 | |
x = F.pad(x, (0, extra_pad)) | |
padded = F.pad(x, paddings, mode, value) | |
end = padded.shape[-1] - extra_pad | |
return padded[..., :end] | |
else: | |
return F.pad(x, paddings, mode, value) | |
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): | |
"""Remove padding from x, handling properly zero padding. Only for 1d!""" | |
padding_left, padding_right = paddings | |
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
assert (padding_left + padding_right) <= x.shape[-1] | |
end = x.shape[-1] - padding_right | |
return x[..., padding_left: end] | |
class NormConv1d(nn.Module): | |
"""Wrapper around Conv1d and normalization applied to this conv | |
to provide a uniform interface across normalization approaches. | |
""" | |
def __init__(self, *args, causal: bool = False, norm: str = 'none', | |
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): | |
super().__init__() | |
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) | |
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) | |
self.norm_type = norm | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.norm(x) | |
return x | |
class NormConv2d(nn.Module): | |
"""Wrapper around Conv2d and normalization applied to this conv | |
to provide a uniform interface across normalization approaches. | |
""" | |
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): | |
super().__init__() | |
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) | |
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) | |
self.norm_type = norm | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.norm(x) | |
return x | |
class NormConvTranspose1d(nn.Module): | |
"""Wrapper around ConvTranspose1d and normalization applied to this conv | |
to provide a uniform interface across normalization approaches. | |
""" | |
def __init__(self, *args, causal: bool = False, norm: str = 'none', | |
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): | |
super().__init__() | |
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) | |
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) | |
self.norm_type = norm | |
def forward(self, x): | |
x = self.convtr(x) | |
x = self.norm(x) | |
return x | |
class NormConvTranspose2d(nn.Module): | |
"""Wrapper around ConvTranspose2d and normalization applied to this conv | |
to provide a uniform interface across normalization approaches. | |
""" | |
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): | |
super().__init__() | |
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) | |
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) | |
def forward(self, x): | |
x = self.convtr(x) | |
x = self.norm(x) | |
return x | |
class StreamableConv1d(nn.Module): | |
"""Conv1d with some builtin handling of asymmetric or causal padding | |
and normalization. | |
""" | |
def __init__(self, in_channels: int, out_channels: int, | |
kernel_size: int, stride: int = 1, dilation: int = 1, | |
groups: int = 1, bias: bool = True, causal: bool = False, | |
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, | |
pad_mode: str = 'reflect'): | |
super().__init__() | |
# warn user on unusual setup between dilation and stride | |
if stride > 1 and dilation > 1: | |
warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1" | |
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).") | |
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, | |
dilation=dilation, groups=groups, bias=bias, causal=causal, | |
norm=norm, norm_kwargs=norm_kwargs) | |
self.causal = causal | |
self.pad_mode = pad_mode | |
def forward(self, x): | |
B, C, T = x.shape | |
kernel_size = self.conv.conv.kernel_size[0] | |
stride = self.conv.conv.stride[0] | |
dilation = self.conv.conv.dilation[0] | |
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations | |
padding_total = kernel_size - stride | |
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) | |
if self.causal: | |
# Left padding for causal | |
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) | |
else: | |
# Asymmetric padding required for odd strides | |
padding_right = padding_total // 2 | |
padding_left = padding_total - padding_right | |
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) | |
return self.conv(x) | |
class StreamableConvTranspose1d(nn.Module): | |
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding | |
and normalization. | |
""" | |
def __init__(self, in_channels: int, out_channels: int, | |
kernel_size: int, stride: int = 1, causal: bool = False, | |
norm: str = 'none', trim_right_ratio: float = 1., | |
norm_kwargs: tp.Dict[str, tp.Any] = {}): | |
super().__init__() | |
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, | |
causal=causal, norm=norm, norm_kwargs=norm_kwargs) | |
self.causal = causal | |
self.trim_right_ratio = trim_right_ratio | |
assert self.causal or self.trim_right_ratio == 1., \ | |
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions" | |
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. | |
def forward(self, x): | |
kernel_size = self.convtr.convtr.kernel_size[0] | |
stride = self.convtr.convtr.stride[0] | |
padding_total = kernel_size - stride | |
y = self.convtr(x) | |
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be | |
# removed at the very end, when keeping only the right length for the output, | |
# as removing it here would require also passing the length at the matching layer | |
# in the encoder. | |
if self.causal: | |
# Trim the padding on the right according to the specified ratio | |
# if trim_right_ratio = 1.0, trim everything from right | |
padding_right = math.ceil(padding_total * self.trim_right_ratio) | |
padding_left = padding_total - padding_right | |
y = unpad1d(y, (padding_left, padding_right)) | |
else: | |
# Asymmetric padding required for odd strides | |
padding_right = padding_total // 2 | |
padding_left = padding_total - padding_right | |
y = unpad1d(y, (padding_left, padding_right)) | |
return y | |
class StreamableLSTM(nn.Module): | |
"""LSTM without worrying about the hidden state, nor the layout of the data. | |
Expects input as convolutional layout. | |
""" | |
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): | |
super().__init__() | |
self.skip = skip | |
self.lstm = nn.LSTM(dimension, dimension, num_layers) | |
def forward(self, x): | |
x = x.permute(2, 0, 1) | |
y, _ = self.lstm(x) | |
if self.skip: | |
y = y + x | |
y = y.permute(1, 2, 0) | |
return y | |
class SEANetResnetBlock(nn.Module): | |
"""Residual block from SEANet model. | |
Args: | |
dim (int): Dimension of the input/output. | |
kernel_sizes (list): List of kernel sizes for the convolutions. | |
dilations (list): List of dilations for the convolutions. | |
activation (str): Activation function. | |
activation_params (dict): Parameters to provide to the activation function. | |
norm (str): Normalization method. | |
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. | |
causal (bool): Whether to use fully causal convolution. | |
pad_mode (str): Padding mode for the convolutions. | |
compress (int): Reduced dimensionality in residual branches (from Demucs v3). | |
true_skip (bool): Whether to use true skip connection or a simple | |
(streamable) convolution as the skip connection. | |
""" | |
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], | |
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, | |
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, | |
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): | |
super().__init__() | |
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' | |
act = getattr(nn, activation) | |
hidden = dim // compress | |
block = [] | |
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): | |
in_chs = dim if i == 0 else hidden | |
out_chs = dim if i == len(kernel_sizes) - 1 else hidden | |
block += [ | |
act(**activation_params), | |
StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, | |
norm=norm, norm_kwargs=norm_params, | |
causal=causal, pad_mode=pad_mode), | |
] | |
self.block = nn.Sequential(*block) | |
self.shortcut: nn.Module | |
if true_skip: | |
self.shortcut = nn.Identity() | |
else: | |
self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, | |
causal=causal, pad_mode=pad_mode) | |
def forward(self, x): | |
return self.shortcut(x) + self.block(x) | |
class SEANetEncoder(nn.Module): | |
"""SEANet encoder. | |
Args: | |
channels (int): Audio channels. | |
dimension (int): Intermediate representation dimension. | |
n_filters (int): Base width for the model. | |
n_residual_layers (int): nb of residual layers. | |
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of | |
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here | |
that must match the decoder order. We use the decoder order as some models may only employ the decoder. | |
activation (str): Activation function. | |
activation_params (dict): Parameters to provide to the activation function. | |
norm (str): Normalization method. | |
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. | |
kernel_size (int): Kernel size for the initial convolution. | |
last_kernel_size (int): Kernel size for the initial convolution. | |
residual_kernel_size (int): Kernel size for the residual layers. | |
dilation_base (int): How much to increase the dilation with each layer. | |
causal (bool): Whether to use fully causal convolution. | |
pad_mode (str): Padding mode for the convolutions. | |
true_skip (bool): Whether to use true skip connection or a simple | |
(streamable) convolution as the skip connection in the residual network blocks. | |
compress (int): Reduced dimensionality in residual branches (from Demucs v3). | |
lstm (int): Number of LSTM layers at the end of the encoder. | |
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. | |
For the encoder, it corresponds to the N first blocks. | |
""" | |
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, | |
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, | |
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, | |
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, | |
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, | |
disable_norm_outer_blocks: int = 0): | |
super().__init__() | |
self.channels = channels | |
self.dimension = dimension | |
self.n_filters = n_filters | |
self.ratios = list(reversed(ratios)) | |
del ratios | |
self.n_residual_layers = n_residual_layers | |
self.hop_length = np.prod(self.ratios) | |
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks | |
self.disable_norm_outer_blocks = disable_norm_outer_blocks | |
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ | |
"Number of blocks for which to disable norm is invalid." \ | |
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." | |
act = getattr(nn, activation) | |
mult = 1 | |
model: tp.List[nn.Module] = [ | |
StreamableConv1d(channels, mult * n_filters, kernel_size, | |
norm='none' if self.disable_norm_outer_blocks >= 1 else norm, | |
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) | |
] | |
# Downsample to raw audio scale | |
for i, ratio in enumerate(self.ratios): | |
block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm | |
# Add residual layers | |
for j in range(n_residual_layers): | |
model += [ | |
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], | |
dilations=[dilation_base ** j, 1], | |
norm=block_norm, norm_params=norm_params, | |
activation=activation, activation_params=activation_params, | |
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] | |
# Add downsampling layers | |
model += [ | |
act(**activation_params), | |
StreamableConv1d(mult * n_filters, mult * n_filters * 2, | |
kernel_size=ratio * 2, stride=ratio, | |
norm=block_norm, norm_kwargs=norm_params, | |
causal=causal, pad_mode=pad_mode), | |
] | |
mult *= 2 | |
if lstm: | |
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] | |
model += [ | |
act(**activation_params), | |
StreamableConv1d(mult * n_filters, dimension, last_kernel_size, | |
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, | |
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) | |
] | |
self.model = nn.Sequential(*model) | |
def forward(self, x): | |
return self.model(x) | |
class SEANetDecoder(nn.Module): | |
"""SEANet decoder. | |
Args: | |
channels (int): Audio channels. | |
dimension (int): Intermediate representation dimension. | |
n_filters (int): Base width for the model. | |
n_residual_layers (int): nb of residual layers. | |
ratios (Sequence[int]): kernel size and stride ratios. | |
activation (str): Activation function. | |
activation_params (dict): Parameters to provide to the activation function. | |
final_activation (str): Final activation function after all convolutions. | |
final_activation_params (dict): Parameters to provide to the activation function. | |
norm (str): Normalization method. | |
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. | |
kernel_size (int): Kernel size for the initial convolution. | |
last_kernel_size (int): Kernel size for the initial convolution. | |
residual_kernel_size (int): Kernel size for the residual layers. | |
dilation_base (int): How much to increase the dilation with each layer. | |
causal (bool): Whether to use fully causal convolution. | |
pad_mode (str): Padding mode for the convolutions. | |
true_skip (bool): Whether to use true skip connection or a simple. | |
(streamable) convolution as the skip connection in the residual network blocks. | |
compress (int): Reduced dimensionality in residual branches (from Demucs v3). | |
lstm (int): Number of LSTM layers at the end of the encoder. | |
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. | |
For the decoder, it corresponds to the N last blocks. | |
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. | |
If equal to 1.0, it means that all the trimming is done at the right. | |
""" | |
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, | |
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, | |
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, | |
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, | |
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, | |
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, | |
disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0): | |
super().__init__() | |
self.dimension = dimension | |
self.channels = channels | |
self.n_filters = n_filters | |
self.ratios = ratios | |
del ratios | |
self.n_residual_layers = n_residual_layers | |
self.hop_length = np.prod(self.ratios) | |
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks | |
self.disable_norm_outer_blocks = disable_norm_outer_blocks | |
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ | |
"Number of blocks for which to disable norm is invalid." \ | |
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." | |
act = getattr(nn, activation) | |
mult = int(2 ** len(self.ratios)) | |
model: tp.List[nn.Module] = [ | |
StreamableConv1d(dimension, mult * n_filters, kernel_size, | |
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, | |
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) | |
] | |
if lstm: | |
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] | |
# Upsample to raw audio scale | |
for i, ratio in enumerate(self.ratios): | |
block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm | |
# Add upsampling layers | |
model += [ | |
act(**activation_params), | |
StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2, | |
kernel_size=ratio * 2, stride=ratio, | |
norm=block_norm, norm_kwargs=norm_params, | |
causal=causal, trim_right_ratio=trim_right_ratio), | |
] | |
# Add residual layers | |
for j in range(n_residual_layers): | |
model += [ | |
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], | |
dilations=[dilation_base ** j, 1], | |
activation=activation, activation_params=activation_params, | |
norm=block_norm, norm_params=norm_params, causal=causal, | |
pad_mode=pad_mode, compress=compress, true_skip=true_skip)] | |
mult //= 2 | |
# Add final layers | |
model += [ | |
act(**activation_params), | |
StreamableConv1d(n_filters, channels, last_kernel_size, | |
norm='none' if self.disable_norm_outer_blocks >= 1 else norm, | |
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) | |
] | |
# Add optional final activation to decoder (eg. tanh) | |
if final_activation is not None: | |
final_act = getattr(nn, final_activation) | |
final_activation_params = final_activation_params or {} | |
model += [ | |
final_act(**final_activation_params) | |
] | |
self.model = nn.Sequential(*model) | |
def forward(self, z): | |
y = self.model(z) | |
return y | |
def exists(val: tp.Optional[tp.Any]) -> bool: | |
return val is not None | |
def default(val: tp.Any, d: tp.Any) -> tp.Any: | |
return val if exists(val) else d | |
def l2norm(t): | |
return F.normalize(t, p=2, dim=-1) | |
def ema_inplace(moving_avg, new, decay: float): | |
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) | |
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): | |
return (x + epsilon) / (x.sum() + n_categories * epsilon) | |
def uniform_init(*shape: int): | |
t = torch.empty(shape) | |
nn.init.kaiming_uniform_(t) | |
return t | |
def sample_vectors(samples, num: int): | |
num_samples, device = samples.shape[0], samples.device | |
if num_samples >= num: | |
indices = torch.randperm(num_samples, device=device)[:num] | |
else: | |
indices = torch.randint(0, num_samples, (num,), device=device) | |
return samples[indices] | |
def kmeans(samples, num_clusters: int, num_iters: int = 10): | |
dim, dtype = samples.shape[-1], samples.dtype | |
means = sample_vectors(samples, num_clusters) | |
for _ in range(num_iters): | |
diffs = rearrange(samples, "n d -> n () d") - rearrange( | |
means, "c d -> () c d" | |
) | |
dists = -(diffs ** 2).sum(dim=-1) | |
buckets = dists.max(dim=-1).indices | |
bins = torch.bincount(buckets, minlength=num_clusters) | |
zero_mask = bins == 0 | |
bins_min_clamped = bins.masked_fill(zero_mask, 1) | |
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) | |
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) | |
new_means = new_means / bins_min_clamped[..., None] | |
means = torch.where(zero_mask[..., None], means, new_means) | |
return means, bins | |
def orthogonal_loss_fn(t): | |
# eq (2) from https://arxiv.org/abs/2112.00384 | |
n = t.shape[0] | |
normed_codes = l2norm(t) | |
identity = torch.eye(n, device=t.device) | |
cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) | |
return ((cosine_sim - identity) ** 2).sum() / (n ** 2) | |
class EuclideanCodebook(nn.Module): | |
"""Codebook with Euclidean distance. | |
Args: | |
dim (int): Dimension. | |
codebook_size (int): Codebook size. | |
kmeans_init (bool): Whether to use k-means to initialize the codebooks. | |
If set to true, run the k-means algorithm on the first training batch and use | |
the learned centroids as initialization. | |
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. | |
decay (float): Decay for exponential moving average over the codebooks. | |
epsilon (float): Epsilon value for numerical stability. | |
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes | |
that have an exponential moving average cluster size less than the specified threshold with | |
randomly selected vector from the current batch. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
codebook_size: int, | |
kmeans_init: int = False, | |
kmeans_iters: int = 10, | |
decay: float = 0.8, | |
epsilon: float = 1e-5, | |
threshold_ema_dead_code: int = 2, | |
): | |
super().__init__() | |
self.decay = decay | |
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros | |
embed = init_fn(codebook_size, dim) | |
self.codebook_size = codebook_size | |
self.kmeans_iters = kmeans_iters | |
self.epsilon = epsilon | |
self.threshold_ema_dead_code = threshold_ema_dead_code | |
self.register_buffer("inited", torch.Tensor([not kmeans_init])) | |
self.register_buffer("cluster_size", torch.zeros(codebook_size)) | |
self.register_buffer("embed", embed) | |
self.register_buffer("embed_avg", embed.clone()) | |
def init_embed_(self, data): | |
if self.inited: | |
return | |
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) | |
self.embed.data.copy_(embed) | |
self.embed_avg.data.copy_(embed.clone()) | |
self.cluster_size.data.copy_(cluster_size) | |
self.inited.data.copy_(torch.Tensor([True])) | |
# Make sure all buffers across workers are in sync after initialization | |
flashy.distrib.broadcast_tensors(self.buffers()) | |
def replace_(self, samples, mask): | |
modified_codebook = torch.where( | |
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed | |
) | |
self.embed.data.copy_(modified_codebook) | |
def expire_codes_(self, batch_samples): | |
if self.threshold_ema_dead_code == 0: | |
return | |
expired_codes = self.cluster_size < self.threshold_ema_dead_code | |
if not torch.any(expired_codes): | |
return | |
batch_samples = rearrange(batch_samples, "... d -> (...) d") | |
self.replace_(batch_samples, mask=expired_codes) | |
flashy.distrib.broadcast_tensors(self.buffers()) | |
def preprocess(self, x): | |
x = rearrange(x, "... d -> (...) d") | |
return x | |
def quantize(self, x): | |
embed = self.embed.t() | |
dist = -( | |
x.pow(2).sum(1, keepdim=True) | |
- 2 * x @ embed | |
+ embed.pow(2).sum(0, keepdim=True) | |
) | |
embed_ind = dist.max(dim=-1).indices | |
return embed_ind | |
def postprocess_emb(self, embed_ind, shape): | |
return embed_ind.view(*shape[:-1]) | |
def dequantize(self, embed_ind): | |
quantize = F.embedding(embed_ind, self.embed) | |
return quantize | |
def encode(self, x): | |
shape = x.shape | |
# pre-process | |
x = self.preprocess(x) | |
# quantize | |
embed_ind = self.quantize(x) | |
# post-process | |
embed_ind = self.postprocess_emb(embed_ind, shape) | |
return embed_ind | |
def decode(self, embed_ind): | |
quantize = self.dequantize(embed_ind) | |
return quantize | |
def forward(self, x): | |
raise NotImplementedError() | |
shape, dtype = x.shape, x.dtype | |
x = self.preprocess(x) | |
self.init_embed_(x) | |
embed_ind = self.quantize(x) | |
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) | |
embed_ind = self.postprocess_emb(embed_ind, shape) | |
quantize = self.dequantize(embed_ind) | |
if self.training: | |
# We do the expiry of code at that point as buffers are in sync | |
# and all the workers will take the same decision. | |
self.expire_codes_(x) | |
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) | |
embed_sum = x.t() @ embed_onehot | |
ema_inplace(self.embed_avg, embed_sum.t(), self.decay) | |
cluster_size = ( | |
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) | |
* self.cluster_size.sum() | |
) | |
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) | |
self.embed.data.copy_(embed_normalized) | |
return quantize, embed_ind | |
class VectorQuantization(nn.Module): | |
"""Vector quantization implementation. | |
Currently supports only euclidean distance. | |
Args: | |
dim (int): Dimension | |
codebook_size (int): Codebook size | |
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. | |
decay (float): Decay for exponential moving average over the codebooks. | |
epsilon (float): Epsilon value for numerical stability. | |
kmeans_init (bool): Whether to use kmeans to initialize the codebooks. | |
kmeans_iters (int): Number of iterations used for kmeans initialization. | |
threshold_ema_dead_code (int): | |
channels_last (bool): Channels are the last dimension in the input tensors. | |
commitment_weight (float): Weight for commitment loss. | |
orthogonal_reg_weight (float): Orthogonal regularization weights. | |
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. | |
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider | |
for orthogonal regularization. | |
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes | |
that have an exponential moving average cluster size less than the specified threshold with | |
randomly selected vector from the current batch. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
codebook_size: int, | |
codebook_dim: tp.Optional[int] = None, | |
decay: float = 0.8, | |
epsilon: float = 1e-5, | |
kmeans_init: bool = False, | |
kmeans_iters: int = 10, | |
threshold_ema_dead_code: int = 2, | |
channels_last: bool = False, | |
commitment_weight: float = 1., | |
orthogonal_reg_weight: float = 0.0, | |
orthogonal_reg_active_codes_only: bool = False, | |
orthogonal_reg_max_codes: tp.Optional[int] = None, | |
): | |
super().__init__() | |
_codebook_dim: int = default(codebook_dim, dim) | |
requires_projection = _codebook_dim != dim | |
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) | |
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) | |
self.epsilon = epsilon | |
self.commitment_weight = commitment_weight | |
self.orthogonal_reg_weight = orthogonal_reg_weight | |
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only | |
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes | |
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, | |
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, | |
decay=decay, epsilon=epsilon, | |
threshold_ema_dead_code=threshold_ema_dead_code) | |
self.codebook_size = codebook_size | |
self.channels_last = channels_last | |
def codebook(self): | |
return self._codebook.embed | |
def inited(self): | |
return self._codebook.inited | |
def _preprocess(self, x): | |
if not self.channels_last: | |
x = rearrange(x, "b d n -> b n d") | |
return x | |
def _postprocess(self, quantize): | |
if not self.channels_last: | |
quantize = rearrange(quantize, "b n d -> b d n") | |
return quantize | |
def encode(self, x): | |
x = self._preprocess(x) | |
x = self.project_in(x) | |
embed_in = self._codebook.encode(x) | |
return embed_in | |
def decode(self, embed_ind): | |
quantize = self._codebook.decode(embed_ind) | |
quantize = self.project_out(quantize) | |
quantize = self._postprocess(quantize) | |
return quantize | |
def forward(self, x): | |
device = x.device | |
x = self._preprocess(x) | |
x = self.project_in(x) | |
quantize, embed_ind = self._codebook(x) | |
if self.training: | |
quantize = x + (quantize - x).detach() | |
loss = torch.tensor([0.0], device=device, requires_grad=self.training) | |
if self.training: | |
if self.commitment_weight > 0: | |
commit_loss = F.mse_loss(quantize.detach(), x) | |
loss = loss + commit_loss * self.commitment_weight | |
if self.orthogonal_reg_weight > 0: | |
codebook = self.codebook | |
if self.orthogonal_reg_active_codes_only: | |
# only calculate orthogonal loss for the activated codes for this batch | |
unique_code_ids = torch.unique(embed_ind) | |
codebook = codebook[unique_code_ids] | |
num_codes = codebook.shape[0] | |
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: | |
rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes] | |
codebook = codebook[rand_ids] | |
orthogonal_reg_loss = orthogonal_loss_fn(codebook) | |
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight | |
quantize = self.project_out(quantize) | |
quantize = self._postprocess(quantize) | |
return quantize, embed_ind, loss | |
class ResidualVectorQuantization(nn.Module): | |
"""Residual vector quantization implementation. | |
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf | |
""" | |
def __init__(self, *, num_quantizers, **kwargs): | |
super().__init__() | |
codebook_size = kwargs.pop('codebook_size', None) | |
if codebook_size is None: | |
raise ValueError("codebook_size must be provided in kwargs") | |
if type(codebook_size) != list: | |
codebook_size = [codebook_size] * num_quantizers | |
self.layers = nn.ModuleList( | |
[VectorQuantization(codebook_size=cur_codebook_size, **kwargs) for _,cur_codebook_size in zip(range(num_quantizers), codebook_size)] | |
) | |
# self.layers = nn.ModuleList( | |
# [VectorQuantization(**kwargs) for _ in range(num_quantizers)] | |
# ) | |
def forward(self, x, n_q: tp.Optional[int] = None): | |
quantized_out = 0.0 | |
residual = x | |
all_losses = [] | |
all_indices = [] | |
n_q = n_q or len(self.layers) | |
for i, layer in enumerate(self.layers[:n_q]): | |
quantized, indices, loss = layer(residual) | |
residual = residual - quantized | |
quantized_out = quantized_out + quantized | |
all_indices.append(indices) | |
all_losses.append(loss) | |
out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) | |
return quantized_out, out_indices, out_losses | |
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: | |
residual = x | |
all_indices = [] | |
n_q = n_q or len(self.layers) | |
for layer in self.layers[:n_q]: | |
indices = layer.encode(residual) | |
quantized = layer.decode(indices) | |
# the original code is below | |
# since quantize has the gradient of residual, according to line 321 | |
# quantize = x + (quantize - x).detach() | |
# the code below will make commitment loss to be 0 for all codebooks except for codebook1 | |
# https://github.com/facebookresearch/encodec/issues/25 | |
# therefore we change it | |
residual = residual - quantized | |
# residual = residual - quantized.detach() | |
# since commitment loss is averaged, the scale of the loss won't get change (not as said in the issue above) | |
all_indices.append(indices) | |
out_indices = torch.stack(all_indices) | |
return out_indices | |
def decode(self, q_indices: torch.Tensor) -> torch.Tensor: | |
quantized_out = torch.tensor(0.0, device=q_indices.device) | |
for i, indices in enumerate(q_indices): | |
layer = self.layers[i] | |
quantized = layer.decode(indices) | |
quantized_out = quantized_out + quantized | |
return quantized_out | |
class ResidualVectorQuantizer(BaseQuantizer): | |
"""Residual Vector Quantizer. | |
Args: | |
dimension (int): Dimension of the codebooks. | |
n_q (int): Number of residual vector quantizers used. | |
q_dropout (bool): Random quantizer drop out at train time. | |
bins (int): Codebook size. | |
decay (float): Decay for exponential moving average over the codebooks. | |
kmeans_init (bool): Whether to use kmeans to initialize the codebooks. | |
kmeans_iters (int): Number of iterations used for kmeans initialization. | |
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes | |
that have an exponential moving average cluster size less than the specified threshold with | |
randomly selected vector from the current batch. | |
orthogonal_reg_weight (float): Orthogonal regularization weights. | |
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. | |
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. | |
for orthogonal regularization. | |
""" | |
def __init__( | |
self, | |
dimension: int = 256, | |
n_q: int = 8, | |
q_dropout: bool = False, | |
bins: tp.Union[int, tp.List[int]] = 1024, | |
decay: float = 0.99, | |
kmeans_init: bool = True, | |
kmeans_iters: int = 10, | |
threshold_ema_dead_code: int = 2, | |
orthogonal_reg_weight: float = 0.0, | |
orthogonal_reg_active_codes_only: bool = False, | |
orthogonal_reg_max_codes: tp.Optional[int] = None, | |
): | |
super().__init__() | |
self.max_n_q = n_q | |
self.n_q = n_q | |
self.q_dropout = q_dropout | |
self.dimension = dimension | |
self.bins = bins | |
self.decay = decay | |
self.kmeans_init = kmeans_init | |
self.kmeans_iters = kmeans_iters | |
self.threshold_ema_dead_code = threshold_ema_dead_code | |
self.orthogonal_reg_weight = orthogonal_reg_weight | |
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only | |
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes | |
self.vq = ResidualVectorQuantization( | |
dim=self.dimension, | |
codebook_size=self.bins, | |
num_quantizers=self.n_q, | |
decay=self.decay, | |
kmeans_init=self.kmeans_init, | |
kmeans_iters=self.kmeans_iters, | |
threshold_ema_dead_code=self.threshold_ema_dead_code, | |
orthogonal_reg_weight=self.orthogonal_reg_weight, | |
orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, | |
orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, | |
channels_last=False | |
) | |
def forward(self, x: torch.Tensor, frame_rate: int): | |
n_q = self.n_q | |
if self.training and self.q_dropout: | |
n_q = int(torch.randint(1, self.n_q + 1, (1,)).item()) | |
if type(self.bins) == list: | |
bins = self.bins | |
else: | |
bins = [self.bins] * self.n_q | |
bw_per_q = [math.log2(bin) * frame_rate / 1000 for bin in bins] | |
bw = torch.tensor(sum(bw_per_q)).to(x) | |
quantized, codes, commit_loss = self.vq(x, n_q=n_q) | |
codes = codes.transpose(0, 1) | |
# codes is [B, K, T], with T frames, K nb of codebooks. | |
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) | |
def encode(self, x: torch.Tensor) -> torch.Tensor: | |
"""Encode a given input tensor with the specified frame rate at the given bandwidth. | |
The RVQ encode method sets the appropriate number of quantizer to use | |
and returns indices for each quantizer. | |
""" | |
n_q = self.n_q | |
codes = self.vq.encode(x, n_q=n_q) | |
codes = codes.transpose(0, 1) | |
# codes is [B, K, T], with T frames, K nb of codebooks. | |
return codes | |
def decode(self, codes: torch.Tensor) -> torch.Tensor: | |
"""Decode the given codes to the quantized representation.""" | |
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. | |
codes = codes.transpose(0, 1) | |
quantized = self.vq.decode(codes) | |
return quantized | |
def total_codebooks(self): | |
return self.max_n_q | |
def num_codebooks(self): | |
return self.n_q | |
def set_num_codebooks(self, n: int): | |
assert n > 0 and n <= self.max_n_q | |
self.n_q = n | |
class DummyQuantizer(BaseQuantizer): | |
"""Fake quantizer that actually does not perform any quantization. | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, x: torch.Tensor, frame_rate: int): | |
q = x.unsqueeze(1) | |
return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) | |
def encode(self, x: torch.Tensor) -> torch.Tensor: | |
"""Encode a given input tensor with the specified sample rate at the given bandwidth. | |
In the case of the DummyQuantizer, the codes are actually identical | |
to the input and resulting quantized representation as no quantization is done. | |
""" | |
return x.unsqueeze(1) | |
def decode(self, codes: torch.Tensor) -> torch.Tensor: | |
"""Decode the given codes to the quantized representation. | |
In the case of the DummyQuantizer, the codes are actually identical | |
to the input and resulting quantized representation as no quantization is done. | |
""" | |
return codes.squeeze(1) | |
def total_codebooks(self): | |
"""Total number of codebooks.""" | |
return 1 | |
def num_codebooks(self): | |
"""Total number of codebooks.""" | |
return self.total_codebooks | |
def set_num_codebooks(self, n: int): | |
"""Set the number of active codebooks.""" | |
raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") | |
class EncodecModel(CompressionModel): | |
"""Encodec model operating on the raw waveform. | |
Args: | |
encoder (nn.Module): Encoder network. | |
decoder (nn.Module): Decoder network. | |
quantizer (BaseQuantizer): Quantizer network. | |
frame_rate (int): Frame rate for the latent representation. | |
sample_rate (int): Audio sample rate. | |
channels (int): Number of audio channels. | |
causal (bool): Whether to use a causal version of the model. | |
renormalize (bool): Whether to renormalize the audio before running the model. | |
""" | |
# we need assignment to override the property in the abstract class, | |
# I couldn't find a better way... | |
frame_rate: float = 0 | |
sample_rate: int = 0 | |
channels: int = 0 | |
def __init__(self, | |
encoder: nn.Module, | |
decoder: nn.Module, | |
quantizer: BaseQuantizer, | |
frame_rate: int, | |
sample_rate: int, | |
channels: int, | |
causal: bool = False, | |
renormalize: bool = False): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.quantizer = quantizer | |
self.frame_rate = frame_rate | |
self.sample_rate = sample_rate | |
self.channels = channels | |
self.renormalize = renormalize | |
self.causal = causal | |
if self.causal: | |
# we force disabling here to avoid handling linear overlap of segments | |
# as supported in original EnCodec codebase. | |
assert not self.renormalize, 'Causal model does not support renormalize' | |
def total_codebooks(self): | |
"""Total number of quantizer codebooks available.""" | |
return self.quantizer.total_codebooks | |
def num_codebooks(self): | |
"""Active number of codebooks used by the quantizer.""" | |
return self.quantizer.num_codebooks | |
def set_num_codebooks(self, n: int): | |
"""Set the active number of codebooks used by the quantizer.""" | |
self.quantizer.set_num_codebooks(n) | |
def cardinality(self): | |
"""Cardinality of each codebook.""" | |
return self.quantizer.bins | |
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | |
scale: tp.Optional[torch.Tensor] | |
if self.renormalize: | |
mono = x.mean(dim=1, keepdim=True) | |
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() | |
scale = 1e-8 + volume | |
x = x / scale | |
scale = scale.view(-1, 1) | |
else: | |
scale = None | |
return x, scale | |
def postprocess(self, | |
x: torch.Tensor, | |
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: | |
if scale is not None: | |
assert self.renormalize | |
x = x * scale.view(-1, 1, 1) | |
return x | |
def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult: | |
if encode: | |
return self.encode(x) | |
else: | |
raise NotImplementedError("model forward and training is not supported.") | |
assert x.dim() == 3 | |
length = x.shape[-1] | |
x, scale = self.preprocess(x) | |
emb = self.encoder(x) | |
q_res = self.quantizer(emb, self.frame_rate) | |
out = self.decoder(q_res.x) | |
# remove extra padding added by the encoder and decoder | |
assert out.shape[-1] >= length, (out.shape[-1], length) | |
out = out[..., :length] | |
q_res.x = self.postprocess(out, scale) | |
return q_res | |
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | |
"""Encode the given input tensor to quantized representation along with scale parameter. | |
Args: | |
x (torch.Tensor): Float tensor of shape [B, C, T] | |
Returns: | |
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of: | |
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. | |
scale a float tensor containing the scale for audio renormalizealization. | |
""" | |
assert x.dim() == 3 | |
x, scale = self.preprocess(x) | |
emb = self.encoder(x) | |
codes = self.quantizer.encode(emb) | |
return codes, scale | |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): | |
"""Decode the given codes to a reconstructed representation, using the scale to perform | |
audio denormalization if needed. | |
Args: | |
codes (torch.Tensor): Int tensor of shape [B, K, T] | |
scale (torch.Tensor, optional): Float tensor containing the scale value. | |
Returns: | |
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. | |
""" | |
emb = self.decode_latent(codes) | |
out = self.decoder(emb) | |
out = self.postprocess(out, scale) | |
# out contains extra padding added by the encoder and decoder | |
return out | |
def decode_latent(self, codes: torch.Tensor): | |
"""Decode from the discrete codes to continuous latent space.""" | |
return self.quantizer.decode(codes) | |
class EncodecModel_encode_only(CompressionModel): | |
"""Encodec model operating on the raw waveform. Encode only, so no decoder | |
Args: | |
encoder (nn.Module): Encoder network. | |
quantizer (BaseQuantizer): Quantizer network. | |
frame_rate (int): Frame rate for the latent representation. | |
sample_rate (int): Audio sample rate. | |
channels (int): Number of audio channels. | |
causal (bool): Whether to use a causal version of the model. | |
renormalize (bool): Whether to renormalize the audio before running the model. | |
""" | |
# we need assignment to override the property in the abstract class, | |
# I couldn't find a better way... | |
frame_rate: float = 0 | |
sample_rate: int = 0 | |
channels: int = 0 | |
def __init__(self, | |
encoder: nn.Module, | |
quantizer: BaseQuantizer, | |
frame_rate: int, | |
sample_rate: int, | |
channels: int, | |
causal: bool = False, | |
renormalize: bool = False): | |
super().__init__() | |
self.encoder = encoder | |
self.quantizer = quantizer | |
self.frame_rate = frame_rate | |
self.sample_rate = sample_rate | |
self.channels = channels | |
self.renormalize = renormalize | |
self.causal = causal | |
if self.causal: | |
# we force disabling here to avoid handling linear overlap of segments | |
# as supported in original EnCodec codebase. | |
assert not self.renormalize, 'Causal model does not support renormalize' | |
def total_codebooks(self): | |
"""Total number of quantizer codebooks available.""" | |
return self.quantizer.total_codebooks | |
def num_codebooks(self): | |
"""Active number of codebooks used by the quantizer.""" | |
return self.quantizer.num_codebooks | |
def set_num_codebooks(self, n: int): | |
"""Set the active number of codebooks used by the quantizer.""" | |
self.quantizer.set_num_codebooks(n) | |
def cardinality(self): | |
"""Cardinality of each codebook.""" | |
return self.quantizer.bins | |
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | |
scale: tp.Optional[torch.Tensor] | |
if self.renormalize: | |
mono = x.mean(dim=1, keepdim=True) | |
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() | |
scale = 1e-8 + volume | |
x = x / scale | |
scale = scale.view(-1, 1) | |
else: | |
scale = None | |
return x, scale | |
def postprocess(self, | |
x: torch.Tensor, | |
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: | |
if scale is not None: | |
assert self.renormalize | |
x = x * scale.view(-1, 1, 1) | |
return x | |
def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult: | |
if encode: | |
return self.encode(x) | |
else: | |
raise NotImplementedError("model forward and training is not supported.") | |
assert x.dim() == 3 | |
length = x.shape[-1] | |
x, scale = self.preprocess(x) | |
emb = self.encoder(x) | |
q_res = self.quantizer(emb, self.frame_rate) | |
out = self.decoder(q_res.x) | |
# remove extra padding added by the encoder and decoder | |
assert out.shape[-1] >= length, (out.shape[-1], length) | |
out = out[..., :length] | |
q_res.x = self.postprocess(out, scale) | |
return q_res | |
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | |
"""Encode the given input tensor to quantized representation along with scale parameter. | |
Args: | |
x (torch.Tensor): Float tensor of shape [B, C, T] | |
Returns: | |
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of: | |
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. | |
scale a float tensor containing the scale for audio renormalizealization. | |
""" | |
assert x.dim() == 3 | |
x, scale = self.preprocess(x) | |
emb = self.encoder(x) | |
codes = self.quantizer.encode(emb) | |
return codes, scale | |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): | |
"""Decode the given codes to a reconstructed representation, using the scale to perform | |
audio denormalization if needed. | |
Args: | |
codes (torch.Tensor): Int tensor of shape [B, K, T] | |
scale (torch.Tensor, optional): Float tensor containing the scale value. | |
Returns: | |
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. | |
""" | |
raise NotImplementedError("Decode is not supported for encode only model") | |
emb = self.decode_latent(codes) | |
out = self.decoder(emb) | |
out = self.postprocess(out, scale) | |
# out contains extra padding added by the encoder and decoder | |
return out | |
def decode_latent(self, codes: torch.Tensor): | |
"""Decode from the discrete codes to continuous latent space.""" | |
raise NotImplementedError("Decode is not supported for encode only model") | |
return self.quantizer.decode(codes) | |
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> BaseQuantizer: | |
klass = { | |
'no_quant': DummyQuantizer, | |
'rvq': ResidualVectorQuantizer | |
}[quantizer] | |
kwargs = dict_from_config(getattr(cfg, quantizer)) | |
if quantizer != 'no_quant': | |
kwargs['dimension'] = dimension | |
return klass(**kwargs) | |
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): | |
if encoder_name == 'seanet': | |
kwargs = dict_from_config(getattr(cfg, 'seanet')) | |
encoder_override_kwargs = kwargs.pop('encoder') | |
decoder_override_kwargs = kwargs.pop('decoder') | |
encoder_kwargs = {**kwargs, **encoder_override_kwargs} | |
decoder_kwargs = {**kwargs, **decoder_override_kwargs} | |
encoder = SEANetEncoder(**encoder_kwargs) | |
decoder = SEANetDecoder(**decoder_kwargs) | |
return encoder, decoder | |
else: | |
raise KeyError(f"Unexpected compression model {cfg.compression_model}") | |
def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> CompressionModel: | |
"""Instantiate a compression model.""" | |
if device == None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
state = torch.load(ckpt_fn, map_location='cpu') | |
cfg = state['xp.cfg'] | |
cfg.device = str(device) | |
weights = state['best_state']['model'] | |
assert cfg.compression_model == 'encodec', "Only Encodec model is supported for now." | |
if encode_only: | |
all_keys = list(weights.keys()) | |
for key in all_keys: | |
if key.startswith('decoder'): | |
del weights[key] | |
kwargs = dict_from_config(getattr(cfg, 'encodec')) | |
encoder_name = kwargs.pop('autoencoder') | |
quantizer_name = kwargs.pop('quantizer') | |
encoder, _ = get_encodec_autoencoder(encoder_name, cfg) | |
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) | |
frame_rate = kwargs['sample_rate'] // encoder.hop_length | |
renormalize = kwargs.pop('renormalize', False) | |
# deprecated params | |
kwargs.pop('renorm', None) | |
compression_model = EncodecModel_encode_only(encoder, quantizer, | |
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) | |
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" | |
compression_model.load_state_dict(weights) | |
compression_model.eval() | |
return compression_model | |
else: | |
kwargs = dict_from_config(getattr(cfg, 'encodec')) | |
encoder_name = kwargs.pop('autoencoder') | |
quantizer_name = kwargs.pop('quantizer') | |
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) | |
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) | |
frame_rate = kwargs['sample_rate'] // encoder.hop_length | |
renormalize = kwargs.pop('renormalize', False) | |
# deprecated params | |
kwargs.pop('renorm', None) | |
compression_model = EncodecModel(encoder, decoder, quantizer, | |
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) | |
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" | |
compression_model.load_state_dict(weights) | |
compression_model.eval() | |
return compression_model | |
if __name__ == "__main__": | |
import torchaudio | |
ckpt_fn = "/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th" | |
audio_in_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam.wav", "/home/pyp/BoostedVoiceEditor/demo/ray.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", "/home/pyp/BoostedVoiceEditor/demo/bible.wav", "/home/pyp/BoostedVoiceEditor/demo/miley.wav"] | |
audio_out_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav"] | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = get_compression_model(ckpt_fn, device=device) | |
for audio_in_fn, audio_out_fn in zip(audio_in_fns, audio_out_fns): | |
audio_in, sr = torchaudio.load(audio_in_fn) | |
if sr != model.sample_rate: | |
audio_in = torchaudio.transforms.Resample(sr, model.sample_rate)(audio_in) | |
if audio_in.shape[0] == 2: | |
audio_in = audio_in.mean(dim=0, keepdim=True) | |
audio_in = audio_in.unsqueeze(0) | |
audio_in = audio_in.to(torch.float32).to(device) | |
codes = model.encode(audio_in)[0] | |
audio_out = model.decode(codes)[0].cpu() | |
torchaudio.save(audio_out_fn, audio_out, model.sample_rate) |