|
|
|
|
|
|
|
|
|
from math import log2, ceil |
|
from functools import partial |
|
from typing import Any, Optional, List, Iterable |
|
|
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
from torch import nn, einsum, Tensor |
|
import torch.nn.functional as F |
|
|
|
from einops import rearrange, repeat, reduce |
|
from einops.layers.torch import Rearrange |
|
from torchvision.utils import save_image |
|
import math |
|
|
|
|
|
def get_same_padding(size, kernel, dilation, stride): |
|
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 |
|
|
|
|
|
class AdaptiveConv2DMod(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_out, |
|
kernel, |
|
*, |
|
demod=True, |
|
stride=1, |
|
dilation=1, |
|
eps=1e-8, |
|
num_conv_kernels=1, |
|
): |
|
super().__init__() |
|
self.eps = eps |
|
|
|
self.dim_out = dim_out |
|
|
|
self.kernel = kernel |
|
self.stride = stride |
|
self.dilation = dilation |
|
self.adaptive = num_conv_kernels > 1 |
|
|
|
self.weights = nn.Parameter( |
|
torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel)) |
|
) |
|
|
|
self.demod = demod |
|
|
|
nn.init.kaiming_normal_( |
|
self.weights, a=0, mode="fan_in", nonlinearity="leaky_relu" |
|
) |
|
|
|
def forward( |
|
self, fmap, mod: Optional[Tensor] = None, kernel_mod: Optional[Tensor] = None |
|
): |
|
""" |
|
notation |
|
|
|
b - batch |
|
n - convs |
|
o - output |
|
i - input |
|
k - kernel |
|
""" |
|
|
|
b, h = fmap.shape[0], fmap.shape[-2] |
|
|
|
|
|
|
|
|
|
if mod.shape[0] != b: |
|
mod = repeat(mod, "b ... -> (s b) ...", s=b // mod.shape[0]) |
|
|
|
if exists(kernel_mod): |
|
kernel_mod_has_el = kernel_mod.numel() > 0 |
|
|
|
assert self.adaptive or not kernel_mod_has_el |
|
|
|
if kernel_mod_has_el and kernel_mod.shape[0] != b: |
|
kernel_mod = repeat( |
|
kernel_mod, "b ... -> (s b) ...", s=b // kernel_mod.shape[0] |
|
) |
|
|
|
|
|
|
|
weights = self.weights |
|
|
|
if self.adaptive: |
|
weights = repeat(weights, "... -> b ...", b=b) |
|
|
|
|
|
|
|
assert exists(kernel_mod) and kernel_mod.numel() > 0 |
|
|
|
kernel_attn = kernel_mod.softmax(dim=-1) |
|
kernel_attn = rearrange(kernel_attn, "b n -> b n 1 1 1 1") |
|
|
|
weights = reduce(weights * kernel_attn, "b n ... -> b ...", "sum") |
|
|
|
|
|
|
|
mod = rearrange(mod, "b i -> b 1 i 1 1") |
|
|
|
weights = weights * (mod + 1) |
|
|
|
if self.demod: |
|
inv_norm = ( |
|
reduce(weights**2, "b o i k1 k2 -> b o 1 1 1", "sum") |
|
.clamp(min=self.eps) |
|
.rsqrt() |
|
) |
|
weights = weights * inv_norm |
|
|
|
fmap = rearrange(fmap, "b c h w -> 1 (b c) h w") |
|
|
|
weights = rearrange(weights, "b o ... -> (b o) ...") |
|
|
|
padding = get_same_padding(h, self.kernel, self.dilation, self.stride) |
|
fmap = F.conv2d(fmap, weights, padding=padding, groups=b) |
|
|
|
return rearrange(fmap, "1 (b o) ... -> b o ...", b=b) |
|
|
|
|
|
class Attend(nn.Module): |
|
def __init__(self, dropout=0.0, flash=False): |
|
super().__init__() |
|
self.dropout = dropout |
|
self.attn_dropout = nn.Dropout(dropout) |
|
self.scale = nn.Parameter(torch.randn(1)) |
|
self.flash = flash |
|
|
|
def flash_attn(self, q, k, v): |
|
q, k, v = map(lambda t: t.contiguous(), (q, k, v)) |
|
out = F.scaled_dot_product_attention( |
|
q, k, v, dropout_p=self.dropout if self.training else 0.0 |
|
) |
|
return out |
|
|
|
def forward(self, q, k, v): |
|
if self.flash: |
|
return self.flash_attn(q, k, v) |
|
|
|
scale = q.shape[-1] ** -0.5 |
|
|
|
|
|
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale |
|
|
|
|
|
attn = sim.softmax(dim=-1) |
|
attn = self.attn_dropout(attn) |
|
|
|
|
|
out = einsum("b h i j, b h j d -> b h i d", attn, v) |
|
|
|
return out |
|
|
|
|
|
def exists(x): |
|
return x is not None |
|
|
|
|
|
def default(val, d): |
|
if exists(val): |
|
return val |
|
return d() if callable(d) else d |
|
|
|
|
|
def cast_tuple(t, length=1): |
|
if isinstance(t, tuple): |
|
return t |
|
return (t,) * length |
|
|
|
|
|
def identity(t, *args, **kwargs): |
|
return t |
|
|
|
|
|
def is_power_of_two(n): |
|
return log2(n).is_integer() |
|
|
|
|
|
def null_iterator(): |
|
while True: |
|
yield None |
|
|
|
|
|
def Downsample(dim, dim_out=None): |
|
return nn.Sequential( |
|
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), |
|
nn.Conv2d(dim * 4, default(dim_out, dim), 1), |
|
) |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) |
|
self.eps = 1e-4 |
|
|
|
def forward(self, x): |
|
return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5) |
|
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
def __init__(self, dim, dim_out, groups=8, num_conv_kernels=0): |
|
super().__init__() |
|
self.proj = AdaptiveConv2DMod( |
|
dim, dim_out, kernel=3, num_conv_kernels=num_conv_kernels |
|
) |
|
self.kernel = 3 |
|
self.dilation = 1 |
|
self.stride = 1 |
|
|
|
self.act = nn.SiLU() |
|
|
|
def forward(self, x, conv_mods_iter: Optional[Iterable] = None): |
|
conv_mods_iter = default(conv_mods_iter, null_iterator()) |
|
|
|
x = self.proj(x, mod=next(conv_mods_iter), kernel_mod=next(conv_mods_iter)) |
|
|
|
x = self.act(x) |
|
return x |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__( |
|
self, dim, dim_out, *, groups=8, num_conv_kernels=0, style_dims: List = [] |
|
): |
|
super().__init__() |
|
style_dims.extend([dim, num_conv_kernels, dim_out, num_conv_kernels]) |
|
|
|
self.block1 = Block( |
|
dim, dim_out, groups=groups, num_conv_kernels=num_conv_kernels |
|
) |
|
self.block2 = Block( |
|
dim_out, dim_out, groups=groups, num_conv_kernels=num_conv_kernels |
|
) |
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() |
|
|
|
def forward(self, x, conv_mods_iter: Optional[Iterable] = None): |
|
h = self.block1(x, conv_mods_iter=conv_mods_iter) |
|
h = self.block2(h, conv_mods_iter=conv_mods_iter) |
|
|
|
return h + self.res_conv(x) |
|
|
|
|
|
class LinearAttention(nn.Module): |
|
def __init__(self, dim, heads=4, dim_head=32): |
|
super().__init__() |
|
self.scale = dim_head**-0.5 |
|
self.heads = heads |
|
hidden_dim = dim_head * heads |
|
|
|
self.norm = RMSNorm(dim) |
|
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) |
|
|
|
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim)) |
|
|
|
def forward(self, x): |
|
b, c, h, w = x.shape |
|
|
|
x = self.norm(x) |
|
|
|
qkv = self.to_qkv(x).chunk(3, dim=1) |
|
q, k, v = map( |
|
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv |
|
) |
|
|
|
q = q.softmax(dim=-2) |
|
k = k.softmax(dim=-1) |
|
|
|
q = q * self.scale |
|
|
|
context = torch.einsum("b h d n, b h e n -> b h d e", k, v) |
|
|
|
out = torch.einsum("b h d e, b h d n -> b h e n", context, q) |
|
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) |
|
return self.to_out(out) |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, dim, heads=4, dim_head=32, flash=False): |
|
super().__init__() |
|
self.heads = heads |
|
hidden_dim = dim_head * heads |
|
|
|
self.norm = RMSNorm(dim) |
|
|
|
self.attend = Attend(flash=flash) |
|
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) |
|
self.to_out = nn.Conv2d(hidden_dim, dim, 1) |
|
|
|
def forward(self, x): |
|
b, c, h, w = x.shape |
|
x = self.norm(x) |
|
qkv = self.to_qkv(x).chunk(3, dim=1) |
|
|
|
q, k, v = map( |
|
lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=self.heads), qkv |
|
) |
|
|
|
out = self.attend(q, k, v) |
|
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) |
|
|
|
return self.to_out(out) |
|
|
|
|
|
|
|
def FeedForward(dim, mult=4): |
|
return nn.Sequential( |
|
RMSNorm(dim), |
|
nn.Conv2d(dim, dim * mult, 1), |
|
nn.GELU(), |
|
nn.Conv2d(dim * mult, dim, 1), |
|
) |
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, dim, dim_head=64, heads=8, depth=1, flash_attn=True, ff_mult=4): |
|
super().__init__() |
|
self.layers = nn.ModuleList([]) |
|
|
|
for _ in range(depth): |
|
self.layers.append( |
|
nn.ModuleList( |
|
[ |
|
Attention( |
|
dim=dim, dim_head=dim_head, heads=heads, flash=flash_attn |
|
), |
|
FeedForward(dim=dim, mult=ff_mult), |
|
] |
|
) |
|
) |
|
|
|
def forward(self, x): |
|
for attn, ff in self.layers: |
|
x = attn(x) + x |
|
x = ff(x) + x |
|
|
|
return x |
|
|
|
|
|
class LinearTransformer(nn.Module): |
|
def __init__(self, dim, dim_head=64, heads=8, depth=1, ff_mult=4): |
|
super().__init__() |
|
self.layers = nn.ModuleList([]) |
|
|
|
for _ in range(depth): |
|
self.layers.append( |
|
nn.ModuleList( |
|
[ |
|
LinearAttention(dim=dim, dim_head=dim_head, heads=heads), |
|
FeedForward(dim=dim, mult=ff_mult), |
|
] |
|
) |
|
) |
|
|
|
def forward(self, x): |
|
for attn, ff in self.layers: |
|
x = attn(x) + x |
|
x = ff(x) + x |
|
|
|
return x |
|
|
|
|
|
class NearestNeighborhoodUpsample(nn.Module): |
|
def __init__(self, dim, dim_out=None): |
|
super().__init__() |
|
dim_out = default(dim_out, dim) |
|
self.conv = nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1) |
|
|
|
def forward(self, x): |
|
|
|
if x.shape[0] >= 64: |
|
x = x.contiguous() |
|
|
|
x = F.interpolate(x, scale_factor=2.0, mode="nearest") |
|
x = self.conv(x) |
|
|
|
return x |
|
|
|
|
|
class EqualLinear(nn.Module): |
|
def __init__(self, dim, dim_out, lr_mul=1, bias=True): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.randn(dim_out, dim)) |
|
if bias: |
|
self.bias = nn.Parameter(torch.zeros(dim_out)) |
|
|
|
self.lr_mul = lr_mul |
|
|
|
def forward(self, input): |
|
return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) |
|
|
|
|
|
class StyleGanNetwork(nn.Module): |
|
def __init__(self, dim_in=128, dim_out=512, depth=8, lr_mul=0.1, dim_text_latent=0): |
|
super().__init__() |
|
self.dim_in = dim_in |
|
self.dim_out = dim_out |
|
self.dim_text_latent = dim_text_latent |
|
|
|
layers = [] |
|
for i in range(depth): |
|
is_first = i == 0 |
|
|
|
if is_first: |
|
dim_in_layer = dim_in + dim_text_latent |
|
else: |
|
dim_in_layer = dim_out |
|
|
|
dim_out_layer = dim_out |
|
|
|
layers.extend( |
|
[EqualLinear(dim_in_layer, dim_out_layer, lr_mul), nn.LeakyReLU(0.2)] |
|
) |
|
|
|
self.net = nn.Sequential(*layers) |
|
|
|
def forward(self, x, text_latent=None): |
|
x = F.normalize(x, dim=1) |
|
if self.dim_text_latent > 0: |
|
assert exists(text_latent) |
|
x = torch.cat((x, text_latent), dim=-1) |
|
return self.net(x) |
|
|
|
|
|
class UnetUpsampler(torch.nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
*, |
|
image_size: int, |
|
input_image_size: int, |
|
init_dim: Optional[int] = None, |
|
out_dim: Optional[int] = None, |
|
style_network: Optional[dict] = None, |
|
up_dim_mults: tuple = (1, 2, 4, 8, 16), |
|
down_dim_mults: tuple = (4, 8, 16), |
|
channels: int = 3, |
|
resnet_block_groups: int = 8, |
|
full_attn: tuple = (False, False, False, True, True), |
|
flash_attn: bool = True, |
|
self_attn_dim_head: int = 64, |
|
self_attn_heads: int = 8, |
|
attn_depths: tuple = (2, 2, 2, 2, 4), |
|
mid_attn_depth: int = 4, |
|
num_conv_kernels: int = 4, |
|
resize_mode: str = "bilinear", |
|
unconditional: bool = True, |
|
skip_connect_scale: Optional[float] = None, |
|
): |
|
super().__init__() |
|
self.style_network = style_network = StyleGanNetwork(**style_network) |
|
self.unconditional = unconditional |
|
assert not ( |
|
unconditional |
|
and exists(style_network) |
|
and style_network.dim_text_latent > 0 |
|
) |
|
|
|
assert is_power_of_two(image_size) and is_power_of_two( |
|
input_image_size |
|
), "both output image size and input image size must be power of 2" |
|
assert ( |
|
input_image_size < image_size |
|
), "input image size must be smaller than the output image size, thus upsampling" |
|
|
|
self.image_size = image_size |
|
self.input_image_size = input_image_size |
|
|
|
style_embed_split_dims = [] |
|
|
|
self.channels = channels |
|
input_channels = channels |
|
|
|
init_dim = default(init_dim, dim) |
|
|
|
up_dims = [init_dim, *map(lambda m: dim * m, up_dim_mults)] |
|
init_down_dim = up_dims[len(up_dim_mults) - len(down_dim_mults)] |
|
down_dims = [init_down_dim, *map(lambda m: dim * m, down_dim_mults)] |
|
self.init_conv = nn.Conv2d(input_channels, init_down_dim, 7, padding=3) |
|
|
|
up_in_out = list(zip(up_dims[:-1], up_dims[1:])) |
|
down_in_out = list(zip(down_dims[:-1], down_dims[1:])) |
|
|
|
block_klass = partial( |
|
ResnetBlock, |
|
groups=resnet_block_groups, |
|
num_conv_kernels=num_conv_kernels, |
|
style_dims=style_embed_split_dims, |
|
) |
|
|
|
FullAttention = partial(Transformer, flash_attn=flash_attn) |
|
*_, mid_dim = up_dims |
|
|
|
self.skip_connect_scale = default(skip_connect_scale, 2**-0.5) |
|
|
|
self.downs = nn.ModuleList([]) |
|
self.ups = nn.ModuleList([]) |
|
|
|
block_count = 6 |
|
|
|
for ind, ( |
|
(dim_in, dim_out), |
|
layer_full_attn, |
|
layer_attn_depth, |
|
) in enumerate(zip(down_in_out, full_attn, attn_depths)): |
|
attn_klass = FullAttention if layer_full_attn else LinearTransformer |
|
|
|
blocks = [] |
|
for i in range(block_count): |
|
blocks.append(block_klass(dim_in, dim_in)) |
|
|
|
self.downs.append( |
|
nn.ModuleList( |
|
[ |
|
nn.ModuleList(blocks), |
|
nn.ModuleList( |
|
[ |
|
( |
|
attn_klass( |
|
dim_in, |
|
dim_head=self_attn_dim_head, |
|
heads=self_attn_heads, |
|
depth=layer_attn_depth, |
|
) |
|
if layer_full_attn |
|
else None |
|
), |
|
nn.Conv2d( |
|
dim_in, dim_out, kernel_size=3, stride=2, padding=1 |
|
), |
|
] |
|
), |
|
] |
|
) |
|
) |
|
|
|
self.mid_block1 = block_klass(mid_dim, mid_dim) |
|
self.mid_attn = FullAttention( |
|
mid_dim, |
|
dim_head=self_attn_dim_head, |
|
heads=self_attn_heads, |
|
depth=mid_attn_depth, |
|
) |
|
self.mid_block2 = block_klass(mid_dim, mid_dim) |
|
|
|
*_, last_dim = up_dims |
|
|
|
for ind, ( |
|
(dim_in, dim_out), |
|
layer_full_attn, |
|
layer_attn_depth, |
|
) in enumerate( |
|
zip( |
|
reversed(up_in_out), |
|
reversed(full_attn), |
|
reversed(attn_depths), |
|
) |
|
): |
|
attn_klass = FullAttention if layer_full_attn else LinearTransformer |
|
|
|
blocks = [] |
|
input_dim = dim_in * 2 if ind < len(down_in_out) else dim_in |
|
for i in range(block_count): |
|
blocks.append(block_klass(input_dim, dim_in)) |
|
|
|
self.ups.append( |
|
nn.ModuleList( |
|
[ |
|
nn.ModuleList(blocks), |
|
nn.ModuleList( |
|
[ |
|
NearestNeighborhoodUpsample( |
|
last_dim if ind == 0 else dim_out, |
|
dim_in, |
|
), |
|
( |
|
attn_klass( |
|
dim_in, |
|
dim_head=self_attn_dim_head, |
|
heads=self_attn_heads, |
|
depth=layer_attn_depth, |
|
) |
|
if layer_full_attn |
|
else None |
|
), |
|
] |
|
), |
|
] |
|
) |
|
) |
|
|
|
self.out_dim = default(out_dim, channels) |
|
self.final_res_block = block_klass(dim, dim) |
|
self.final_to_rgb = nn.Conv2d(dim, channels, 1) |
|
self.resize_mode = resize_mode |
|
self.style_to_conv_modulations = nn.Linear( |
|
style_network.dim_out, sum(style_embed_split_dims) |
|
) |
|
self.style_embed_split_dims = style_embed_split_dims |
|
|
|
@property |
|
def allowable_rgb_resolutions(self): |
|
input_res_base = int(log2(self.input_image_size)) |
|
output_res_base = int(log2(self.image_size)) |
|
allowed_rgb_res_base = list(range(input_res_base, output_res_base)) |
|
return [*map(lambda p: 2**p, allowed_rgb_res_base)] |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
@property |
|
def total_params(self): |
|
return sum([p.numel() for p in self.parameters()]) |
|
|
|
def resize_image_to(self, x, size): |
|
return F.interpolate(x, (size, size), mode=self.resize_mode) |
|
|
|
def forward( |
|
self, |
|
lowres_image: torch.Tensor, |
|
styles: Optional[torch.Tensor] = None, |
|
noise: Optional[torch.Tensor] = None, |
|
global_text_tokens: Optional[torch.Tensor] = None, |
|
return_all_rgbs: bool = False, |
|
): |
|
x = lowres_image |
|
|
|
noise_scale = 0.001 |
|
noise_aug = torch.randn_like(x) * noise_scale |
|
x = x + noise_aug |
|
x = x.clamp(0, 1) |
|
|
|
shape = x.shape |
|
batch_size = shape[0] |
|
|
|
assert shape[-2:] == ((self.input_image_size,) * 2) |
|
|
|
|
|
if not exists(styles): |
|
assert exists(self.style_network) |
|
|
|
noise = default( |
|
noise, |
|
torch.randn( |
|
(batch_size, self.style_network.dim_in), device=self.device |
|
), |
|
) |
|
styles = self.style_network(noise, global_text_tokens) |
|
|
|
|
|
conv_mods = self.style_to_conv_modulations(styles) |
|
conv_mods = conv_mods.split(self.style_embed_split_dims, dim=-1) |
|
conv_mods = iter(conv_mods) |
|
|
|
x = self.init_conv(x) |
|
|
|
h = [] |
|
for blocks, (attn, downsample) in self.downs: |
|
for block in blocks: |
|
x = block(x, conv_mods_iter=conv_mods) |
|
h.append(x) |
|
|
|
if attn is not None: |
|
x = attn(x) |
|
|
|
x = downsample(x) |
|
|
|
x = self.mid_block1(x, conv_mods_iter=conv_mods) |
|
x = self.mid_attn(x) |
|
x = self.mid_block2(x, conv_mods_iter=conv_mods) |
|
|
|
for ( |
|
blocks, |
|
( |
|
upsample, |
|
attn, |
|
), |
|
) in self.ups: |
|
x = upsample(x) |
|
for block in blocks: |
|
if h != []: |
|
res = h.pop() |
|
res = res * self.skip_connect_scale |
|
x = torch.cat((x, res), dim=1) |
|
|
|
x = block(x, conv_mods_iter=conv_mods) |
|
|
|
if attn is not None: |
|
x = attn(x) |
|
|
|
x = self.final_res_block(x, conv_mods_iter=conv_mods) |
|
rgb = self.final_to_rgb(x) |
|
|
|
if not return_all_rgbs: |
|
return rgb |
|
|
|
return rgb, [] |
|
|
|
|
|
def tile_image(image, chunk_size=64): |
|
c, h, w = image.shape |
|
h_chunks = ceil(h / chunk_size) |
|
w_chunks = ceil(w / chunk_size) |
|
tiles = [] |
|
for i in range(h_chunks): |
|
for j in range(w_chunks): |
|
tile = image[ |
|
:, |
|
i * chunk_size : (i + 1) * chunk_size, |
|
j * chunk_size : (j + 1) * chunk_size, |
|
] |
|
tiles.append(tile) |
|
return tiles, h_chunks, w_chunks |
|
|
|
|
|
|
|
def create_checkerboard_weights(tile_size): |
|
x = torch.linspace(-1, 1, tile_size) |
|
y = torch.linspace(-1, 1, tile_size) |
|
|
|
x, y = torch.meshgrid(x, y, indexing="ij") |
|
d = torch.sqrt(x * x + y * y) |
|
sigma, mu = 0.5, 0.0 |
|
weights = torch.exp(-((d - mu) ** 2 / (2.0 * sigma**2))) |
|
|
|
|
|
weights = weights**8 |
|
|
|
return weights / weights.max() |
|
|
|
|
|
def repeat_weights(weights, image_size): |
|
tile_size = weights.shape[0] |
|
repeats = ( |
|
math.ceil(image_size[0] / tile_size), |
|
math.ceil(image_size[1] / tile_size), |
|
) |
|
return weights.repeat(repeats)[: image_size[0], : image_size[1]] |
|
|
|
|
|
def create_offset_weights(weights, image_size): |
|
tile_size = weights.shape[0] |
|
offset = tile_size // 2 |
|
full_weights = repeat_weights( |
|
weights, (image_size[0] + offset, image_size[1] + offset) |
|
) |
|
return full_weights[offset:, offset:] |
|
|
|
|
|
def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64): |
|
|
|
c = tiles[0].shape[0] |
|
h = h_chunks * chunk_size |
|
w = w_chunks * chunk_size |
|
|
|
|
|
merged = torch.zeros((c, h, w), dtype=tiles[0].dtype) |
|
|
|
|
|
for idx, tile in enumerate(tiles): |
|
i = idx // w_chunks |
|
j = idx % w_chunks |
|
|
|
h_start = i * chunk_size |
|
w_start = j * chunk_size |
|
|
|
tile_h, tile_w = tile.shape[1:] |
|
merged[:, h_start : h_start + tile_h, w_start : w_start + tile_w] = tile |
|
|
|
return merged |
|
|
|
|
|
class AuraSR: |
|
def __init__(self, config: dict[str, Any], device: str = "cuda"): |
|
self.upsampler = UnetUpsampler(**config).to(device) |
|
self.input_image_size = config["input_image_size"] |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
model_id: str = "fal-ai/AuraSR", |
|
use_safetensors: bool = True, |
|
device: str = "cuda", |
|
): |
|
import json |
|
import torch |
|
from pathlib import Path |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
if Path(model_id).is_file(): |
|
local_file = Path(model_id) |
|
if local_file.suffix == ".safetensors": |
|
use_safetensors = True |
|
elif local_file.suffix == ".ckpt": |
|
use_safetensors = False |
|
else: |
|
raise ValueError( |
|
f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files." |
|
) |
|
|
|
|
|
config_path = local_file.with_name("config.json") |
|
if not config_path.exists(): |
|
raise FileNotFoundError( |
|
f"Config file not found: {config_path}. " |
|
f"When loading from a local file, ensure that 'config.json' " |
|
f"is present in the same directory as '{local_file.name}'. " |
|
f"If you're trying to load a model from Hugging Face, " |
|
f"please provide the model ID instead of a file path." |
|
) |
|
|
|
config = json.loads(config_path.read_text()) |
|
hf_model_path = local_file.parent |
|
else: |
|
hf_model_path = Path( |
|
snapshot_download(model_id, ignore_patterns=["*.ckpt"]) |
|
) |
|
config = json.loads((hf_model_path / "config.json").read_text()) |
|
|
|
model = cls(config, device) |
|
|
|
if use_safetensors: |
|
try: |
|
from safetensors.torch import load_file |
|
|
|
checkpoint = load_file( |
|
hf_model_path / "model.safetensors" |
|
if not Path(model_id).is_file() |
|
else model_id |
|
) |
|
except ImportError: |
|
raise ImportError( |
|
"The safetensors library is not installed. " |
|
"Please install it with `pip install safetensors` " |
|
"or use `use_safetensors=False` to load the model with PyTorch." |
|
) |
|
else: |
|
checkpoint = torch.load( |
|
hf_model_path / "model.ckpt" |
|
if not Path(model_id).is_file() |
|
else model_id |
|
) |
|
|
|
model.upsampler.load_state_dict(checkpoint, strict=True) |
|
return model |
|
|
|
@torch.no_grad() |
|
def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image: |
|
tensor_transform = transforms.ToTensor() |
|
device = self.upsampler.device |
|
|
|
image_tensor = tensor_transform(image).unsqueeze(0) |
|
_, _, h, w = image_tensor.shape |
|
pad_h = ( |
|
self.input_image_size - h % self.input_image_size |
|
) % self.input_image_size |
|
pad_w = ( |
|
self.input_image_size - w % self.input_image_size |
|
) % self.input_image_size |
|
|
|
|
|
image_tensor = torch.nn.functional.pad( |
|
image_tensor, (0, pad_w, 0, pad_h), mode="reflect" |
|
).squeeze(0) |
|
tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size) |
|
|
|
|
|
num_tiles = len(tiles) |
|
batches = [ |
|
tiles[i : i + max_batch_size] for i in range(0, num_tiles, max_batch_size) |
|
] |
|
reconstructed_tiles = [] |
|
|
|
for batch in batches: |
|
model_input = torch.stack(batch).to(device) |
|
generator_output = self.upsampler( |
|
lowres_image=model_input, |
|
noise=torch.randn(model_input.shape[0], 128, device=device), |
|
) |
|
reconstructed_tiles.extend( |
|
list(generator_output.clamp_(0, 1).detach().cpu()) |
|
) |
|
|
|
merged_tensor = merge_tiles( |
|
reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4 |
|
) |
|
unpadded = merged_tensor[:, : h * 4, : w * 4] |
|
|
|
to_pil = transforms.ToPILImage() |
|
return to_pil(unpadded) |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def upscale_4x_overlapped(self, image, max_batch_size=8, weight_type="checkboard"): |
|
tensor_transform = transforms.ToTensor() |
|
device = self.upsampler.device |
|
|
|
image_tensor = tensor_transform(image).unsqueeze(0) |
|
_, _, h, w = image_tensor.shape |
|
|
|
|
|
pad_h = ( |
|
self.input_image_size - h % self.input_image_size |
|
) % self.input_image_size |
|
pad_w = ( |
|
self.input_image_size - w % self.input_image_size |
|
) % self.input_image_size |
|
|
|
|
|
image_tensor = torch.nn.functional.pad( |
|
image_tensor, (0, pad_w, 0, pad_h), mode="reflect" |
|
).squeeze(0) |
|
|
|
|
|
def process_tiles(tiles, h_chunks, w_chunks): |
|
num_tiles = len(tiles) |
|
batches = [ |
|
tiles[i : i + max_batch_size] |
|
for i in range(0, num_tiles, max_batch_size) |
|
] |
|
reconstructed_tiles = [] |
|
|
|
for batch in batches: |
|
model_input = torch.stack(batch).to(device) |
|
generator_output = self.upsampler( |
|
lowres_image=model_input, |
|
noise=torch.randn(model_input.shape[0], 128, device=device), |
|
) |
|
reconstructed_tiles.extend( |
|
list(generator_output.clamp_(0, 1).detach().cpu()) |
|
) |
|
|
|
return merge_tiles( |
|
reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4 |
|
) |
|
|
|
|
|
tiles1, h_chunks1, w_chunks1 = tile_image(image_tensor, self.input_image_size) |
|
result1 = process_tiles(tiles1, h_chunks1, w_chunks1) |
|
|
|
|
|
offset = self.input_image_size // 2 |
|
image_tensor_offset = torch.nn.functional.pad( |
|
image_tensor, (offset, offset, offset, offset), mode="reflect" |
|
).squeeze(0) |
|
|
|
tiles2, h_chunks2, w_chunks2 = tile_image( |
|
image_tensor_offset, self.input_image_size |
|
) |
|
result2 = process_tiles(tiles2, h_chunks2, w_chunks2) |
|
|
|
|
|
offset_4x = offset * 4 |
|
result2_interior = result2[:, offset_4x:-offset_4x, offset_4x:-offset_4x] |
|
|
|
if weight_type == "checkboard": |
|
weight_tile = create_checkerboard_weights(self.input_image_size * 4) |
|
|
|
weight_shape = result2_interior.shape[1:] |
|
weights_1 = create_offset_weights(weight_tile, weight_shape) |
|
weights_2 = repeat_weights(weight_tile, weight_shape) |
|
|
|
normalizer = weights_1 + weights_2 |
|
weights_1 = weights_1 / normalizer |
|
weights_2 = weights_2 / normalizer |
|
|
|
weights_1 = weights_1.unsqueeze(0).repeat(3, 1, 1) |
|
weights_2 = weights_2.unsqueeze(0).repeat(3, 1, 1) |
|
elif weight_type == "constant": |
|
weights_1 = torch.ones_like(result2_interior) * 0.5 |
|
weights_2 = weights_1 |
|
else: |
|
raise ValueError( |
|
"weight_type should be either 'gaussian' or 'constant' but got", |
|
weight_type, |
|
) |
|
|
|
result1 = result1 * weights_2 |
|
result2 = result2_interior * weights_1 |
|
|
|
|
|
result1 = result1 + result2 |
|
|
|
|
|
unpadded = result1[:, : h * 4, : w * 4] |
|
|
|
to_pil = transforms.ToPILImage() |
|
return to_pil(unpadded) |
|
|