import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin import torchvision.transforms as tr import functools class ResidualBlock(nn.Module): def __init__(self, in_features): super(ResidualBlock, self).__init__() self.block = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features), ) def forward(self, x): return x + self.block(x) def generator(num_residual_blocks=9): channels = 3 out_features = 64 model = [ nn.ReflectionPad2d(channels), nn.Conv2d(channels, out_features, 7), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # Downsampling for _ in range(2): out_features *= 2 model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # Residual blocks for _ in range(num_residual_blocks): model += [ResidualBlock(out_features)] # Upsampling for _ in range(2): out_features //= 2 model += [ nn.Upsample(scale_factor=2), nn.Conv2d(in_features, out_features, 3, stride=1, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # Output layer model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()] return nn.Sequential(*model) class CycleGAN(nn.Module, PyTorchModelHubMixin, pipeline_tag="image-to-image"): def __init__(self, channel_mean_a=None, channel_std_a=None, channel_mean_b=None, channel_std_b=None): super(CycleGAN, self).__init__() self.generator_ab = generator() self.generator_ba = generator() # Store normalization parameters as non-trainable parameters self.register_buffer('channel_mean_a', torch.tensor(channel_mean_a if channel_mean_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32)) self.register_buffer('channel_std_a', torch.tensor(channel_std_a if channel_std_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32)) self.register_buffer('channel_mean_b', torch.tensor(channel_mean_b if channel_mean_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32)) self.register_buffer('channel_std_b', torch.tensor(channel_std_b if channel_std_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32)) def get_val_transform(model, direction="a_to_b", size=256): mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b return tr.Compose([ tr.ToPILImage(), tr.Resize(size), tr.CenterCrop(size), tr.ToTensor(), tr.Normalize(mean=mean.tolist(), std=std.tolist()), ]) def de_normalize(tensor, model, direction="a_to_b"): img_tensor = tensor mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b img_tensor = img_tensor * std[:, None, None] + mean[:, None, None] return torch.clamp(img_tensor.permute(1, 2, 0) * 255.0, 0.0, 255.0).to(torch.uint8)