|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from argparse import Namespace
|
|
import numpy as np
|
|
|
|
from .update import BasicUpdateBlock, SmallUpdateBlock
|
|
from .extractor import BasicEncoder, SmallEncoder
|
|
from .corr import CorrBlock
|
|
from .utils import bilinear_sampler, coords_grid, upflow8
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
try:
|
|
autocast = torch.amp.autocast
|
|
except:
|
|
|
|
class autocast:
|
|
def __init__(self, enabled):
|
|
pass
|
|
def __enter__(self):
|
|
pass
|
|
def __exit__(self, *args):
|
|
pass
|
|
|
|
def backwarp(img, flow):
|
|
_, _, H, W = img.size()
|
|
|
|
u = flow[:, 0, :, :]
|
|
v = flow[:, 1, :, :]
|
|
|
|
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
|
|
|
|
gridX = torch.tensor(gridX, requires_grad=False,).to(device)
|
|
gridY = torch.tensor(gridY, requires_grad=False,).to(device)
|
|
x = gridX.unsqueeze(0).expand_as(u).float() + u
|
|
y = gridY.unsqueeze(0).expand_as(v).float() + v
|
|
|
|
x = 2*(x/(W-1) - 0.5)
|
|
y = 2*(y/(H-1) - 0.5)
|
|
|
|
grid = torch.stack((x,y), dim=3)
|
|
|
|
imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=True)
|
|
|
|
return imgOut
|
|
class ErrorAttention(nn.Module):
|
|
"""A three-layer network for predicting mask"""
|
|
def __init__(self, input, output):
|
|
super(ErrorAttention, self).__init__()
|
|
self.conv1 = nn.Conv2d(input, 32, 5, padding=2)
|
|
self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
|
|
self.conv3 = nn.Conv2d(38, output, 3, padding=1)
|
|
self.prelu1 = nn.PReLU()
|
|
self.prelu2 = nn.PReLU()
|
|
|
|
def forward(self, x1):
|
|
x = self.prelu1(self.conv1(x1))
|
|
x = self.prelu2(torch.cat([self.conv2(x), x1], dim=1))
|
|
x = self.conv3(x)
|
|
return x
|
|
|
|
class RFR(nn.Module):
|
|
def __init__(self, args):
|
|
super(RFR, self).__init__()
|
|
self.attention2 = ErrorAttention(6, 1)
|
|
self.hidden_dim = hdim = 128
|
|
self.context_dim = cdim = 128
|
|
args.corr_levels = 4
|
|
args.corr_radius = 4
|
|
args.dropout = 0
|
|
self.args = args
|
|
|
|
|
|
self.fnet = BasicEncoder(output_dim=256, norm_fn='none', dropout=args.dropout)
|
|
|
|
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
|
|
|
|
|
|
|
def freeze_bn(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.BatchNorm2d):
|
|
m.eval()
|
|
|
|
def initialize_flow(self, img):
|
|
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
|
N, C, H, W = img.shape
|
|
coords0 = coords_grid(N, H//8, W//8).to(img.device)
|
|
coords1 = coords_grid(N, H//8, W//8).to(img.device)
|
|
|
|
|
|
return coords0, coords1
|
|
|
|
def upsample_flow(self, flow, mask):
|
|
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
|
N, _, H, W = flow.shape
|
|
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
|
mask = torch.softmax(mask, dim=2)
|
|
|
|
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
|
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
|
|
|
up_flow = torch.sum(mask * up_flow, dim=2)
|
|
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
|
return up_flow.reshape(N, 2, 8*H, 8*W)
|
|
|
|
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
|
H, W = image1.size()[2:4]
|
|
H8 = H // 8 * 8
|
|
W8 = W // 8 * 8
|
|
|
|
if flow_init is not None:
|
|
flow_init_resize = F.interpolate(flow_init, size=(H8//8, W8//8), mode='nearest')
|
|
|
|
flow_init_resize[:, :1] = flow_init_resize[:, :1].clone() * (W8 // 8 *1.0) / flow_init.size()[3]
|
|
flow_init_resize[:, 1:] = flow_init_resize[:, 1:].clone() * (H8 // 8*1.0) / flow_init.size()[2]
|
|
|
|
if not hasattr(self.args, 'not_use_rfr_mask') or ( hasattr(self.args, 'not_use_rfr_mask') and (not self.args.not_use_rfr_mask)):
|
|
im18 = F.interpolate(image1, size=(H8//8, W8//8), mode='bilinear')
|
|
im28 = F.interpolate(image2, size=(H8//8, W8//8), mode='bilinear')
|
|
|
|
warp21 = backwarp(im28, flow_init_resize)
|
|
error21 = torch.sum(torch.abs(warp21 - im18), dim=1, keepdim=True)
|
|
|
|
f12init = torch.exp(- self.attention2(torch.cat([im18, error21, flow_init_resize], dim=1)) ** 2) * flow_init_resize
|
|
else:
|
|
flow_init_resize = None
|
|
flow_init = torch.zeros(image1.size()[0], 2, image1.size()[2]//8, image1.size()[3]//8).to(device)
|
|
error21 = torch.zeros(image1.size()[0], 1, image1.size()[2]//8, image1.size()[3]//8).to(device)
|
|
|
|
f12_init = flow_init
|
|
|
|
|
|
image1 = F.interpolate(image1, size=(H8, W8), mode='bilinear')
|
|
image2 = F.interpolate(image2, size=(H8, W8), mode='bilinear')
|
|
|
|
f12s, f12, f12_init = self.forward_pred(image1, image2, iters, flow_init_resize, upsample, test_mode)
|
|
|
|
|
|
if (hasattr(self.args, 'requires_sq_flow') and self.args.requires_sq_flow):
|
|
for ii in range(len(f12s)):
|
|
f12s[ii] = F.interpolate(f12s[ii], size=(H, W), mode='bilinear')
|
|
f12s[ii][:, :1] = f12s[ii][:, :1].clone() / (1.0*W8) * W
|
|
f12s[ii][:, 1:] = f12s[ii][:, 1:].clone() / (1.0*H8) * H
|
|
if self.training:
|
|
return f12s
|
|
else:
|
|
return [f12s[-1]], f12_init
|
|
else:
|
|
f12[:, :1] = f12[:, :1].clone() / (1.0*W8) * W
|
|
f12[:, 1:] = f12[:, 1:].clone() / (1.0*H8) * H
|
|
|
|
f12 = F.interpolate(f12, size=(H, W), mode='bilinear')
|
|
|
|
return f12, f12_init, error21,
|
|
|
|
def forward_pred(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
|
""" Estimate optical flow between pair of frames """
|
|
|
|
|
|
image1 = image1.contiguous()
|
|
image2 = image2.contiguous()
|
|
|
|
hdim = self.hidden_dim
|
|
cdim = self.context_dim
|
|
|
|
|
|
with autocast(device, enabled=self.args.mixed_precision):
|
|
fmap1, fmap2 = self.fnet([image1, image2])
|
|
fmap1 = fmap1.float()
|
|
fmap2 = fmap2.float()
|
|
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
|
|
|
|
|
with autocast(device, enabled=self.args.mixed_precision):
|
|
cnet = self.fnet(image1)
|
|
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
|
net = torch.tanh(net)
|
|
inp = torch.relu(inp)
|
|
|
|
coords0, coords1 = self.initialize_flow(image1)
|
|
|
|
if flow_init is not None:
|
|
coords1 = coords1 + flow_init
|
|
|
|
flow_predictions = []
|
|
for itr in range(iters):
|
|
coords1 = coords1.detach()
|
|
if itr == 0:
|
|
if flow_init is not None:
|
|
coords1 = coords1 + flow_init
|
|
corr = corr_fn(coords1)
|
|
|
|
flow = coords1 - coords0
|
|
with autocast(device, enabled=self.args.mixed_precision):
|
|
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
|
|
|
|
|
coords1 = coords1 + delta_flow
|
|
|
|
|
|
if up_mask is None:
|
|
flow_up = upflow8(coords1 - coords0)
|
|
else:
|
|
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
|
|
|
flow_predictions.append(flow_up)
|
|
|
|
return flow_predictions, flow_up, flow_init
|
|
|
|
class RAFT(nn.Module):
|
|
def __init__(self, path='./_pretrain_models/anime_interp_full.ckpt'):
|
|
super().__init__()
|
|
self.raft = RFR(Namespace(
|
|
small=False,
|
|
mixed_precision=False,
|
|
))
|
|
if path is not None:
|
|
sd = torch.load(path)['model_state_dict']
|
|
self.raft.load_state_dict({
|
|
k[len('module.flownet.'):]: v
|
|
for k,v in sd.items()
|
|
if k.startswith('module.flownet.')
|
|
}, strict=False)
|
|
return
|
|
def forward(self, img0, img1, flow0=None, iters=12, return_more=False):
|
|
if flow0 is not None:
|
|
flow0 = flow0.flip(dims=(1,))
|
|
out = self.raft(img0, img1, iters=iters, flow_init=flow0)
|
|
return out[0].flip(dims=(1,))
|
|
|
|
|