VicFonch commited on
Commit
c604c51
·
unverified ·
1 Parent(s): 0b58ffc

deleting unecessary files/scripts

Browse files
config/confg.yaml DELETED
@@ -1,64 +0,0 @@
1
- data_confg:
2
- train_batch_size: 6
3
- val_batch_size: 6
4
- test_batch_size: 6
5
- flow_method: raft
6
- data_domain: animation
7
- datamodule_confg:
8
- mean: [0.5, 0.5, 0.5]
9
- sd: [0.5, 0.5, 0.5]
10
- size: [256, 448]
11
- amount_augmentations: 1
12
- horizontal_flip: 0.5
13
- time_flip: True
14
- rotation: 0
15
- brightness: 0.2
16
- contrast: 0.2
17
- saturation: 0.2
18
- hue: 0.1
19
-
20
- trainer_confg:
21
- accumulate_grad_batches: 5
22
- gradient_clip_val: 1.0
23
- max_epochs: 500
24
- num_nodes: 1
25
- devices: 2
26
- accelerator: gpu
27
- strategy: ddp_find_unused_parameters_true
28
-
29
- optim_confg:
30
- optimizer_confg: # AdamW
31
- lr: 1.0e-4
32
- betas: [0.9, 0.999]
33
- eps: 1.0e-8
34
- scheduler_confg: # ReduceLROnPlateau
35
- mode: min
36
- factor: 0.5
37
- patience: 3
38
- verbose: True
39
-
40
- pretrained_model_path: null # Fine-tune model path
41
-
42
- model_confg:
43
- kappa: 2.0
44
- timesteps: 20
45
- p: 0.3
46
- etas_end: 0.99
47
- min_noise_level: 0.04
48
- flow_model: raft
49
- flow_kwargs:
50
- pretrained_path: null #_pretrain_models/anime_interp_full.ckpt
51
- warping_kwargs:
52
- in_channels: 3
53
- channels: [128, 256, 384, 512]
54
- synthesis_kwargs:
55
- in_channels: 3
56
- channels: [128, 256, 384, 512]
57
- temb_channels: 512
58
- heads: 1
59
- window_size: 8
60
- window_attn: True
61
- grid_attn: True
62
- expansion_rate: 1.5
63
- num_conv_blocks: 1
64
- dropout: 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/train_pipline.py DELETED
@@ -1,177 +0,0 @@
1
- import os
2
- import copy
3
- import matplotlib.pyplot as plt
4
- from typing import Any
5
-
6
- import torch
7
- from torch.optim.lr_scheduler import ReduceLROnPlateau
8
- from torch.optim import AdamW, Optimizer
9
- from torch.utils.data import DataLoader
10
- from lightning import LightningModule
11
-
12
- from torchmetrics import MetricCollection
13
- from torchmetrics.image import PeakSignalNoiseRatio as PSNR
14
- from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
15
- from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
16
-
17
- from model.model import MultiInputResShift
18
-
19
- from utils.utils import denorm, make_grid_images#, save_triplet
20
- from utils.ema import EMA
21
- from utils.inter_frame_idx import get_inter_frame_temp_index
22
- from utils.raft import raft_flow
23
-
24
-
25
- class TrainPipline(LightningModule):
26
- def __init__(self,
27
- confg: dict,
28
- test_dataloader: DataLoader):
29
- super(TrainPipline, self).__init__()
30
-
31
- self.test_dataloader = test_dataloader
32
-
33
- self.confg = confg
34
-
35
- self.mean, self.sd = confg["data_confg"]["mean"], confg["data_confg"]["sd"]
36
-
37
- self.model = MultiInputResShift(**confg["model_confg"])
38
- self.model.flow_model.requires_grad_(False).eval()
39
-
40
- self.ema = EMA(beta=0.995)
41
- self.ema_model = copy.deepcopy(self.model).eval().requires_grad_(False)
42
-
43
- self.charbonnier_loss = lambda x, y: torch.mean(torch.sqrt((x - y)**2 + 1e-6))
44
- self.lpips_loss = LPIPS(net_type='vgg')
45
-
46
- self.train_metrics = MetricCollection({
47
- "train_lpips": LPIPS(net_type='alex'),
48
- "train_psnr": PSNR(),
49
- "train_ssim": SSIM()
50
- })
51
- self.val_metrics = MetricCollection({
52
- "val_lpips": LPIPS(net_type='alex'),
53
- "val_psnr": PSNR(),
54
- "val_ssim": SSIM()
55
- })
56
-
57
- def loss_fn(self,
58
- x: torch.Tensor,
59
- predicted_x: torch.Tensor) -> torch.Tensor:
60
- percep_loss = 0.2 * self.lpips_loss(x, predicted_x.clamp(-1, 1))
61
- pix2pix_loss = self.charbonnier_loss(x, predicted_x)
62
- return percep_loss + pix2pix_loss
63
-
64
- def sample_t(self,
65
- shape: tuple[int, ...],
66
- max_t: int,
67
- device: torch.device) -> torch.Tensor:
68
- p = torch.linspace(1, max_t, steps=max_t, device=device) ** 2
69
- p = p / p.sum()
70
- t = torch.multinomial(p, num_samples=shape[0], replacement=True)
71
- return t
72
-
73
- def forward(self,
74
- I0: torch.Tensor,
75
- It: torch.Tensor,
76
- I1: torch.Tensor) -> torch.Tensor:
77
- flow0tot = raft_flow(I0, It, 'animation')
78
- flow1tot = raft_flow(I1, It, 'animation')
79
- mid_idx = get_inter_frame_temp_index(I0, It, I1, flow0tot, flow1tot).to(It.dtype)
80
-
81
- tau = torch.stack([mid_idx, 1 - mid_idx], dim=1)
82
-
83
- if self.current_epoch > 5:
84
- t = torch.randint(low=1, high=self.model.timesteps, size=(It.shape[0],), device=It.device, dtype=torch.long)
85
- else:
86
- t = self.sample_t(shape=(It.shape[0],), max_t=self.model.timesteps, device=It.device)
87
-
88
- predicted_It = self.model(I0, It, I1, tau=tau, t=t)
89
- return predicted_It
90
-
91
- def get_step_plt_images(self,
92
- It: torch.Tensor,
93
- predicted_It: torch.Tensor) -> plt.Figure:
94
- fig, ax = plt.subplots(1, 2, figsize=(20, 10))
95
- ax[0].imshow(denorm(predicted_It.clamp(-1, 1), self.mean, self.sd)[0].permute(1, 2, 0).cpu().numpy())
96
- ax[0].axis("off")
97
- ax[0].set_title("Predicted")
98
- ax[1].imshow(denorm(It, self.mean, self.sd)[0].permute(1, 2, 0).cpu().numpy())
99
- ax[1].axis("off")
100
- ax[1].set_title("Ground Truth")
101
- plt.tight_layout()
102
- #img_path = "step_image.png"
103
- #fig.savefig(img_path, dpi=300, bbox_inches='tight')
104
- plt.close(fig)
105
- return fig
106
-
107
- def training_step(self, batch: tuple[torch.Tensor, ...], _) -> torch.Tensor:
108
- I0, It, I1 = batch
109
- predicted_It = self(I0, It, I1)
110
- loss = self.loss_fn(It, predicted_It)
111
-
112
- self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
113
- self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
114
-
115
- self.ema.step_ema(self.ema_model, self.model)
116
- with torch.inference_mode():
117
- fig = self.get_step_plt_images(It, predicted_It)
118
- self.logger.experiment.add_figure("Train Predictions", fig, self.global_step)
119
- mets = self.train_metrics(It, predicted_It.clamp(-1, 1))
120
- self.log_dict(mets, prog_bar=True, on_step=True,on_epoch=False)
121
- return loss
122
-
123
- @torch.no_grad()
124
- def validation_step(self, batch: tuple[torch.Tensor, ...], _) -> None:
125
- I0, It, I1 = batch
126
- predicted_It = self(I0, It, I1)
127
- loss = self.loss_fn(It, predicted_It)
128
-
129
- self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
130
-
131
- mets = self.val_metrics(It, predicted_It.clamp(-1, 1))
132
- self.log_dict(mets, prog_bar=True, on_step=False, on_epoch=True)
133
-
134
- @torch.inference_mode()
135
- def on_train_epoch_end(self) -> None:
136
- torch.save(self.ema_model.state_dict(),
137
- os.path.join("_checkpoint", f"resshift_diff_{self.current_epoch}.pth"))
138
-
139
- batch = next(iter(self.test_dataloader))
140
- I0, It, I1 = batch
141
- I0, It, I1 = I0.to(self.device), It.to(self.device), I1.to(self.device)
142
-
143
- flow0tot = raft_flow(I0, It, 'animation')
144
- flow1tot = raft_flow(I1, It, 'animation')
145
- mid_idx = get_inter_frame_temp_index(I0, It, I1, flow0tot, flow1tot).to(It.dtype)
146
- tau = torch.stack([mid_idx, 1 - mid_idx], dim=1)
147
-
148
- predicted_It = self.ema_model.reverse_process([I0, I1], tau)
149
-
150
- I0 = denorm(I0, self.mean, self.sd)
151
- I1 = denorm(I1, self.mean, self.sd)
152
- It = denorm(It, self.mean, self.sd)
153
- predicted_It = denorm(predicted_It.clamp(-1, 1), self.mean, self.sd)
154
-
155
- #save_triplet([I0, It, predicted_It, I1], f"./_output/target_{self.current_epoch}.png", nrow=1)
156
- grid = make_grid_images([I0, It, predicted_It, I1], nrow=1)
157
- self.logger.experiment.add_image("Predicted Images", grid, self.global_step)
158
-
159
- def configure_optimizers(self) -> tuple[list[Optimizer], list[dict[str, Any]]]:
160
- optimizer = [AdamW(
161
- self.model.parameters(),
162
- **self.confg["optim_confg"]['optimizer_confg']
163
- )]
164
-
165
- scheduler = [{
166
- 'scheduler': ReduceLROnPlateau(
167
- optimizer[0],
168
- **self.confg["optim_confg"]['scheduler_confg']
169
- ),
170
- 'monitor': 'val_loss',
171
- 'interval': 'epoch',
172
- 'frequency': 1,
173
- 'strict': True,
174
- }]
175
-
176
- return optimizer, scheduler
177
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/ema.py DELETED
@@ -1,32 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- class EMA:
5
- def __init__(self, beta: float):
6
- super().__init__()
7
- self.beta = beta
8
- self.step = 0
9
-
10
- def update_model_average(self, ema_model: nn.Module, current_model: nn.Module) -> None:
11
- for current_params, ema_model in zip(current_model.parameters(), ema_model.parameters()):
12
- old_weight, up_weight = ema_model.data, current_params.data
13
- ema_model.data = self.update_average(old_weight, up_weight)
14
-
15
- def update_average(self, old: torch.Tensor | None, new: torch.Tensor) -> torch.Tensor:
16
- if old is None:
17
- return new
18
- return old * self.beta + (1 - self.beta) * new
19
-
20
- def step_ema(self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 2000) -> None:
21
- if self.step < step_start_ema:
22
- self.reset_parameters(ema_model, model)
23
- self.step += 1
24
- return
25
- self.update_model_average(ema_model, model)
26
- self.step += 1
27
-
28
- def copy_to(self, ema_model: nn.Module, model: nn.Module) -> None:
29
- model.load_state_dict(ema_model.state_dict())
30
-
31
- def reset_parameters(self, ema_model: nn.Module, model: nn.Module) -> None:
32
- ema_model.load_state_dict(model.state_dict())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/inter_frame_idx.py DELETED
@@ -1,123 +0,0 @@
1
- from utils.utils import morph_open
2
-
3
- import torch
4
- from kornia.color import rgb_to_grayscale
5
-
6
- import cv2
7
- import numpy as np
8
-
9
- class FlowEstimation:
10
- def __init__(self, flow_estimator: str = "farneback"):
11
- assert flow_estimator in ["farneback", "dualtvl1"], "Flow estimator must be one of [farneback, dualtvl1]"
12
-
13
- if flow_estimator == "farneback":
14
- self.flow_estimator = self.OptFlow_Farneback
15
- elif flow_estimator == "dualtvl1":
16
- self.flow_estimator = self.OptFlow_DualTVL1
17
- else:
18
- raise NotImplementedError
19
-
20
- def OptFlow_Farneback(self, I0: torch.Tensor, I1: torch.Tensor) -> torch.Tensor:
21
- device = I0.device
22
-
23
- I0 = I0.cpu().clamp(0, 1) * 255
24
- I1 = I1.cpu().clamp(0, 1) * 255
25
-
26
- batch_size = I0.shape[0]
27
- for i in range(batch_size):
28
- I0_np = I0[i].permute(1, 2, 0).numpy().astype(np.uint8)
29
- I1_np = I1[i].permute(1, 2, 0).numpy().astype(np.uint8)
30
-
31
- I0_gray = cv2.cvtColor(I0_np, cv2.COLOR_BGR2GRAY)
32
- I1_gray = cv2.cvtColor(I1_np, cv2.COLOR_BGR2GRAY)
33
-
34
- flow = cv2.calcOpticalFlowFarneback(I0_gray, I1_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
35
- flow = torch.from_numpy(flow).permute(2, 0, 1).unsqueeze(0).float()
36
- if i == 0:
37
- flows = flow
38
- else:
39
- flows = torch.cat((flows, flow), dim = 0)
40
-
41
- return flows.to(device)
42
-
43
- def OptFlow_DualTVL1(
44
- self,
45
- I0: torch.Tensor,
46
- I1: torch.Tensor,
47
- tau: float = 0.25,
48
- lambda_: float = 0.15,
49
- theta: float = 0.3,
50
- scales_number: int = 5,
51
- warps: int = 5,
52
- epsilon: float = 0.01,
53
- inner_iterations: int = 30,
54
- outer_iterations: int = 10,
55
- scale_step: float = 0.8,
56
- gamma: float = 0.0
57
- ) -> torch.Tensor:
58
- optical_flow = cv2.optflow.createOptFlow_DualTVL1()
59
- optical_flow.setTau(tau)
60
- optical_flow.setLambda(lambda_)
61
- optical_flow.setTheta(theta)
62
- optical_flow.setScalesNumber(scales_number)
63
- optical_flow.setWarpingsNumber(warps)
64
- optical_flow.setEpsilon(epsilon)
65
- optical_flow.setInnerIterations(inner_iterations)
66
- optical_flow.setOuterIterations(outer_iterations)
67
- optical_flow.setScaleStep(scale_step)
68
- optical_flow.setGamma(gamma)
69
-
70
- device = I0.device
71
-
72
- I0 = I0.cpu().clamp(0, 1) * 255
73
- I1 = I1.cpu().clamp(0, 1) * 255
74
-
75
- batch_size = I0.shape[0]
76
- for i in range(batch_size):
77
- I0_np = I0[i].permute(1, 2, 0).numpy().astype(np.uint8)
78
- I1_np = I1[i].permute(1, 2, 0).numpy().astype(np.uint8)
79
-
80
- I0_gray = cv2.cvtColor(I0_np, cv2.COLOR_BGR2GRAY)
81
- I1_gray = cv2.cvtColor(I1_np, cv2.COLOR_BGR2GRAY)
82
-
83
- flow = optical_flow.calc(I0_gray, I1_gray, None)
84
- flow = torch.from_numpy(flow).permute(2, 0, 1).unsqueeze(0).float()
85
- if i == 0:
86
- flows = flow
87
- else:
88
- flows = torch.cat((flows, flow), dim = 0)
89
-
90
- return flows.to(device)
91
-
92
- def __call__(self, I1: torch.Tensor, I0: torch.Tensor) -> torch.Tensor:
93
- return self.flow_estimator(I1, I0)
94
-
95
- def get_inter_frame_temp_index(
96
- I0: torch.Tensor,
97
- It: torch.Tensor,
98
- I1: torch.Tensor,
99
- flow0tot: torch.Tensor,
100
- flow1tot: torch.Tensor,
101
- k: int = 5,
102
- threshold: float = 2e-2
103
- ) -> torch.Tensor:
104
-
105
- I0_gray = rgb_to_grayscale(I0)
106
- It_gray = rgb_to_grayscale(It)
107
- I1_gray = rgb_to_grayscale(I1)
108
-
109
- mask0tot = morph_open(It_gray - I0_gray, k=k)
110
- mask1tot = morph_open(I1_gray - It_gray, k=k)
111
-
112
- mask0tot = (abs(mask0tot) > threshold).to(torch.uint8)
113
- mask1tot = (abs(mask1tot) > threshold).to(torch.uint8)
114
-
115
- flow_mag0tot = torch.sqrt(flow0tot[:, 0, :, :]**2 + flow0tot[:, 1, :, :]**2).unsqueeze(1)
116
- flow_mag1tot = torch.sqrt(flow1tot[:, 0, :, :]**2 + flow1tot[:, 1, :, :]**2).unsqueeze(1)
117
-
118
- norm0tot = (flow_mag0tot*mask0tot).squeeze(1)
119
- norm1tot = (flow_mag1tot*mask1tot).squeeze(1)
120
- d0tot = torch.sum(norm0tot, dim = (1, 2))
121
- d1tot = torch.sum(norm1tot, dim = (1, 2))
122
-
123
- return d0tot / (d0tot + d1tot + 1e-12)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/raft.py DELETED
@@ -1,20 +0,0 @@
1
- import torch
2
- from torchvision.models.optical_flow import raft_large
3
- from modules.flow_models.raft.rfr_new import RAFT
4
-
5
- def raft_flow(
6
- I0: torch.Tensor,
7
- I1: torch.Tensor,
8
- data_domain: str = "animation",
9
- device: str = 'cuda'
10
- ) -> tuple[torch.Tensor, torch.Tensor]:
11
- if I0.dtype != torch.float32 or I1.dtype != torch.float32:
12
- I0 = I0.to(torch.float32)
13
- I1 = I1.to(torch.float32)
14
- if data_domain == "animation":
15
- raft = RAFT().requires_grad_(False).eval().to(device)
16
- elif data_domain == "photorealism":
17
- raft = raft_large().requires_grad_(False).eval().to(device)
18
- else:
19
- raise ValueError("data_domain must be either 'animation' or 'photorealism'")
20
- return raft(I0, I1) if data_domain == "animation" else raft(I0, I1)[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/uncertainty.py DELETED
@@ -1,49 +0,0 @@
1
- import torch
2
- import itertools
3
- from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
4
- from utils.utils import denorm
5
-
6
- def compute_lpips_variability(samples: torch.Tensor,
7
- net: str = 'alex',
8
- device: str = 'cuda'
9
- ) -> float:
10
- loss_fn = LPIPS(net_type=net).to(device)
11
- loss_fn.eval()
12
-
13
- if samples.min() >= 0.0:
14
- samples = samples * 2 - 1 # Convertir [0, 1] → [-1, 1]
15
-
16
- N = samples.size(0)
17
- scores = []
18
- for i, j in itertools.combinations(range(N), 2):
19
- x = samples[i:i+1].to(device)
20
- y = samples[j:j+1].to(device)
21
- dist = loss_fn(denorm(x.clamp(-1, 1)), denorm(y.clamp(-1, 1)))
22
- scores.append(dist.item())
23
-
24
- return sum(scores) / len(scores)
25
-
26
- def compute_pixelwise_correlation(samples: torch.Tensor) -> float:
27
- N, C, H, W = samples.shape
28
- samples_flat = samples.view(N, C, -1) # (N, C, H*W)
29
-
30
- corrs = []
31
- for i, j in itertools.combinations(range(N), 2):
32
- x = samples_flat[i] # (C, HW)
33
- y = samples_flat[j] # (C, HW)
34
- mean_x = x.mean(dim=1, keepdim=True)
35
- mean_y = y.mean(dim=1, keepdim=True)
36
- x_centered = x - mean_x
37
- y_centered = y - mean_y
38
- numerator = (x_centered * y_centered).sum(dim=1)
39
- denominator = (x_centered.norm(dim=1) * y_centered.norm(dim=1)) + 1e-8
40
- corr = numerator / denominator # (C,)
41
- corrs.append(corr.mean().item())
42
- return sum(corrs) / len(corrs)
43
-
44
- def compute_dynamic_range(samples: torch.Tensor) -> float:
45
- max_vals, _ = samples.max(dim=0) # (C, H, W)
46
- min_vals, _ = samples.min(dim=0) # (C, H, W)
47
-
48
- dynamic_range = max_vals - min_vals # (C, H, W)
49
- return dynamic_range.mean().item()