File size: 3,691 Bytes
5381499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import Optional
import torch
from torch import nn
import torch.nn.functional as F
import einops


class QuantizationModule(nn.Module):
    def __init__(
        self, config
    ):
        """
        Args:
            x (Tensor): The extracted features from waveforms. Shape: (batch, num_frames, in_features)
            mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)

        Returns:
            out (Tensor): The quantized features. Shape: (batch, num_frames, d_model)
            perplexity (Tensor): The perplexity of the quantized features. Shape: (1)
        """
        super().__init__()

        assert (
            config.d_model % config.num_codebooks == 0
        ), "d_model must be divisible by num_codebooks"

        self.num_codebooks = config.num_codebooks
        self.num_codewords = config.num_codewords
        self.d_model = config.d_model
        self.codeword_dim = config.d_model // config.num_codebooks

        self.codebooks = self._init_codebooks()

        self.projection = nn.Linear(
            config.in_features, self.num_codebooks * self.num_codewords
        )

        self.tau = 1  # temperature factor

    def _init_codebooks(self):
        codebooks = torch.randn(
            1, 1, self.num_codebooks, self.num_codewords, self.codeword_dim
        )
        nn.init.xavier_uniform_(codebooks)

        return nn.Parameter(codebooks)

    @property
    def total_codewords(self):
        return self.num_codebooks * self.num_codewords

    @staticmethod
    def _compute_perplexity(probs: torch.Tensor, mask: Optional[torch.Tensor] = None):
        """
        Computes the perplexity of the quantized features. (Diversity loss)

        Args:
            probs (Tensor): The probability distribution of words in codebooks. Shape: (batch, num_frames, num_codebooks, num_codewords)
            mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
        """
        if mask is not None:
            probs = (
                probs * ~mask[..., None, None]
            )  # Turn invalid frames' probability to 0
            marginal_probs = (
                einops.reduce(probs, "b nf nb nw -> nb nw", "sum") / mask.sum()
            )
        else:
            marginal_probs = einops.reduce(probs, "b nf nb nw -> nb nw", "mean")

        perplexity = torch.exp(
            -torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)
        ).sum()
        return perplexity

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        batch_size, num_frames, _ = x.shape

        logits = self.projection(x)
        logits = logits.view(
            batch_size, num_frames, self.num_codebooks, self.num_codewords
        )

        if self.training:
            word_probs = F.gumbel_softmax(logits, tau=self.tau, hard=True, dim=-1)
            word_soft_probs = F.softmax(logits, dim=-1)

            perplexity = self._compute_perplexity(word_soft_probs, mask=mask)
        else:
            word_ids = torch.argmax(logits, dim=-1, keepdim=True)
            word_probs = torch.zeros_like(logits).scatter_(-1, word_ids, 1.0)  # One-hot

            perplexity = self._compute_perplexity(word_probs, mask=mask)

        # (batch, num_frames, num_codebooks, num_codewords, 1) x (1, 1, num_codebooks, num_codewords, codeword_dim)
        # -> (batch, num_frames, num_codebooks x codeword_dim)
        quantized = einops.reduce(
            word_probs.unsqueeze_(-1) * self.codebooks,
            "b nf nb nw d -> b nf (nb d)",
            reduction="sum",
        )

        return quantized, perplexity