cyclegan-streamlit / model.py
waleko's picture
fix bugs
1121140
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
import torchvision.transforms as tr
import functools
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
)
def forward(self, x):
return x + self.block(x)
def generator(num_residual_blocks=9):
channels = 3
out_features = 64
model = [
nn.ReflectionPad2d(channels),
nn.Conv2d(channels, out_features, 7),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Downsampling
for _ in range(2):
out_features *= 2
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Residual blocks
for _ in range(num_residual_blocks):
model += [ResidualBlock(out_features)]
# Upsampling
for _ in range(2):
out_features //= 2
model += [
nn.Upsample(scale_factor=2),
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Output layer
model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
return nn.Sequential(*model)
class CycleGAN(nn.Module, PyTorchModelHubMixin, pipeline_tag="image-to-image"):
def __init__(self, channel_mean_a=None, channel_std_a=None, channel_mean_b=None, channel_std_b=None):
super(CycleGAN, self).__init__()
self.generator_ab = generator()
self.generator_ba = generator()
# Store normalization parameters as non-trainable parameters
self.register_buffer('channel_mean_a', torch.tensor(channel_mean_a if channel_mean_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
self.register_buffer('channel_std_a', torch.tensor(channel_std_a if channel_std_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
self.register_buffer('channel_mean_b', torch.tensor(channel_mean_b if channel_mean_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
self.register_buffer('channel_std_b', torch.tensor(channel_std_b if channel_std_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
def get_val_transform(model, direction="a_to_b", size=256):
mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
return tr.Compose([
tr.ToPILImage(),
tr.Resize(size),
tr.CenterCrop(size),
tr.ToTensor(),
tr.Normalize(mean=mean.tolist(), std=std.tolist()),
])
def de_normalize(tensor, model, direction="a_to_b"):
img_tensor = tensor
mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
img_tensor = img_tensor * std[:, None, None] + mean[:, None, None]
return torch.clamp(img_tensor.permute(1, 2, 0) * 255.0, 0.0, 255.0).to(torch.uint8)