|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import torch |
|
from torch import nn |
|
|
|
|
|
class SinusPositionEmbedding(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x, scale=1000): |
|
device = x.device |
|
half_dim = self.dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) |
|
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) |
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
return emb |
|
|
|
class TimestepEmbedding(nn.Module): |
|
def __init__(self, dim, freq_embed_dim=256): |
|
super().__init__() |
|
self.time_embed = SinusPositionEmbedding(freq_embed_dim) |
|
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) |
|
|
|
def forward(self, timestep): |
|
time_hidden = self.time_embed(timestep) |
|
time_hidden = time_hidden.to(timestep.dtype) |
|
time = self.time_mlp(time_hidden) |
|
return time |