File size: 3,607 Bytes
7cf938c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1121140
7cf938c
 
 
1121140
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
import torchvision.transforms as tr
import functools

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

def generator(num_residual_blocks=9):
    channels = 3
    out_features = 64
    model = [
        nn.ReflectionPad2d(channels),
        nn.Conv2d(channels, out_features, 7),
        nn.InstanceNorm2d(out_features),
        nn.ReLU(inplace=True),
    ]
    in_features = out_features

    # Downsampling
    for _ in range(2):
        out_features *= 2
        model += [
            nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features

    # Residual blocks
    for _ in range(num_residual_blocks):
        model += [ResidualBlock(out_features)]

    # Upsampling
    for _ in range(2):
        out_features //= 2
        model += [
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features

    # Output layer
    model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]

    return nn.Sequential(*model)

class CycleGAN(nn.Module, PyTorchModelHubMixin, pipeline_tag="image-to-image"):
    def __init__(self, channel_mean_a=None, channel_std_a=None, channel_mean_b=None, channel_std_b=None):
        super(CycleGAN, self).__init__()
        self.generator_ab = generator()
        self.generator_ba = generator()
        
        # Store normalization parameters as non-trainable parameters
        self.register_buffer('channel_mean_a', torch.tensor(channel_mean_a if channel_mean_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
        self.register_buffer('channel_std_a', torch.tensor(channel_std_a if channel_std_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
        self.register_buffer('channel_mean_b', torch.tensor(channel_mean_b if channel_mean_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
        self.register_buffer('channel_std_b', torch.tensor(channel_std_b if channel_std_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32))

def get_val_transform(model, direction="a_to_b", size=256):
    mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
    std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
    return tr.Compose([
        tr.ToPILImage(),
        tr.Resize(size),
        tr.CenterCrop(size),
        tr.ToTensor(),
        tr.Normalize(mean=mean.tolist(), std=std.tolist()),
    ])

def de_normalize(tensor, model, direction="a_to_b"):
    img_tensor = tensor
    mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
    std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
    img_tensor = img_tensor * std[:, None, None] + mean[:, None, None]
    return torch.clamp(img_tensor.permute(1, 2, 0) * 255.0, 0.0, 255.0).to(torch.uint8)