|
import sys |
|
from typing import Literal, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from cube3d.model.transformers.norm import RMSNorm |
|
|
|
|
|
class SphericalVectorQuantizer(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_codes: int, |
|
width: Optional[int] = None, |
|
codebook_regularization: Literal["batch_norm", "kl"] = "batch_norm", |
|
): |
|
""" |
|
Initializes the SphericalVQ module. |
|
Args: |
|
embed_dim (int): The dimensionality of the embeddings. |
|
num_codes (int): The number of codes in the codebook. |
|
width (Optional[int], optional): The width of the input. Defaults to None. |
|
Raises: |
|
ValueError: If beta is not in the range [0, 1]. |
|
""" |
|
super().__init__() |
|
|
|
self.num_codes = num_codes |
|
|
|
self.codebook = nn.Embedding(num_codes, embed_dim) |
|
self.codebook.weight.data.uniform_(-1.0 / num_codes, 1.0 / num_codes) |
|
|
|
width = width or embed_dim |
|
if width != embed_dim: |
|
self.c_in = nn.Linear(width, embed_dim) |
|
self.c_x = nn.Linear(width, embed_dim) |
|
self.c_out = nn.Linear(embed_dim, width) |
|
else: |
|
self.c_in = self.c_out = self.c_x = nn.Identity() |
|
|
|
self.norm = RMSNorm(embed_dim, elementwise_affine=False) |
|
self.cb_reg = codebook_regularization |
|
if self.cb_reg == "batch_norm": |
|
self.cb_norm = nn.BatchNorm1d(embed_dim, track_running_stats=False) |
|
else: |
|
self.cb_weight = nn.Parameter(torch.ones([embed_dim])) |
|
self.cb_bias = nn.Parameter(torch.zeros([embed_dim])) |
|
self.cb_norm = lambda x: x.mul(self.cb_weight).add_(self.cb_bias) |
|
|
|
def get_codebook(self): |
|
""" |
|
Retrieves the normalized codebook weights. |
|
This method applies a series of normalization operations to the |
|
codebook weights, ensuring they are properly scaled and normalized |
|
before being returned. |
|
Returns: |
|
torch.Tensor: The normalized weights of the codebook. |
|
""" |
|
|
|
return self.norm(self.cb_norm(self.codebook.weight)) |
|
|
|
@torch.no_grad() |
|
|
|
def lookup_codebook(self, q: torch.Tensor): |
|
""" |
|
Perform a lookup in the codebook and process the result. |
|
This method takes an input tensor of indices, retrieves the corresponding |
|
embeddings from the codebook, and applies a transformation to the retrieved |
|
embeddings. |
|
Args: |
|
q (torch.Tensor): A tensor containing indices to look up in the codebook. |
|
Returns: |
|
torch.Tensor: The transformed embeddings retrieved from the codebook. |
|
""" |
|
|
|
|
|
z_q = F.embedding(q, self.get_codebook()) |
|
z_q = self.c_out(z_q) |
|
return z_q |
|
|
|
@torch.no_grad() |
|
def lookup_codebook_latents(self, q: torch.Tensor): |
|
""" |
|
Retrieves the latent representations from the codebook corresponding to the given indices. |
|
Args: |
|
q (torch.Tensor): A tensor containing the indices of the codebook entries to retrieve. |
|
The indices should be integers and correspond to the rows in the codebook. |
|
Returns: |
|
torch.Tensor: A tensor containing the latent representations retrieved from the codebook. |
|
The shape of the returned tensor depends on the shape of the input indices |
|
and the dimensionality of the codebook entries. |
|
""" |
|
|
|
|
|
z_q = F.embedding(q, self.get_codebook()) |
|
return z_q |
|
|
|
def quantize(self, z: torch.Tensor): |
|
""" |
|
Quantizes the latent codes z with the codebook |
|
|
|
Args: |
|
z (Tensor): B x ... x F |
|
""" |
|
|
|
|
|
codebook = self.get_codebook() |
|
|
|
with torch.no_grad(): |
|
|
|
z_flat = z.view(-1, z.shape[-1]) |
|
|
|
|
|
d = torch.cdist(z_flat, codebook) |
|
q = torch.argmin(d, dim=1) |
|
|
|
z_q = codebook[q, :].reshape(*z.shape[:-1], -1) |
|
q = q.view(*z.shape[:-1]) |
|
|
|
return z_q, {"z": z.detach(), "q": q} |
|
|
|
def straight_through_approximation(self, z, z_q): |
|
"""passed gradient from z_q to z""" |
|
z_q = z + (z_q - z).detach() |
|
return z_q |
|
|
|
def forward(self, z: torch.Tensor): |
|
""" |
|
Forward pass of the spherical vector quantization autoencoder. |
|
Args: |
|
z (torch.Tensor): Input tensor of shape (batch_size, ..., feature_dim). |
|
Returns: |
|
Tuple[torch.Tensor, Dict[str, Any]]: |
|
- z_q (torch.Tensor): The quantized output tensor after applying the |
|
straight-through approximation and output projection. |
|
- ret_dict (Dict[str, Any]): A dictionary containing additional |
|
information: |
|
- "z_q" (torch.Tensor): Detached quantized tensor. |
|
- "q" (torch.Tensor): Indices of the quantized vectors. |
|
- "perplexity" (torch.Tensor): The perplexity of the quantization, |
|
calculated as the exponential of the negative sum of the |
|
probabilities' log values. |
|
""" |
|
|
|
with torch.autocast(device_type=z.device.type, enabled=False): |
|
|
|
z = z.float() |
|
|
|
|
|
z_e = self.norm(self.c_in(z)) |
|
z_q, ret_dict = self.quantize(z_e) |
|
|
|
ret_dict["z_q"] = z_q.detach() |
|
z_q = self.straight_through_approximation(z_e, z_q) |
|
z_q = self.c_out(z_q) |
|
|
|
return z_q, ret_dict |
|
|