Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,043 Bytes
efa71f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from typing import Any, Dict
import torch
import torch.nn as nn
class SplitVectorDiffusion(nn.Module):
def __init__(self, *, device: torch.device, wrapped: nn.Module, n_ctx: int, d_latent: int):
super().__init__()
self.device = device
self.n_ctx = n_ctx
self.d_latent = d_latent
self.wrapped = wrapped
if hasattr(self.wrapped, "cached_model_kwargs"):
self.cached_model_kwargs = self.wrapped.cached_model_kwargs
def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs):
h = x.reshape(x.shape[0], self.n_ctx, -1).permute(0, 2, 1)
pre_channels = h.shape[1]
h = self.wrapped(h, t, **kwargs)
assert (
h.shape[1] == pre_channels * 2
), "expected twice as many outputs for variance prediction"
eps, var = torch.chunk(h, 2, dim=1)
return torch.cat(
[
eps.permute(0, 2, 1).flatten(1),
var.permute(0, 2, 1).flatten(1),
],
dim=1,
)
|