import os import torch import numpy as np import matplotlib.pyplot as plt try: from kornia.morphology import opening except ImportError: from kornia.morphology import open as opening from torchvision import transforms from torchvision.utils import make_grid, save_image from typing import Any def exist(val: Any) -> bool: return val is not None def morph_open(x: torch.Tensor, k: int) -> torch.Tensor: if k==0: return x else: with torch.no_grad(): return opening(x, torch.ones(k,k,device=x.device)) def make_grid_images(images: list[torch.Tensor], **kwargs) -> torch.Tensor: concatenated_images = torch.cat(images, dim=3) grid_concatenated = make_grid(concatenated_images, **kwargs) return grid_concatenated def save_images(images: tuple[torch.Tensor, torch.Tensor], path: str, **kwargs) -> None: gen, real = images concatenated_images = torch.cat((gen, real), dim=3) grid_concatenated = make_grid(concatenated_images, **kwargs) ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy() ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8) save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path) def save_triplet(images: tuple[torch.Tensor, ...], path: str, **kwargs) -> None: concatenated_images = torch.cat(images, dim=3) grid_concatenated = make_grid(concatenated_images, **kwargs) ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy() ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8) save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path) def plot_images(images: torch.Tensor) -> None: plt.figure(figsize=(32, 32)) plt.imshow(torch.cat([ torch.cat([i for i in images.cpu()], dim=-1), ], dim=-2).permute(1, 2, 0).cpu()) plt.show() def make_graphic(metric_name: str, metrics: list[torch.Tensor], path: str) -> None: plt.figure(figsize=(32, 32)) metrics = [m.cpu().numpy() for m in metrics] plt.plot(metrics) plt.title(metric_name) plt.xlabel("Epoch") plt.ylabel(metric_name) path = os.path.join(path, f"{metric_name}.png") plt.savefig(path) plt.close() def norm( img: torch.Tensor, mean: list[float] = [0.5, 0.5, 0.5], std: list[float] = [0.5, 0.5, 0.5] ) -> torch.Tensor: normalize = transforms.Normalize(mean, std) return normalize(img) def denorm( img: torch.Tensor, mean: list[float] = [0.5, 0.5, 0.5], std: list[float] = [0.5, 0.5, 0.5] ) -> torch.Tensor: mean = torch.tensor(mean, device=img.device) std = torch.tensor(std, device=img.device) return img*std[None][...,None,None] + mean[None][...,None,None]