Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from huggingface_hub import PyTorchModelHubMixin | |
import torchvision.transforms as tr | |
import functools | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_features): | |
super(ResidualBlock, self).__init__() | |
self.block = nn.Sequential( | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(in_features, in_features, 3), | |
nn.InstanceNorm2d(in_features), | |
nn.ReLU(inplace=True), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(in_features, in_features, 3), | |
nn.InstanceNorm2d(in_features), | |
) | |
def forward(self, x): | |
return x + self.block(x) | |
def generator(num_residual_blocks=9): | |
channels = 3 | |
out_features = 64 | |
model = [ | |
nn.ReflectionPad2d(channels), | |
nn.Conv2d(channels, out_features, 7), | |
nn.InstanceNorm2d(out_features), | |
nn.ReLU(inplace=True), | |
] | |
in_features = out_features | |
# Downsampling | |
for _ in range(2): | |
out_features *= 2 | |
model += [ | |
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), | |
nn.InstanceNorm2d(out_features), | |
nn.ReLU(inplace=True), | |
] | |
in_features = out_features | |
# Residual blocks | |
for _ in range(num_residual_blocks): | |
model += [ResidualBlock(out_features)] | |
# Upsampling | |
for _ in range(2): | |
out_features //= 2 | |
model += [ | |
nn.Upsample(scale_factor=2), | |
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1), | |
nn.InstanceNorm2d(out_features), | |
nn.ReLU(inplace=True), | |
] | |
in_features = out_features | |
# Output layer | |
model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()] | |
return nn.Sequential(*model) | |
class CycleGAN(nn.Module, PyTorchModelHubMixin, pipeline_tag="image-to-image"): | |
def __init__(self, channel_mean_a=None, channel_std_a=None, channel_mean_b=None, channel_std_b=None): | |
super(CycleGAN, self).__init__() | |
self.generator_ab = generator() | |
self.generator_ba = generator() | |
# Store normalization parameters as non-trainable parameters | |
self.register_buffer('channel_mean_a', torch.tensor(channel_mean_a if channel_mean_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32)) | |
self.register_buffer('channel_std_a', torch.tensor(channel_std_a if channel_std_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32)) | |
self.register_buffer('channel_mean_b', torch.tensor(channel_mean_b if channel_mean_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32)) | |
self.register_buffer('channel_std_b', torch.tensor(channel_std_b if channel_std_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32)) | |
def get_val_transform(model, direction="a_to_b", size=256): | |
mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b | |
std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b | |
return tr.Compose([ | |
tr.ToPILImage(), | |
tr.Resize(size), | |
tr.CenterCrop(size), | |
tr.ToTensor(), | |
tr.Normalize(mean=mean.tolist(), std=std.tolist()), | |
]) | |
def de_normalize(tensor, model, direction="a_to_b"): | |
img_tensor = tensor | |
mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b | |
std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b | |
img_tensor = img_tensor * std[:, None, None] + mean[:, None, None] | |
return torch.clamp(img_tensor.permute(1, 2, 0) * 255.0, 0.0, 255.0).to(torch.uint8) | |