|
import torch |
|
import torch.nn as nn |
|
from submodules.lang_seg.modules.models.lseg_net import LSegNet, clip |
|
|
|
class LSegFeatureExtractor(LSegNet): |
|
def __init__(self, half_res=True): |
|
super().__init__( |
|
labels='', |
|
backbone='clip_vitl16_384', |
|
features=256, |
|
crop_size=224, |
|
arch_option=0, |
|
block_depth=0, |
|
activation='lrelu' |
|
) |
|
|
|
self.half_res = half_res |
|
|
|
@torch.no_grad() |
|
def extract_features(self, x): |
|
layer_1, layer_2, layer_3, layer_4 = forward_layers(self.pretrained, x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
pretrained = self.pretrained |
|
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) |
|
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) |
|
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) |
|
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) |
|
|
|
|
|
layer_1_rn = self.scratch.layer1_rn(layer_1) |
|
layer_2_rn = self.scratch.layer2_rn(layer_2) |
|
layer_3_rn = self.scratch.layer3_rn(layer_3) |
|
layer_4_rn = self.scratch.layer4_rn(layer_4) |
|
|
|
path_4 = self.scratch.refinenet4(layer_4_rn) |
|
path_3 = self.scratch.refinenet3(path_4, layer_3_rn) |
|
path_2 = self.scratch.refinenet2(path_3, layer_2_rn) |
|
path_1 = self.scratch.refinenet1(path_2, layer_1_rn) |
|
|
|
|
|
image_features = self.scratch.head1(path_1) |
|
if self.half_res: |
|
return image_features |
|
|
|
|
|
image_features = self.scratch.output_conv(image_features) |
|
|
|
return image_features |
|
|
|
@torch.no_grad() |
|
def decode_feature(self, image_features, labelset=''): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
imshape = image_features.shape |
|
|
|
|
|
if labelset == '': |
|
text = self.text |
|
else: |
|
text = clip.tokenize(labelset) |
|
|
|
self.logit_scale = self.logit_scale.to(image_features.device) |
|
text = text.to(image_features.device) |
|
text_features = self.clip_pretrained.encode_text(text) |
|
image_features = image_features.permute(0,2,3,1).reshape(-1, self.out_c) |
|
|
|
|
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
|
logits_per_image = self.logit_scale * image_features.half() @ text_features.t() |
|
out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], -1).permute(0,3,1,2) |
|
|
|
if self.arch_option in [1, 2]: |
|
for _ in range(self.block_depth - 1): |
|
out = self.scratch.head_block(out) |
|
out = self.scratch.head_block(out, False) |
|
|
|
if self.half_res: |
|
out = self.scratch.output_conv(out) |
|
|
|
return out |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
|
print(f"Loading checkpoint from: {pretrained_model_name_or_path}") |
|
ckpt = torch.load(pretrained_model_name_or_path, map_location='cpu') |
|
print(f"Checkpoint loaded. Keys in checkpoint: {ckpt.keys()}") |
|
|
|
print("Processing state dict...") |
|
new_state_dict = {k[len("net."):]: v for k, v in ckpt['state_dict'].items() if k.startswith("net.")} |
|
print(f"Processed state dict. Number of keys: {len(new_state_dict)}") |
|
|
|
print("Initializing model...") |
|
model = cls(*args, **kwargs) |
|
|
|
print("Loading state dict into model...") |
|
model.load_state_dict(new_state_dict, strict=True) |
|
print("State dict loaded successfully.") |
|
|
|
print("Cleaning up...") |
|
del ckpt |
|
del new_state_dict |
|
|
|
print("Model loading complete.") |
|
return model |
|
|
|
def forward_layers(pretrained, x): |
|
b, c, h, w = x.shape |
|
|
|
|
|
glob = pretrained.model.forward_flex(x) |
|
|
|
layer_1 = pretrained.activations["1"] |
|
layer_2 = pretrained.activations["2"] |
|
layer_3 = pretrained.activations["3"] |
|
layer_4 = pretrained.activations["4"] |
|
|
|
layer_1 = pretrained.act_postprocess1[0:2](layer_1) |
|
layer_2 = pretrained.act_postprocess2[0:2](layer_2) |
|
layer_3 = pretrained.act_postprocess3[0:2](layer_3) |
|
layer_4 = pretrained.act_postprocess4[0:2](layer_4) |
|
|
|
unflatten = nn.Sequential( |
|
nn.Unflatten( |
|
2, |
|
torch.Size( |
|
[ |
|
h // pretrained.model.patch_size[1], |
|
w // pretrained.model.patch_size[0], |
|
] |
|
), |
|
) |
|
) |
|
|
|
if layer_1.ndim == 3: |
|
layer_1 = unflatten(layer_1) |
|
if layer_2.ndim == 3: |
|
layer_2 = unflatten(layer_2) |
|
if layer_3.ndim == 3: |
|
layer_3 = unflatten(layer_3) |
|
if layer_4.ndim == 3: |
|
layer_4 = unflatten(layer_4) |
|
|
|
return layer_1, layer_2, layer_3, layer_4 |
|
|