vfontech's picture
Uploading the app
587665f verified
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
try:
from kornia.morphology import opening
except ImportError:
from kornia.morphology import open as opening
from torchvision import transforms
from torchvision.utils import make_grid, save_image
from typing import Any
def exist(val: Any) -> bool:
return val is not None
def morph_open(x: torch.Tensor, k: int) -> torch.Tensor:
if k==0:
return x
else:
with torch.no_grad():
return opening(x, torch.ones(k,k,device=x.device))
def make_grid_images(images: list[torch.Tensor], **kwargs) -> torch.Tensor:
concatenated_images = torch.cat(images, dim=3)
grid_concatenated = make_grid(concatenated_images, **kwargs)
return grid_concatenated
def save_images(images: tuple[torch.Tensor, torch.Tensor], path: str, **kwargs) -> None:
gen, real = images
concatenated_images = torch.cat((gen, real), dim=3)
grid_concatenated = make_grid(concatenated_images, **kwargs)
ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy()
ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8)
save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path)
def save_triplet(images: tuple[torch.Tensor, ...], path: str, **kwargs) -> None:
concatenated_images = torch.cat(images, dim=3)
grid_concatenated = make_grid(concatenated_images, **kwargs)
ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy()
ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8)
save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path)
def plot_images(images: torch.Tensor) -> None:
plt.figure(figsize=(32, 32))
plt.imshow(torch.cat([
torch.cat([i for i in images.cpu()], dim=-1),
], dim=-2).permute(1, 2, 0).cpu())
plt.show()
def make_graphic(metric_name: str, metrics: list[torch.Tensor], path: str) -> None:
plt.figure(figsize=(32, 32))
metrics = [m.cpu().numpy() for m in metrics]
plt.plot(metrics)
plt.title(metric_name)
plt.xlabel("Epoch")
plt.ylabel(metric_name)
path = os.path.join(path, f"{metric_name}.png")
plt.savefig(path)
plt.close()
def norm(
img: torch.Tensor,
mean: list[float] = [0.5, 0.5, 0.5],
std: list[float] = [0.5, 0.5, 0.5]
) -> torch.Tensor:
normalize = transforms.Normalize(mean, std)
return normalize(img)
def denorm(
img: torch.Tensor,
mean: list[float] = [0.5, 0.5, 0.5],
std: list[float] = [0.5, 0.5, 0.5]
) -> torch.Tensor:
mean = torch.tensor(mean, device=img.device)
std = torch.tensor(std, device=img.device)
return img*std[None][...,None,None] + mean[None][...,None,None]