LSM / src /lseg.py
kairunwen's picture
Update Code
57746f1
import torch
import torch.nn as nn
from submodules.lang_seg.modules.models.lseg_net import LSegNet, clip
class LSegFeatureExtractor(LSegNet):
def __init__(self, half_res=True):
super().__init__(
labels='',
backbone='clip_vitl16_384',
features=256,
crop_size=224,
arch_option=0,
block_depth=0,
activation='lrelu'
)
self.half_res = half_res
@torch.no_grad()
def extract_features(self, x):
layer_1, layer_2, layer_3, layer_4 = forward_layers(self.pretrained, x)
# layer:(b, 1024, h//16, w//16)
# image_features = torch.cat([layer_1, layer_2, layer_3, layer_4], dim=1)
# # image_features:(b, 4096, h//16, w//16)
# dense feature
# DPT head
pretrained = self.pretrained
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
# refinenet
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
# (b, 512, h//2, w//2)
image_features = self.scratch.head1(path_1)
if self.half_res:
return image_features
# (b, 512, h, w)
image_features = self.scratch.output_conv(image_features)
return image_features
@torch.no_grad()
def decode_feature(self, image_features, labelset=''):
# # image_features:(b, 4096, h//16, w//16)
# # split image_features into 4 parts
# layer_1, layer_2, layer_3, layer_4 = torch.split(image_features, 1024, dim=1)
# # DPT head
# pretrained = self.pretrained
# layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
# layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
# layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
# layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
# # refinenet
# layer_1_rn = self.scratch.layer1_rn(layer_1)
# layer_2_rn = self.scratch.layer2_rn(layer_2)
# layer_3_rn = self.scratch.layer3_rn(layer_3)
# layer_4_rn = self.scratch.layer4_rn(layer_4)
# path_4 = self.scratch.refinenet4(layer_4_rn)
# path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
# path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
# path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
# image_features = self.scratch.head1(path_1)
imshape = image_features.shape
# encode text
if labelset == '':
text = self.text
else:
text = clip.tokenize(labelset)
self.logit_scale = self.logit_scale.to(image_features.device)
text = text.to(image_features.device)
text_features = self.clip_pretrained.encode_text(text)
image_features = image_features.permute(0,2,3,1).reshape(-1, self.out_c)
# normalized features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logits_per_image = self.logit_scale * image_features.half() @ text_features.t()
out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], -1).permute(0,3,1,2)
if self.arch_option in [1, 2]:
for _ in range(self.block_depth - 1):
out = self.scratch.head_block(out)
out = self.scratch.head_block(out, False)
if self.half_res:
out = self.scratch.output_conv(out)
return out
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
print(f"Loading checkpoint from: {pretrained_model_name_or_path}")
ckpt = torch.load(pretrained_model_name_or_path, map_location='cpu')
print(f"Checkpoint loaded. Keys in checkpoint: {ckpt.keys()}")
print("Processing state dict...")
new_state_dict = {k[len("net."):]: v for k, v in ckpt['state_dict'].items() if k.startswith("net.")}
print(f"Processed state dict. Number of keys: {len(new_state_dict)}")
print("Initializing model...")
model = cls(*args, **kwargs)
print("Loading state dict into model...")
model.load_state_dict(new_state_dict, strict=True)
print("State dict loaded successfully.")
print("Cleaning up...")
del ckpt
del new_state_dict
print("Model loading complete.")
return model
def forward_layers(pretrained, x):
b, c, h, w = x.shape
# encoder
glob = pretrained.model.forward_flex(x)
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
unflatten = nn.Sequential(
nn.Unflatten(
2,
torch.Size(
[
h // pretrained.model.patch_size[1],
w // pretrained.model.patch_size[0],
]
),
)
)
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten(layer_3)
if layer_4.ndim == 3:
layer_4 = unflatten(layer_4)
return layer_1, layer_2, layer_3, layer_4