VicFonch commited on
Commit
d220e82
·
unverified ·
1 Parent(s): c604c51

raft/rfr_new.py: Fix device problems

Browse files
modules/flow_models/raft/rfr_new.py CHANGED
@@ -14,6 +14,8 @@ from .extractor import BasicEncoder, SmallEncoder
14
  from .corr import CorrBlock
15
  from .utils import bilinear_sampler, coords_grid, upflow8
16
 
 
 
17
  try:
18
  autocast = torch.amp.autocast
19
  except:
@@ -34,8 +36,8 @@ def backwarp(img, flow):
34
 
35
  gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
36
 
37
- gridX = torch.tensor(gridX, requires_grad=False,).cuda()
38
- gridY = torch.tensor(gridY, requires_grad=False,).cuda()
39
  x = gridX.unsqueeze(0).expand_as(u).float() + u
40
  y = gridY.unsqueeze(0).expand_as(v).float() + v
41
  # range -1 to 1
@@ -129,8 +131,8 @@ class RFR(nn.Module):
129
  f12init = torch.exp(- self.attention2(torch.cat([im18, error21, flow_init_resize], dim=1)) ** 2) * flow_init_resize
130
  else:
131
  flow_init_resize = None
132
- flow_init = torch.zeros(image1.size()[0], 2, image1.size()[2]//8, image1.size()[3]//8).cuda()
133
- error21 = torch.zeros(image1.size()[0], 1, image1.size()[2]//8, image1.size()[3]//8).cuda()
134
 
135
  f12_init = flow_init
136
  # print('None inital flow!')
@@ -169,14 +171,14 @@ class RFR(nn.Module):
169
  cdim = self.context_dim
170
 
171
  # run the feature network
172
- with autocast("cuda", enabled=self.args.mixed_precision):
173
  fmap1, fmap2 = self.fnet([image1, image2])
174
  fmap1 = fmap1.float()
175
  fmap2 = fmap2.float()
176
  corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
177
 
178
  # run the context network
179
- with autocast("cuda", enabled=self.args.mixed_precision):
180
  cnet = self.fnet(image1)
181
  net, inp = torch.split(cnet, [hdim, cdim], dim=1)
182
  net = torch.tanh(net)
@@ -196,7 +198,7 @@ class RFR(nn.Module):
196
  corr = corr_fn(coords1) # index correlation volume
197
 
198
  flow = coords1 - coords0
199
- with autocast("cuda", enabled=self.args.mixed_precision):
200
  net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
201
 
202
  # F(t+1) = F(t) + \Delta(t)
 
14
  from .corr import CorrBlock
15
  from .utils import bilinear_sampler, coords_grid, upflow8
16
 
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+
19
  try:
20
  autocast = torch.amp.autocast
21
  except:
 
36
 
37
  gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
38
 
39
+ gridX = torch.tensor(gridX, requires_grad=False,).to(device)
40
+ gridY = torch.tensor(gridY, requires_grad=False,).to(device)
41
  x = gridX.unsqueeze(0).expand_as(u).float() + u
42
  y = gridY.unsqueeze(0).expand_as(v).float() + v
43
  # range -1 to 1
 
131
  f12init = torch.exp(- self.attention2(torch.cat([im18, error21, flow_init_resize], dim=1)) ** 2) * flow_init_resize
132
  else:
133
  flow_init_resize = None
134
+ flow_init = torch.zeros(image1.size()[0], 2, image1.size()[2]//8, image1.size()[3]//8).to(device)
135
+ error21 = torch.zeros(image1.size()[0], 1, image1.size()[2]//8, image1.size()[3]//8).to(device)
136
 
137
  f12_init = flow_init
138
  # print('None inital flow!')
 
171
  cdim = self.context_dim
172
 
173
  # run the feature network
174
+ with autocast(device, enabled=self.args.mixed_precision):
175
  fmap1, fmap2 = self.fnet([image1, image2])
176
  fmap1 = fmap1.float()
177
  fmap2 = fmap2.float()
178
  corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
179
 
180
  # run the context network
181
+ with autocast(device, enabled=self.args.mixed_precision):
182
  cnet = self.fnet(image1)
183
  net, inp = torch.split(cnet, [hdim, cdim], dim=1)
184
  net = torch.tanh(net)
 
198
  corr = corr_fn(coords1) # index correlation volume
199
 
200
  flow = coords1 - coords0
201
+ with autocast(device, enabled=self.args.mixed_precision):
202
  net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
203
 
204
  # F(t+1) = F(t) + \Delta(t)