|
import torch |
|
import torch.nn as nn |
|
import yaml |
|
import sys |
|
sys.path.append(".") |
|
sys.path.append("submodules") |
|
sys.path.append("submodules/mast3r") |
|
from mast3r.model import AsymmetricMASt3R |
|
from src.ptv3 import PTV3 |
|
from src.gaussian_head import GaussianHead |
|
from src.utils.points_process import merge_points |
|
from src.losses import GaussianLoss |
|
from src.lseg import LSegFeatureExtractor |
|
import argparse |
|
|
|
class LSM_MASt3R(nn.Module): |
|
def __init__(self, |
|
mast3r_config, |
|
point_transformer_config, |
|
gaussian_head_config, |
|
lseg_config, |
|
): |
|
|
|
super().__init__() |
|
|
|
self.config = { |
|
'mast3r_config': mast3r_config, |
|
'point_transformer_config': point_transformer_config, |
|
'gaussian_head_config': gaussian_head_config, |
|
'lseg_config': lseg_config |
|
} |
|
|
|
|
|
self.mast3r = AsymmetricMASt3R.from_pretrained(**mast3r_config) |
|
|
|
|
|
for param in self.mast3r.parameters(): |
|
param.requires_grad = False |
|
self.mast3r.eval() |
|
|
|
|
|
self.point_transformer = PTV3(**point_transformer_config) |
|
|
|
|
|
self.gaussian_head = GaussianHead(**gaussian_head_config) |
|
|
|
|
|
self.lseg_feature_extractor = LSegFeatureExtractor.from_pretrained(**lseg_config) |
|
for param in self.lseg_feature_extractor.parameters(): |
|
param.requires_grad = False |
|
self.lseg_feature_extractor.eval() |
|
|
|
|
|
d_gs_feats = gaussian_head_config.get('d_gs_feats', 32) |
|
self.feature_reduction = nn.Sequential( |
|
nn.Conv2d(512, d_gs_feats, kernel_size=1), |
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
) |
|
|
|
self.feature_expansion = nn.Sequential( |
|
nn.Conv2d(d_gs_feats, 512, kernel_size=1), |
|
nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) |
|
) |
|
|
|
def forward(self, view1, view2): |
|
|
|
mast3r_output = self.mast3r(view1, view2) |
|
|
|
|
|
data_dict = merge_points(mast3r_output, view1, view2) |
|
|
|
|
|
point_transformer_output = self.point_transformer(data_dict) |
|
|
|
|
|
lseg_features = self.extract_lseg_features(view1, view2) |
|
|
|
|
|
final_output = self.gaussian_head(point_transformer_output, lseg_features) |
|
|
|
return final_output |
|
|
|
def extract_lseg_features(self, view1, view2): |
|
|
|
img = torch.cat([view1['img'], view2['img']], dim=0) |
|
|
|
lseg_features = self.lseg_feature_extractor.extract_features(img) |
|
|
|
lseg_features = self.feature_reduction(lseg_features) |
|
|
|
return lseg_features |
|
|
|
@staticmethod |
|
def from_pretrained(checkpoint_path, device='cuda'): |
|
|
|
ckpt = torch.load(checkpoint_path, map_location='cpu') |
|
|
|
|
|
config = ckpt['args'] |
|
|
|
|
|
model = eval(config.model) |
|
|
|
|
|
model.load_state_dict(ckpt['model']) |
|
|
|
|
|
model = model.to(device) |
|
|
|
return model |
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False): |
|
|
|
full_state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) |
|
|
|
|
|
trainable_state_dict = { |
|
k: v for k, v in full_state_dict.items() |
|
if not (k.startswith('mast3r.') or k.startswith('lseg_feature_extractor.')) |
|
} |
|
|
|
return trainable_state_dict |
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
|
|
model_state = super().state_dict() |
|
|
|
|
|
for k in list(state_dict.keys()): |
|
if k in model_state and not (k.startswith('mast3r.') or k.startswith('lseg_feature_extractor.')): |
|
model_state[k] = state_dict[k] |
|
|
|
|
|
super().load_state_dict(model_state, strict=False) |
|
|
|
if __name__ == "__main__": |
|
from torch.utils.data import DataLoader |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--checkpoint', type=str) |
|
args = parser.parse_args() |
|
|
|
|
|
with open("configs/model_config.yaml", "r") as f: |
|
config = yaml.safe_load(f) |
|
|
|
if args.checkpoint is not None: |
|
model = LSM_MASt3R.from_pretrained(args.checkpoint, device='cuda') |
|
else: |
|
model = LSM_MASt3R(**config).to('cuda') |
|
|
|
model.eval() |
|
|
|
|
|
print(model) |
|
|
|
from src.datasets.scannet import Scannet |
|
dataset = Scannet(split='train', ROOT="data/scannet_processed", resolution=[(512, 384)]) |
|
|
|
print(dataset) |
|
|
|
data_loader = DataLoader(dataset, batch_size=3, shuffle=True) |
|
data = next(iter(data_loader)) |
|
|
|
for view in data: |
|
view['img'] = view['img'].to('cuda') |
|
view['depthmap'] = view['depthmap'].to('cuda') |
|
view['camera_pose'] = view['camera_pose'].to('cuda') |
|
view['camera_intrinsics'] = view['camera_intrinsics'].to('cuda') |
|
|
|
output = model(*data[:2]) |
|
|
|
|
|
loss = GaussianLoss() |
|
loss_value = loss(*data, *output, model) |
|
print(loss_value) |
|
|