File size: 2,755 Bytes
587665f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

try:
    from kornia.morphology import opening
except ImportError:
    from kornia.morphology import open as opening

from torchvision import transforms
from torchvision.utils import make_grid, save_image

from typing import Any

def exist(val: Any) -> bool:
    return val is not None

def morph_open(x: torch.Tensor, k: int) -> torch.Tensor:
    if k==0:
        return x
    else:
        with torch.no_grad():
            return opening(x, torch.ones(k,k,device=x.device))

def make_grid_images(images: list[torch.Tensor], **kwargs) -> torch.Tensor:
    concatenated_images = torch.cat(images, dim=3)
    grid_concatenated = make_grid(concatenated_images, **kwargs)
    return grid_concatenated

def save_images(images: tuple[torch.Tensor, torch.Tensor], path: str, **kwargs) -> None:
    gen, real = images
    concatenated_images = torch.cat((gen, real), dim=3)
    grid_concatenated = make_grid(concatenated_images, **kwargs)

    ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy()
    ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8)

    save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path)

def save_triplet(images: tuple[torch.Tensor, ...], path: str, **kwargs) -> None:
    concatenated_images = torch.cat(images, dim=3)
    grid_concatenated = make_grid(concatenated_images, **kwargs)
    
    ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy()
    ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8)

    save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path)

def plot_images(images: torch.Tensor) -> None:
    plt.figure(figsize=(32, 32))
    plt.imshow(torch.cat([
        torch.cat([i for i in images.cpu()], dim=-1),
    ], dim=-2).permute(1, 2, 0).cpu())
    plt.show()

def make_graphic(metric_name: str, metrics: list[torch.Tensor], path: str) -> None:
    plt.figure(figsize=(32, 32))
    metrics = [m.cpu().numpy() for m in metrics]
    plt.plot(metrics)
    plt.title(metric_name)
    plt.xlabel("Epoch")
    plt.ylabel(metric_name)
    path = os.path.join(path, f"{metric_name}.png")
    plt.savefig(path)
    plt.close()

def norm(
    img: torch.Tensor, 
    mean: list[float] = [0.5, 0.5, 0.5], 
    std: list[float] = [0.5, 0.5, 0.5]
) -> torch.Tensor:
    normalize = transforms.Normalize(mean, std)
    return normalize(img)

def denorm(
    img: torch.Tensor, 
    mean: list[float] = [0.5, 0.5, 0.5], 
    std: list[float] = [0.5, 0.5, 0.5]
) -> torch.Tensor:
    mean = torch.tensor(mean, device=img.device)
    std = torch.tensor(std, device=img.device)
    return img*std[None][...,None,None] + mean[None][...,None,None]