VicFonch
commited on
deleting unecessary files/scripts
Browse files- config/confg.yaml +0 -64
- model/train_pipline.py +0 -177
- utils/ema.py +0 -32
- utils/inter_frame_idx.py +0 -123
- utils/raft.py +0 -20
- utils/uncertainty.py +0 -49
config/confg.yaml
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
data_confg:
|
2 |
-
train_batch_size: 6
|
3 |
-
val_batch_size: 6
|
4 |
-
test_batch_size: 6
|
5 |
-
flow_method: raft
|
6 |
-
data_domain: animation
|
7 |
-
datamodule_confg:
|
8 |
-
mean: [0.5, 0.5, 0.5]
|
9 |
-
sd: [0.5, 0.5, 0.5]
|
10 |
-
size: [256, 448]
|
11 |
-
amount_augmentations: 1
|
12 |
-
horizontal_flip: 0.5
|
13 |
-
time_flip: True
|
14 |
-
rotation: 0
|
15 |
-
brightness: 0.2
|
16 |
-
contrast: 0.2
|
17 |
-
saturation: 0.2
|
18 |
-
hue: 0.1
|
19 |
-
|
20 |
-
trainer_confg:
|
21 |
-
accumulate_grad_batches: 5
|
22 |
-
gradient_clip_val: 1.0
|
23 |
-
max_epochs: 500
|
24 |
-
num_nodes: 1
|
25 |
-
devices: 2
|
26 |
-
accelerator: gpu
|
27 |
-
strategy: ddp_find_unused_parameters_true
|
28 |
-
|
29 |
-
optim_confg:
|
30 |
-
optimizer_confg: # AdamW
|
31 |
-
lr: 1.0e-4
|
32 |
-
betas: [0.9, 0.999]
|
33 |
-
eps: 1.0e-8
|
34 |
-
scheduler_confg: # ReduceLROnPlateau
|
35 |
-
mode: min
|
36 |
-
factor: 0.5
|
37 |
-
patience: 3
|
38 |
-
verbose: True
|
39 |
-
|
40 |
-
pretrained_model_path: null # Fine-tune model path
|
41 |
-
|
42 |
-
model_confg:
|
43 |
-
kappa: 2.0
|
44 |
-
timesteps: 20
|
45 |
-
p: 0.3
|
46 |
-
etas_end: 0.99
|
47 |
-
min_noise_level: 0.04
|
48 |
-
flow_model: raft
|
49 |
-
flow_kwargs:
|
50 |
-
pretrained_path: null #_pretrain_models/anime_interp_full.ckpt
|
51 |
-
warping_kwargs:
|
52 |
-
in_channels: 3
|
53 |
-
channels: [128, 256, 384, 512]
|
54 |
-
synthesis_kwargs:
|
55 |
-
in_channels: 3
|
56 |
-
channels: [128, 256, 384, 512]
|
57 |
-
temb_channels: 512
|
58 |
-
heads: 1
|
59 |
-
window_size: 8
|
60 |
-
window_attn: True
|
61 |
-
grid_attn: True
|
62 |
-
expansion_rate: 1.5
|
63 |
-
num_conv_blocks: 1
|
64 |
-
dropout: 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/train_pipline.py
DELETED
@@ -1,177 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import copy
|
3 |
-
import matplotlib.pyplot as plt
|
4 |
-
from typing import Any
|
5 |
-
|
6 |
-
import torch
|
7 |
-
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
8 |
-
from torch.optim import AdamW, Optimizer
|
9 |
-
from torch.utils.data import DataLoader
|
10 |
-
from lightning import LightningModule
|
11 |
-
|
12 |
-
from torchmetrics import MetricCollection
|
13 |
-
from torchmetrics.image import PeakSignalNoiseRatio as PSNR
|
14 |
-
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
|
15 |
-
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
|
16 |
-
|
17 |
-
from model.model import MultiInputResShift
|
18 |
-
|
19 |
-
from utils.utils import denorm, make_grid_images#, save_triplet
|
20 |
-
from utils.ema import EMA
|
21 |
-
from utils.inter_frame_idx import get_inter_frame_temp_index
|
22 |
-
from utils.raft import raft_flow
|
23 |
-
|
24 |
-
|
25 |
-
class TrainPipline(LightningModule):
|
26 |
-
def __init__(self,
|
27 |
-
confg: dict,
|
28 |
-
test_dataloader: DataLoader):
|
29 |
-
super(TrainPipline, self).__init__()
|
30 |
-
|
31 |
-
self.test_dataloader = test_dataloader
|
32 |
-
|
33 |
-
self.confg = confg
|
34 |
-
|
35 |
-
self.mean, self.sd = confg["data_confg"]["mean"], confg["data_confg"]["sd"]
|
36 |
-
|
37 |
-
self.model = MultiInputResShift(**confg["model_confg"])
|
38 |
-
self.model.flow_model.requires_grad_(False).eval()
|
39 |
-
|
40 |
-
self.ema = EMA(beta=0.995)
|
41 |
-
self.ema_model = copy.deepcopy(self.model).eval().requires_grad_(False)
|
42 |
-
|
43 |
-
self.charbonnier_loss = lambda x, y: torch.mean(torch.sqrt((x - y)**2 + 1e-6))
|
44 |
-
self.lpips_loss = LPIPS(net_type='vgg')
|
45 |
-
|
46 |
-
self.train_metrics = MetricCollection({
|
47 |
-
"train_lpips": LPIPS(net_type='alex'),
|
48 |
-
"train_psnr": PSNR(),
|
49 |
-
"train_ssim": SSIM()
|
50 |
-
})
|
51 |
-
self.val_metrics = MetricCollection({
|
52 |
-
"val_lpips": LPIPS(net_type='alex'),
|
53 |
-
"val_psnr": PSNR(),
|
54 |
-
"val_ssim": SSIM()
|
55 |
-
})
|
56 |
-
|
57 |
-
def loss_fn(self,
|
58 |
-
x: torch.Tensor,
|
59 |
-
predicted_x: torch.Tensor) -> torch.Tensor:
|
60 |
-
percep_loss = 0.2 * self.lpips_loss(x, predicted_x.clamp(-1, 1))
|
61 |
-
pix2pix_loss = self.charbonnier_loss(x, predicted_x)
|
62 |
-
return percep_loss + pix2pix_loss
|
63 |
-
|
64 |
-
def sample_t(self,
|
65 |
-
shape: tuple[int, ...],
|
66 |
-
max_t: int,
|
67 |
-
device: torch.device) -> torch.Tensor:
|
68 |
-
p = torch.linspace(1, max_t, steps=max_t, device=device) ** 2
|
69 |
-
p = p / p.sum()
|
70 |
-
t = torch.multinomial(p, num_samples=shape[0], replacement=True)
|
71 |
-
return t
|
72 |
-
|
73 |
-
def forward(self,
|
74 |
-
I0: torch.Tensor,
|
75 |
-
It: torch.Tensor,
|
76 |
-
I1: torch.Tensor) -> torch.Tensor:
|
77 |
-
flow0tot = raft_flow(I0, It, 'animation')
|
78 |
-
flow1tot = raft_flow(I1, It, 'animation')
|
79 |
-
mid_idx = get_inter_frame_temp_index(I0, It, I1, flow0tot, flow1tot).to(It.dtype)
|
80 |
-
|
81 |
-
tau = torch.stack([mid_idx, 1 - mid_idx], dim=1)
|
82 |
-
|
83 |
-
if self.current_epoch > 5:
|
84 |
-
t = torch.randint(low=1, high=self.model.timesteps, size=(It.shape[0],), device=It.device, dtype=torch.long)
|
85 |
-
else:
|
86 |
-
t = self.sample_t(shape=(It.shape[0],), max_t=self.model.timesteps, device=It.device)
|
87 |
-
|
88 |
-
predicted_It = self.model(I0, It, I1, tau=tau, t=t)
|
89 |
-
return predicted_It
|
90 |
-
|
91 |
-
def get_step_plt_images(self,
|
92 |
-
It: torch.Tensor,
|
93 |
-
predicted_It: torch.Tensor) -> plt.Figure:
|
94 |
-
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
|
95 |
-
ax[0].imshow(denorm(predicted_It.clamp(-1, 1), self.mean, self.sd)[0].permute(1, 2, 0).cpu().numpy())
|
96 |
-
ax[0].axis("off")
|
97 |
-
ax[0].set_title("Predicted")
|
98 |
-
ax[1].imshow(denorm(It, self.mean, self.sd)[0].permute(1, 2, 0).cpu().numpy())
|
99 |
-
ax[1].axis("off")
|
100 |
-
ax[1].set_title("Ground Truth")
|
101 |
-
plt.tight_layout()
|
102 |
-
#img_path = "step_image.png"
|
103 |
-
#fig.savefig(img_path, dpi=300, bbox_inches='tight')
|
104 |
-
plt.close(fig)
|
105 |
-
return fig
|
106 |
-
|
107 |
-
def training_step(self, batch: tuple[torch.Tensor, ...], _) -> torch.Tensor:
|
108 |
-
I0, It, I1 = batch
|
109 |
-
predicted_It = self(I0, It, I1)
|
110 |
-
loss = self.loss_fn(It, predicted_It)
|
111 |
-
|
112 |
-
self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
|
113 |
-
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
|
114 |
-
|
115 |
-
self.ema.step_ema(self.ema_model, self.model)
|
116 |
-
with torch.inference_mode():
|
117 |
-
fig = self.get_step_plt_images(It, predicted_It)
|
118 |
-
self.logger.experiment.add_figure("Train Predictions", fig, self.global_step)
|
119 |
-
mets = self.train_metrics(It, predicted_It.clamp(-1, 1))
|
120 |
-
self.log_dict(mets, prog_bar=True, on_step=True,on_epoch=False)
|
121 |
-
return loss
|
122 |
-
|
123 |
-
@torch.no_grad()
|
124 |
-
def validation_step(self, batch: tuple[torch.Tensor, ...], _) -> None:
|
125 |
-
I0, It, I1 = batch
|
126 |
-
predicted_It = self(I0, It, I1)
|
127 |
-
loss = self.loss_fn(It, predicted_It)
|
128 |
-
|
129 |
-
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
|
130 |
-
|
131 |
-
mets = self.val_metrics(It, predicted_It.clamp(-1, 1))
|
132 |
-
self.log_dict(mets, prog_bar=True, on_step=False, on_epoch=True)
|
133 |
-
|
134 |
-
@torch.inference_mode()
|
135 |
-
def on_train_epoch_end(self) -> None:
|
136 |
-
torch.save(self.ema_model.state_dict(),
|
137 |
-
os.path.join("_checkpoint", f"resshift_diff_{self.current_epoch}.pth"))
|
138 |
-
|
139 |
-
batch = next(iter(self.test_dataloader))
|
140 |
-
I0, It, I1 = batch
|
141 |
-
I0, It, I1 = I0.to(self.device), It.to(self.device), I1.to(self.device)
|
142 |
-
|
143 |
-
flow0tot = raft_flow(I0, It, 'animation')
|
144 |
-
flow1tot = raft_flow(I1, It, 'animation')
|
145 |
-
mid_idx = get_inter_frame_temp_index(I0, It, I1, flow0tot, flow1tot).to(It.dtype)
|
146 |
-
tau = torch.stack([mid_idx, 1 - mid_idx], dim=1)
|
147 |
-
|
148 |
-
predicted_It = self.ema_model.reverse_process([I0, I1], tau)
|
149 |
-
|
150 |
-
I0 = denorm(I0, self.mean, self.sd)
|
151 |
-
I1 = denorm(I1, self.mean, self.sd)
|
152 |
-
It = denorm(It, self.mean, self.sd)
|
153 |
-
predicted_It = denorm(predicted_It.clamp(-1, 1), self.mean, self.sd)
|
154 |
-
|
155 |
-
#save_triplet([I0, It, predicted_It, I1], f"./_output/target_{self.current_epoch}.png", nrow=1)
|
156 |
-
grid = make_grid_images([I0, It, predicted_It, I1], nrow=1)
|
157 |
-
self.logger.experiment.add_image("Predicted Images", grid, self.global_step)
|
158 |
-
|
159 |
-
def configure_optimizers(self) -> tuple[list[Optimizer], list[dict[str, Any]]]:
|
160 |
-
optimizer = [AdamW(
|
161 |
-
self.model.parameters(),
|
162 |
-
**self.confg["optim_confg"]['optimizer_confg']
|
163 |
-
)]
|
164 |
-
|
165 |
-
scheduler = [{
|
166 |
-
'scheduler': ReduceLROnPlateau(
|
167 |
-
optimizer[0],
|
168 |
-
**self.confg["optim_confg"]['scheduler_confg']
|
169 |
-
),
|
170 |
-
'monitor': 'val_loss',
|
171 |
-
'interval': 'epoch',
|
172 |
-
'frequency': 1,
|
173 |
-
'strict': True,
|
174 |
-
}]
|
175 |
-
|
176 |
-
return optimizer, scheduler
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/ema.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
|
4 |
-
class EMA:
|
5 |
-
def __init__(self, beta: float):
|
6 |
-
super().__init__()
|
7 |
-
self.beta = beta
|
8 |
-
self.step = 0
|
9 |
-
|
10 |
-
def update_model_average(self, ema_model: nn.Module, current_model: nn.Module) -> None:
|
11 |
-
for current_params, ema_model in zip(current_model.parameters(), ema_model.parameters()):
|
12 |
-
old_weight, up_weight = ema_model.data, current_params.data
|
13 |
-
ema_model.data = self.update_average(old_weight, up_weight)
|
14 |
-
|
15 |
-
def update_average(self, old: torch.Tensor | None, new: torch.Tensor) -> torch.Tensor:
|
16 |
-
if old is None:
|
17 |
-
return new
|
18 |
-
return old * self.beta + (1 - self.beta) * new
|
19 |
-
|
20 |
-
def step_ema(self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 2000) -> None:
|
21 |
-
if self.step < step_start_ema:
|
22 |
-
self.reset_parameters(ema_model, model)
|
23 |
-
self.step += 1
|
24 |
-
return
|
25 |
-
self.update_model_average(ema_model, model)
|
26 |
-
self.step += 1
|
27 |
-
|
28 |
-
def copy_to(self, ema_model: nn.Module, model: nn.Module) -> None:
|
29 |
-
model.load_state_dict(ema_model.state_dict())
|
30 |
-
|
31 |
-
def reset_parameters(self, ema_model: nn.Module, model: nn.Module) -> None:
|
32 |
-
ema_model.load_state_dict(model.state_dict())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/inter_frame_idx.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
from utils.utils import morph_open
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from kornia.color import rgb_to_grayscale
|
5 |
-
|
6 |
-
import cv2
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
class FlowEstimation:
|
10 |
-
def __init__(self, flow_estimator: str = "farneback"):
|
11 |
-
assert flow_estimator in ["farneback", "dualtvl1"], "Flow estimator must be one of [farneback, dualtvl1]"
|
12 |
-
|
13 |
-
if flow_estimator == "farneback":
|
14 |
-
self.flow_estimator = self.OptFlow_Farneback
|
15 |
-
elif flow_estimator == "dualtvl1":
|
16 |
-
self.flow_estimator = self.OptFlow_DualTVL1
|
17 |
-
else:
|
18 |
-
raise NotImplementedError
|
19 |
-
|
20 |
-
def OptFlow_Farneback(self, I0: torch.Tensor, I1: torch.Tensor) -> torch.Tensor:
|
21 |
-
device = I0.device
|
22 |
-
|
23 |
-
I0 = I0.cpu().clamp(0, 1) * 255
|
24 |
-
I1 = I1.cpu().clamp(0, 1) * 255
|
25 |
-
|
26 |
-
batch_size = I0.shape[0]
|
27 |
-
for i in range(batch_size):
|
28 |
-
I0_np = I0[i].permute(1, 2, 0).numpy().astype(np.uint8)
|
29 |
-
I1_np = I1[i].permute(1, 2, 0).numpy().astype(np.uint8)
|
30 |
-
|
31 |
-
I0_gray = cv2.cvtColor(I0_np, cv2.COLOR_BGR2GRAY)
|
32 |
-
I1_gray = cv2.cvtColor(I1_np, cv2.COLOR_BGR2GRAY)
|
33 |
-
|
34 |
-
flow = cv2.calcOpticalFlowFarneback(I0_gray, I1_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
|
35 |
-
flow = torch.from_numpy(flow).permute(2, 0, 1).unsqueeze(0).float()
|
36 |
-
if i == 0:
|
37 |
-
flows = flow
|
38 |
-
else:
|
39 |
-
flows = torch.cat((flows, flow), dim = 0)
|
40 |
-
|
41 |
-
return flows.to(device)
|
42 |
-
|
43 |
-
def OptFlow_DualTVL1(
|
44 |
-
self,
|
45 |
-
I0: torch.Tensor,
|
46 |
-
I1: torch.Tensor,
|
47 |
-
tau: float = 0.25,
|
48 |
-
lambda_: float = 0.15,
|
49 |
-
theta: float = 0.3,
|
50 |
-
scales_number: int = 5,
|
51 |
-
warps: int = 5,
|
52 |
-
epsilon: float = 0.01,
|
53 |
-
inner_iterations: int = 30,
|
54 |
-
outer_iterations: int = 10,
|
55 |
-
scale_step: float = 0.8,
|
56 |
-
gamma: float = 0.0
|
57 |
-
) -> torch.Tensor:
|
58 |
-
optical_flow = cv2.optflow.createOptFlow_DualTVL1()
|
59 |
-
optical_flow.setTau(tau)
|
60 |
-
optical_flow.setLambda(lambda_)
|
61 |
-
optical_flow.setTheta(theta)
|
62 |
-
optical_flow.setScalesNumber(scales_number)
|
63 |
-
optical_flow.setWarpingsNumber(warps)
|
64 |
-
optical_flow.setEpsilon(epsilon)
|
65 |
-
optical_flow.setInnerIterations(inner_iterations)
|
66 |
-
optical_flow.setOuterIterations(outer_iterations)
|
67 |
-
optical_flow.setScaleStep(scale_step)
|
68 |
-
optical_flow.setGamma(gamma)
|
69 |
-
|
70 |
-
device = I0.device
|
71 |
-
|
72 |
-
I0 = I0.cpu().clamp(0, 1) * 255
|
73 |
-
I1 = I1.cpu().clamp(0, 1) * 255
|
74 |
-
|
75 |
-
batch_size = I0.shape[0]
|
76 |
-
for i in range(batch_size):
|
77 |
-
I0_np = I0[i].permute(1, 2, 0).numpy().astype(np.uint8)
|
78 |
-
I1_np = I1[i].permute(1, 2, 0).numpy().astype(np.uint8)
|
79 |
-
|
80 |
-
I0_gray = cv2.cvtColor(I0_np, cv2.COLOR_BGR2GRAY)
|
81 |
-
I1_gray = cv2.cvtColor(I1_np, cv2.COLOR_BGR2GRAY)
|
82 |
-
|
83 |
-
flow = optical_flow.calc(I0_gray, I1_gray, None)
|
84 |
-
flow = torch.from_numpy(flow).permute(2, 0, 1).unsqueeze(0).float()
|
85 |
-
if i == 0:
|
86 |
-
flows = flow
|
87 |
-
else:
|
88 |
-
flows = torch.cat((flows, flow), dim = 0)
|
89 |
-
|
90 |
-
return flows.to(device)
|
91 |
-
|
92 |
-
def __call__(self, I1: torch.Tensor, I0: torch.Tensor) -> torch.Tensor:
|
93 |
-
return self.flow_estimator(I1, I0)
|
94 |
-
|
95 |
-
def get_inter_frame_temp_index(
|
96 |
-
I0: torch.Tensor,
|
97 |
-
It: torch.Tensor,
|
98 |
-
I1: torch.Tensor,
|
99 |
-
flow0tot: torch.Tensor,
|
100 |
-
flow1tot: torch.Tensor,
|
101 |
-
k: int = 5,
|
102 |
-
threshold: float = 2e-2
|
103 |
-
) -> torch.Tensor:
|
104 |
-
|
105 |
-
I0_gray = rgb_to_grayscale(I0)
|
106 |
-
It_gray = rgb_to_grayscale(It)
|
107 |
-
I1_gray = rgb_to_grayscale(I1)
|
108 |
-
|
109 |
-
mask0tot = morph_open(It_gray - I0_gray, k=k)
|
110 |
-
mask1tot = morph_open(I1_gray - It_gray, k=k)
|
111 |
-
|
112 |
-
mask0tot = (abs(mask0tot) > threshold).to(torch.uint8)
|
113 |
-
mask1tot = (abs(mask1tot) > threshold).to(torch.uint8)
|
114 |
-
|
115 |
-
flow_mag0tot = torch.sqrt(flow0tot[:, 0, :, :]**2 + flow0tot[:, 1, :, :]**2).unsqueeze(1)
|
116 |
-
flow_mag1tot = torch.sqrt(flow1tot[:, 0, :, :]**2 + flow1tot[:, 1, :, :]**2).unsqueeze(1)
|
117 |
-
|
118 |
-
norm0tot = (flow_mag0tot*mask0tot).squeeze(1)
|
119 |
-
norm1tot = (flow_mag1tot*mask1tot).squeeze(1)
|
120 |
-
d0tot = torch.sum(norm0tot, dim = (1, 2))
|
121 |
-
d1tot = torch.sum(norm1tot, dim = (1, 2))
|
122 |
-
|
123 |
-
return d0tot / (d0tot + d1tot + 1e-12)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/raft.py
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torchvision.models.optical_flow import raft_large
|
3 |
-
from modules.flow_models.raft.rfr_new import RAFT
|
4 |
-
|
5 |
-
def raft_flow(
|
6 |
-
I0: torch.Tensor,
|
7 |
-
I1: torch.Tensor,
|
8 |
-
data_domain: str = "animation",
|
9 |
-
device: str = 'cuda'
|
10 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
11 |
-
if I0.dtype != torch.float32 or I1.dtype != torch.float32:
|
12 |
-
I0 = I0.to(torch.float32)
|
13 |
-
I1 = I1.to(torch.float32)
|
14 |
-
if data_domain == "animation":
|
15 |
-
raft = RAFT().requires_grad_(False).eval().to(device)
|
16 |
-
elif data_domain == "photorealism":
|
17 |
-
raft = raft_large().requires_grad_(False).eval().to(device)
|
18 |
-
else:
|
19 |
-
raise ValueError("data_domain must be either 'animation' or 'photorealism'")
|
20 |
-
return raft(I0, I1) if data_domain == "animation" else raft(I0, I1)[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/uncertainty.py
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import itertools
|
3 |
-
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
|
4 |
-
from utils.utils import denorm
|
5 |
-
|
6 |
-
def compute_lpips_variability(samples: torch.Tensor,
|
7 |
-
net: str = 'alex',
|
8 |
-
device: str = 'cuda'
|
9 |
-
) -> float:
|
10 |
-
loss_fn = LPIPS(net_type=net).to(device)
|
11 |
-
loss_fn.eval()
|
12 |
-
|
13 |
-
if samples.min() >= 0.0:
|
14 |
-
samples = samples * 2 - 1 # Convertir [0, 1] → [-1, 1]
|
15 |
-
|
16 |
-
N = samples.size(0)
|
17 |
-
scores = []
|
18 |
-
for i, j in itertools.combinations(range(N), 2):
|
19 |
-
x = samples[i:i+1].to(device)
|
20 |
-
y = samples[j:j+1].to(device)
|
21 |
-
dist = loss_fn(denorm(x.clamp(-1, 1)), denorm(y.clamp(-1, 1)))
|
22 |
-
scores.append(dist.item())
|
23 |
-
|
24 |
-
return sum(scores) / len(scores)
|
25 |
-
|
26 |
-
def compute_pixelwise_correlation(samples: torch.Tensor) -> float:
|
27 |
-
N, C, H, W = samples.shape
|
28 |
-
samples_flat = samples.view(N, C, -1) # (N, C, H*W)
|
29 |
-
|
30 |
-
corrs = []
|
31 |
-
for i, j in itertools.combinations(range(N), 2):
|
32 |
-
x = samples_flat[i] # (C, HW)
|
33 |
-
y = samples_flat[j] # (C, HW)
|
34 |
-
mean_x = x.mean(dim=1, keepdim=True)
|
35 |
-
mean_y = y.mean(dim=1, keepdim=True)
|
36 |
-
x_centered = x - mean_x
|
37 |
-
y_centered = y - mean_y
|
38 |
-
numerator = (x_centered * y_centered).sum(dim=1)
|
39 |
-
denominator = (x_centered.norm(dim=1) * y_centered.norm(dim=1)) + 1e-8
|
40 |
-
corr = numerator / denominator # (C,)
|
41 |
-
corrs.append(corr.mean().item())
|
42 |
-
return sum(corrs) / len(corrs)
|
43 |
-
|
44 |
-
def compute_dynamic_range(samples: torch.Tensor) -> float:
|
45 |
-
max_vals, _ = samples.max(dim=0) # (C, H, W)
|
46 |
-
min_vals, _ = samples.min(dim=0) # (C, H, W)
|
47 |
-
|
48 |
-
dynamic_range = max_vals - min_vals # (C, H, W)
|
49 |
-
return dynamic_range.mean().item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|