VicFonch
commited on
raft/rfr_new.py: Fix device problems
Browse files
modules/flow_models/raft/rfr_new.py
CHANGED
@@ -14,6 +14,8 @@ from .extractor import BasicEncoder, SmallEncoder
|
|
14 |
from .corr import CorrBlock
|
15 |
from .utils import bilinear_sampler, coords_grid, upflow8
|
16 |
|
|
|
|
|
17 |
try:
|
18 |
autocast = torch.amp.autocast
|
19 |
except:
|
@@ -34,8 +36,8 @@ def backwarp(img, flow):
|
|
34 |
|
35 |
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
|
36 |
|
37 |
-
gridX = torch.tensor(gridX, requires_grad=False,).
|
38 |
-
gridY = torch.tensor(gridY, requires_grad=False,).
|
39 |
x = gridX.unsqueeze(0).expand_as(u).float() + u
|
40 |
y = gridY.unsqueeze(0).expand_as(v).float() + v
|
41 |
# range -1 to 1
|
@@ -129,8 +131,8 @@ class RFR(nn.Module):
|
|
129 |
f12init = torch.exp(- self.attention2(torch.cat([im18, error21, flow_init_resize], dim=1)) ** 2) * flow_init_resize
|
130 |
else:
|
131 |
flow_init_resize = None
|
132 |
-
flow_init = torch.zeros(image1.size()[0], 2, image1.size()[2]//8, image1.size()[3]//8).
|
133 |
-
error21 = torch.zeros(image1.size()[0], 1, image1.size()[2]//8, image1.size()[3]//8).
|
134 |
|
135 |
f12_init = flow_init
|
136 |
# print('None inital flow!')
|
@@ -169,14 +171,14 @@ class RFR(nn.Module):
|
|
169 |
cdim = self.context_dim
|
170 |
|
171 |
# run the feature network
|
172 |
-
with autocast(
|
173 |
fmap1, fmap2 = self.fnet([image1, image2])
|
174 |
fmap1 = fmap1.float()
|
175 |
fmap2 = fmap2.float()
|
176 |
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
177 |
|
178 |
# run the context network
|
179 |
-
with autocast(
|
180 |
cnet = self.fnet(image1)
|
181 |
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
182 |
net = torch.tanh(net)
|
@@ -196,7 +198,7 @@ class RFR(nn.Module):
|
|
196 |
corr = corr_fn(coords1) # index correlation volume
|
197 |
|
198 |
flow = coords1 - coords0
|
199 |
-
with autocast(
|
200 |
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
201 |
|
202 |
# F(t+1) = F(t) + \Delta(t)
|
|
|
14 |
from .corr import CorrBlock
|
15 |
from .utils import bilinear_sampler, coords_grid, upflow8
|
16 |
|
17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
+
|
19 |
try:
|
20 |
autocast = torch.amp.autocast
|
21 |
except:
|
|
|
36 |
|
37 |
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
|
38 |
|
39 |
+
gridX = torch.tensor(gridX, requires_grad=False,).to(device)
|
40 |
+
gridY = torch.tensor(gridY, requires_grad=False,).to(device)
|
41 |
x = gridX.unsqueeze(0).expand_as(u).float() + u
|
42 |
y = gridY.unsqueeze(0).expand_as(v).float() + v
|
43 |
# range -1 to 1
|
|
|
131 |
f12init = torch.exp(- self.attention2(torch.cat([im18, error21, flow_init_resize], dim=1)) ** 2) * flow_init_resize
|
132 |
else:
|
133 |
flow_init_resize = None
|
134 |
+
flow_init = torch.zeros(image1.size()[0], 2, image1.size()[2]//8, image1.size()[3]//8).to(device)
|
135 |
+
error21 = torch.zeros(image1.size()[0], 1, image1.size()[2]//8, image1.size()[3]//8).to(device)
|
136 |
|
137 |
f12_init = flow_init
|
138 |
# print('None inital flow!')
|
|
|
171 |
cdim = self.context_dim
|
172 |
|
173 |
# run the feature network
|
174 |
+
with autocast(device, enabled=self.args.mixed_precision):
|
175 |
fmap1, fmap2 = self.fnet([image1, image2])
|
176 |
fmap1 = fmap1.float()
|
177 |
fmap2 = fmap2.float()
|
178 |
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
179 |
|
180 |
# run the context network
|
181 |
+
with autocast(device, enabled=self.args.mixed_precision):
|
182 |
cnet = self.fnet(image1)
|
183 |
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
184 |
net = torch.tanh(net)
|
|
|
198 |
corr = corr_fn(coords1) # index correlation volume
|
199 |
|
200 |
flow = coords1 - coords0
|
201 |
+
with autocast(device, enabled=self.args.mixed_precision):
|
202 |
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
203 |
|
204 |
# F(t+1) = F(t) + \Delta(t)
|