Spaces:
Running
on
Zero
Running
on
Zero
File size: 23,313 Bytes
5f9d349 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch
from argparse import ArgumentParser
from pathlib import Path
import math
import numpy as np
from torch import nn
from torch.utils.data import ConcatDataset
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import json
import wandb
from PIL import Image
from torchvision.transforms import ToTensor
from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
from romatch.datasets.megadepth import MegadepthBuilder
from romatch.losses.robust_loss_tiny_roma import RobustLosses
from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
from romatch.train.train import train_k_steps
from romatch.checkpointing import CheckPoint
resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6), "xfeat": (600,800), "big": (768, 1024)}
def kde(x, std = 0.1):
# use a gaussian kernel to estimate density
x = x.half() # Do it in half precision TODO: remove hardcoding
scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
density = scores.sum(dim=-1)
return density
class BasicLayer(nn.Module):
"""
Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
super().__init__()
self.layer = nn.Sequential(
nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
nn.BatchNorm2d(out_channels, affine=False),
nn.ReLU(inplace = True) if relu else nn.Identity()
)
def forward(self, x):
return self.layer(x)
class XFeatModel(nn.Module):
"""
Implementation of architecture described in
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
"""
def __init__(self, xfeat = None,
freeze_xfeat = True,
sample_mode = "threshold_balanced",
symmetric = False,
exact_softmax = False):
super().__init__()
if xfeat is None:
xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True, top_k = 4096).net
del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
if freeze_xfeat:
xfeat.train(False)
self.xfeat = [xfeat]# hide params from ddp
else:
self.xfeat = nn.ModuleList([xfeat])
self.freeze_xfeat = freeze_xfeat
match_dim = 256
self.coarse_matcher = nn.Sequential(
BasicLayer(64+64+2, match_dim,),
BasicLayer(match_dim, match_dim,),
BasicLayer(match_dim, match_dim,),
BasicLayer(match_dim, match_dim,),
nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
fine_match_dim = 64
self.fine_matcher = nn.Sequential(
BasicLayer(24+24+2, fine_match_dim,),
BasicLayer(fine_match_dim, fine_match_dim,),
BasicLayer(fine_match_dim, fine_match_dim,),
BasicLayer(fine_match_dim, fine_match_dim,),
nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
self.sample_mode = sample_mode
self.sample_thresh = 0.2
self.symmetric = symmetric
self.exact_softmax = exact_softmax
@property
def device(self):
return self.fine_matcher[-1].weight.device
def preprocess_tensor(self, x):
""" Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
H, W = x.shape[-2:]
_H, _W = (H//32) * 32, (W//32) * 32
rh, rw = H/_H, W/_W
x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
return x, rh, rw
def forward_single(self, x):
with torch.inference_mode(self.freeze_xfeat or not self.training):
xfeat = self.xfeat[0]
with torch.no_grad():
x = x.mean(dim=1, keepdim = True)
x = xfeat.norm(x)
#main backbone
x1 = xfeat.block1(x)
x2 = xfeat.block2(x1 + xfeat.skip1(x))
x3 = xfeat.block3(x2)
x4 = xfeat.block4(x3)
x5 = xfeat.block5(x4)
x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
feats = xfeat.block_fusion( x3 + x4 + x5 )
if self.freeze_xfeat:
return x2.clone(), feats.clone()
return x2, feats
def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
if coords.shape[-1] == 2:
return self._to_pixel_coordinates(coords, H_A, W_A)
if isinstance(coords, (list, tuple)):
kpts_A, kpts_B = coords[0], coords[1]
else:
kpts_A, kpts_B = coords[...,:2], coords[...,2:]
return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
def _to_pixel_coordinates(self, coords, H, W):
kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
return kpts
def pos_embed(self, corr_volume: torch.Tensor):
B, H1, W1, H0, W0 = corr_volume.shape
grid = torch.stack(
torch.meshgrid(
torch.linspace(-1+1/W1,1-1/W1, W1),
torch.linspace(-1+1/H1,1-1/H1, H1),
indexing = "xy"),
dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
down = 4
if not self.training and not self.exact_softmax:
grid_lr = torch.stack(
torch.meshgrid(
torch.linspace(-1+down/W1,1-down/W1, W1//down),
torch.linspace(-1+down/H1,1-down/H1, H1//down),
indexing = "xy"),
dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
cv = corr_volume
best_match = cv.reshape(B,H1*W1,H0,W0).amax(dim=1) # B, HW, H, W
P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
else:
P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
return pos_embeddings
def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
device = warp.device
H,W2,_ = warp.shape
W = W2//2 if symmetric else W2
if im_A is None:
from PIL import Image
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
if not isinstance(im_A, torch.Tensor):
im_A = im_A.resize((W,H))
im_B = im_B.resize((W,H))
x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
if symmetric:
x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
else:
if symmetric:
x_A = im_A
x_B = im_B
im_A_transfer_rgb = F.grid_sample(
x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
)[0]
if symmetric:
im_B_transfer_rgb = F.grid_sample(
x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
)[0]
warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
white_im = torch.ones((H,2*W),device=device)
else:
warp_im = im_A_transfer_rgb
white_im = torch.ones((H, W), device = device)
vis_im = certainty * warp_im + (1 - certainty) * white_im
if save_path is not None:
from romatch.utils import tensor_to_pil
tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
return vis_im
def corr_volume(self, feat0, feat1):
"""
input:
feat0 -> torch.Tensor(B, C, H, W)
feat1 -> torch.Tensor(B, C, H, W)
return:
corr_volume -> torch.Tensor(B, H, W, H, W)
"""
B, C, H0, W0 = feat0.shape
B, C, H1, W1 = feat1.shape
feat0 = feat0.view(B, C, H0*W0)
feat1 = feat1.view(B, C, H1*W1)
corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
return corr_volume
@torch.inference_mode()
def match_from_path(self, im0_path, im1_path):
device = self.device
im0 = ToTensor()(Image.open(im0_path))[None].to(device)
im1 = ToTensor()(Image.open(im1_path))[None].to(device)
return self.match(im0, im1, batched = False)
@torch.inference_mode()
def match(self, im0, im1, *args, batched = True):
# stupid
if isinstance(im0, (str, Path)):
return self.match_from_path(im0, im1)
elif isinstance(im0, Image.Image):
batched = False
device = self.device
im0 = ToTensor()(im0)[None].to(device)
im1 = ToTensor()(im1)[None].to(device)
B,C,H0,W0 = im0.shape
B,C,H1,W1 = im1.shape
self.train(False)
corresps = self.forward({"im_A":im0, "im_B":im1})
#return 1,1
flow = F.interpolate(
corresps[4]["flow"],
size = (H0, W0),
mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
grid = torch.stack(
torch.meshgrid(
torch.linspace(-1+1/W0,1-1/W0, W0),
torch.linspace(-1+1/H0,1-1/H0, H0),
indexing = "xy"),
dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
if batched:
return warp, cert
else:
return warp[0], cert[0]
def sample(
self,
matches,
certainty,
num=10000,
):
if "threshold" in self.sample_mode:
upper_thresh = self.sample_thresh
certainty = certainty.clone()
certainty[certainty > upper_thresh] = 1
matches, certainty = (
matches.reshape(-1, 4),
certainty.reshape(-1),
)
expansion_factor = 4 if "balanced" in self.sample_mode else 1
good_samples = torch.multinomial(certainty,
num_samples = min(expansion_factor*num, len(certainty)),
replacement=False)
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
if "balanced" not in self.sample_mode:
return good_matches, good_certainty
density = kde(good_matches, std=0.1)
p = 1 / (density+1)
p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
balanced_samples = torch.multinomial(p,
num_samples = min(num,len(good_certainty)),
replacement=False)
return good_matches[balanced_samples], good_certainty[balanced_samples]
def forward(self, batch):
"""
input:
x -> torch.Tensor(B, C, H, W) grayscale or rgb images
return:
"""
im0 = batch["im_A"]
im1 = batch["im_B"]
corresps = {}
im0, rh0, rw0 = self.preprocess_tensor(im0)
im1, rh1, rw1 = self.preprocess_tensor(im1)
B, C, H0, W0 = im0.shape
B, C, H1, W1 = im1.shape
to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
if im0.shape[-2:] == im1.shape[-2:]:
x = torch.cat([im0, im1], dim=0)
x = self.forward_single(x)
feats_x0_c, feats_x1_c = x[1].chunk(2)
feats_x0_f, feats_x1_f = x[0].chunk(2)
else:
feats_x0_f, feats_x0_c = self.forward_single(im0)
feats_x1_f, feats_x1_c = self.forward_single(im1)
corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
coarse_warp = self.pos_embed(corr_volume)
coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)
coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
return corresps
def train(args):
rank = 0
gpus = 1
device_id = rank % torch.cuda.device_count()
romatch.LOCAL_RANK = 0
torch.cuda.set_device(device_id)
resolution = "big"
wandb_log = not args.dont_log_wandb
experiment_name = Path(__file__).stem
wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
checkpoint_dir = "workspace/checkpoints/"
h,w = resolutions[resolution]
model = XFeatModel(freeze_xfeat = False).to(device_id)
# Num steps
global_step = 0
batch_size = args.gpu_batch_size
step_size = gpus*batch_size
romatch.STEP_SIZE = step_size
N = 2_000_000 # 2M pairs
# checkpoint every
k = 25000 // romatch.STEP_SIZE
# Data
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
use_horizontal_flip_aug = True
normalize = False # don't imgnet normalize
rot_prob = 0
depth_interpolation_mode = "bilinear"
megadepth_train1 = mega.build_scenes(
split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
ht=h,wt=w, normalize = normalize
)
megadepth_train2 = mega.build_scenes(
split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
ht=h,wt=w, normalize = normalize
)
megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
# Loss and optimizer
depth_loss = RobustLosses(
ce_weight=0.01,
local_dist={4:4},
depth_interpolation_mode=depth_interpolation_mode,
alpha = {4:0.15, 8:0.15},
c = 1e-4,
epe_mask_prob_th = 0.001,
)
parameters = [
{"params": model.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
]
optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
#megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 30)
checkpointer = CheckPoint(checkpoint_dir, experiment_name)
model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
romatch.GLOBAL_STEP = global_step
grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
grad_clip_norm = 0.01
#megadense_benchmark.benchmark(model)
for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
mega_sampler = torch.utils.data.WeightedRandomSampler(
mega_ws, num_samples = batch_size * k, replacement=False
)
mega_dataloader = iter(
torch.utils.data.DataLoader(
megadepth_train,
batch_size = batch_size,
sampler = mega_sampler,
num_workers = 8,
)
)
train_k_steps(
n, k, mega_dataloader, model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
)
checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
wandb.log(mega1500_benchmark.benchmark(model, model_name=experiment_name), step = romatch.GLOBAL_STEP)
def test_mega_8_scenes(model, name):
mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
'mega_8_scenes_0025_0.1_0.3.npz',
'mega_8_scenes_0021_0.1_0.3.npz',
'mega_8_scenes_0008_0.1_0.3.npz',
'mega_8_scenes_0032_0.1_0.3.npz',
'mega_8_scenes_1589_0.1_0.3.npz',
'mega_8_scenes_0063_0.1_0.3.npz',
'mega_8_scenes_0024_0.1_0.3.npz',
'mega_8_scenes_0019_0.3_0.5.npz',
'mega_8_scenes_0025_0.3_0.5.npz',
'mega_8_scenes_0021_0.3_0.5.npz',
'mega_8_scenes_0008_0.3_0.5.npz',
'mega_8_scenes_0032_0.3_0.5.npz',
'mega_8_scenes_1589_0.3_0.5.npz',
'mega_8_scenes_0063_0.3_0.5.npz',
'mega_8_scenes_0024_0.3_0.5.npz'])
mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
print(mega_8_scenes_results)
json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
def test_mega1500(model, name):
mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
def test_mega1500_poselib(model, name):
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
def test_mega_8_scenes_poselib(model, name):
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
'mega_8_scenes_0025_0.1_0.3.npz',
'mega_8_scenes_0021_0.1_0.3.npz',
'mega_8_scenes_0008_0.1_0.3.npz',
'mega_8_scenes_0032_0.1_0.3.npz',
'mega_8_scenes_1589_0.1_0.3.npz',
'mega_8_scenes_0063_0.1_0.3.npz',
'mega_8_scenes_0024_0.1_0.3.npz',
'mega_8_scenes_0019_0.3_0.5.npz',
'mega_8_scenes_0025_0.3_0.5.npz',
'mega_8_scenes_0021_0.3_0.5.npz',
'mega_8_scenes_0008_0.3_0.5.npz',
'mega_8_scenes_0032_0.3_0.5.npz',
'mega_8_scenes_1589_0.3_0.5.npz',
'mega_8_scenes_0063_0.3_0.5.npz',
'mega_8_scenes_0024_0.3_0.5.npz'])
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
def test_scannet_poselib(model, name):
scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
scannet_results = scannet_benchmark.benchmark(model)
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
def test_scannet(model, name):
scannet_benchmark = ScanNetBenchmark("data/scannet")
scannet_results = scannet_benchmark.benchmark(model)
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
if __name__ == "__main__":
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
os.environ["OMP_NUM_THREADS"] = "16"
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
import romatch
parser = ArgumentParser()
parser.add_argument("--only_test", action='store_true')
parser.add_argument("--debug_mode", action='store_true')
parser.add_argument("--dont_log_wandb", action='store_true')
parser.add_argument("--train_resolution", default='medium')
parser.add_argument("--gpu_batch_size", default=8, type=int)
parser.add_argument("--wandb_entity", required = False)
args, _ = parser.parse_known_args()
romatch.DEBUG_MODE = args.debug_mode
if not args.only_test:
train(args)
experiment_name = "tiny_roma_v1_outdoor"#Path(__file__).stem
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = XFeatModel(freeze_xfeat=False, exact_softmax=False).to(device)
model.load_state_dict(torch.load(f"{experiment_name}.pth"))
test_mega1500_poselib(model, experiment_name)
|