LSM / src /model.py
kairunwen's picture
Update Code
57746f1
import torch
import torch.nn as nn
import yaml
import sys
sys.path.append(".")
sys.path.append("submodules")
sys.path.append("submodules/mast3r")
from mast3r.model import AsymmetricMASt3R
from src.ptv3 import PTV3
from src.gaussian_head import GaussianHead
from src.utils.points_process import merge_points
from src.losses import GaussianLoss
from src.lseg import LSegFeatureExtractor
import argparse
class LSM_MASt3R(nn.Module):
def __init__(self,
mast3r_config,
point_transformer_config,
gaussian_head_config,
lseg_config,
):
super().__init__()
# self.config
self.config = {
'mast3r_config': mast3r_config,
'point_transformer_config': point_transformer_config,
'gaussian_head_config': gaussian_head_config,
'lseg_config': lseg_config
}
# Initialize AsymmetricMASt3R
self.mast3r = AsymmetricMASt3R.from_pretrained(**mast3r_config)
# Freeze MASt3R parameters
for param in self.mast3r.parameters():
param.requires_grad = False
self.mast3r.eval()
# Initialize PointTransformerV3
self.point_transformer = PTV3(**point_transformer_config)
# Initialize the gaussian head
self.gaussian_head = GaussianHead(**gaussian_head_config)
# Initialize the lseg feature extractor
self.lseg_feature_extractor = LSegFeatureExtractor.from_pretrained(**lseg_config)
for param in self.lseg_feature_extractor.parameters():
param.requires_grad = False
self.lseg_feature_extractor.eval()
# Define two linear layers
d_gs_feats = gaussian_head_config.get('d_gs_feats', 32)
self.feature_reduction = nn.Sequential(
nn.Conv2d(512, d_gs_feats, kernel_size=1),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
) # (b, 512, h//2, w//2) -> (b, d_features, h, w)
self.feature_expansion = nn.Sequential(
nn.Conv2d(d_gs_feats, 512, kernel_size=1),
nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)
) # (b, d_features, h, w) -> (b, 512, h//2, w//2)
def forward(self, view1, view2):
# AsymmetricMASt3R forward pass
mast3r_output = self.mast3r(view1, view2)
# merge points from two views
data_dict = merge_points(mast3r_output, view1, view2)
# PointTransformerV3 forward pass
point_transformer_output = self.point_transformer(data_dict)
# extract lseg features
lseg_features = self.extract_lseg_features(view1, view2)
# Gaussian head forward pass
final_output = self.gaussian_head(point_transformer_output, lseg_features)
return final_output
def extract_lseg_features(self, view1, view2):
# concat view1 and view2
img = torch.cat([view1['img'], view2['img']], dim=0) # (v*b, 3, h, w)
# extract features
lseg_features = self.lseg_feature_extractor.extract_features(img) # (v*b, 512, h//2, w//2)
# reduce dimensions
lseg_features = self.feature_reduction(lseg_features) # (v*b, d_features, h, w)
return lseg_features
@staticmethod
def from_pretrained(checkpoint_path, device='cuda'):
# Load the checkpoint
ckpt = torch.load(checkpoint_path, map_location='cpu')
# Extract the configuration from the checkpoint
config = ckpt['args']
# Create a new instance of LSM_MASt3R
model = eval(config.model)
# Load the state dict
model.load_state_dict(ckpt['model'])
# Move the model to the specified device
model = model.to(device)
return model
def state_dict(self, destination=None, prefix='', keep_vars=False):
# 获取所有参数的state_dict
full_state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# 只保留需要训练的参数
trainable_state_dict = {
k: v for k, v in full_state_dict.items()
if not (k.startswith('mast3r.') or k.startswith('lseg_feature_extractor.'))
}
return trainable_state_dict
def load_state_dict(self, state_dict, strict=True):
# 获取当前模型的完整state_dict
model_state = super().state_dict()
# 只更新需要训练的参数
for k in list(state_dict.keys()):
if k in model_state and not (k.startswith('mast3r.') or k.startswith('lseg_feature_extractor.')):
model_state[k] = state_dict[k]
# 使用更新后的state_dict加载模型
super().load_state_dict(model_state, strict=False)
if __name__ == "__main__":
from torch.utils.data import DataLoader
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str)
args = parser.parse_args()
# Load config
with open("configs/model_config.yaml", "r") as f:
config = yaml.safe_load(f)
# Initialize model
if args.checkpoint is not None:
model = LSM_MASt3R.from_pretrained(args.checkpoint, device='cuda')
else:
model = LSM_MASt3R(**config).to('cuda')
model.eval()
# Print model
print(model)
# Load dataset
from src.datasets.scannet import Scannet
dataset = Scannet(split='train', ROOT="data/scannet_processed", resolution=[(512, 384)])
# Print dataset
print(dataset)
# Test model
data_loader = DataLoader(dataset, batch_size=3, shuffle=True)
data = next(iter(data_loader))
# move data to cuda
for view in data:
view['img'] = view['img'].to('cuda')
view['depthmap'] = view['depthmap'].to('cuda')
view['camera_pose'] = view['camera_pose'].to('cuda')
view['camera_intrinsics'] = view['camera_intrinsics'].to('cuda')
# Forward pass
output = model(*data[:2])
# Loss
loss = GaussianLoss()
loss_value = loss(*data, *output, model)
print(loss_value)