Spaces:
Runtime error
Runtime error
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torchvision | |
import torch.nn as nn | |
from einops import rearrange | |
from models.sigclip import SigLIPViTBackbone | |
from models.dino import DinoViTBackbone | |
class ShallowDeepSiglipDinoEncoder(nn.Module): | |
def __init__(self, siglip_config={}, dino_config={}): | |
super().__init__() | |
self.to_pil = torchvision.transforms.ToPILImage() | |
self.image_encoder_siglip = SigLIPViTBackbone(**siglip_config) | |
self.image_encoder_dino = DinoViTBackbone(**dino_config) | |
def forward(self, image_tensor, device="cpu"): | |
bs = image_tensor.size(0) | |
# tensor 转 PIL | |
pixel_values = [] | |
for image_tensor_i in image_tensor: | |
pixel_values.append(self.to_pil(image_tensor_i)) | |
embeddings = [] | |
embeddings_siglip_list = self.image_encoder_siglip(pixel_values, device) | |
embeddings_dino_list = self.image_encoder_dino(pixel_values, device) | |
for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list): | |
embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat | |
embeddings.append(embeddings_i) | |
return embeddings | |
# The default is to use double the image size, i.e., 768x768. | |
class ShallowDeepPatchfySiglipDinoEncoder(nn.Module): | |
def __init__(self, siglip_config={}, dino_config={}, patchfy_scale=2, default_image_size=384): | |
super().__init__() | |
self.to_pil = torchvision.transforms.ToPILImage() | |
self.image_encoder_siglip = SigLIPViTBackbone(**siglip_config) | |
self.image_encoder_dino = DinoViTBackbone(**dino_config) | |
self.patchfy = (patchfy_scale > 1) | |
self.patchfy_scale = patchfy_scale | |
self.default_image_size = default_image_size | |
def forward(self, image_tensor, device="cpu", **kwargs): # input image size = 768 | |
image_tensor = image_tensor["image_ref"] # this is a dict | |
bs = image_tensor.size(0) | |
if self.patchfy: | |
image_local = rearrange(image_tensor, "b c (h hp) (w wp) -> (b hp wp) c h w", hp=self.patchfy_scale, wp=self.patchfy_scale) | |
image_global = torch.nn.functional.interpolate(image_tensor, size=(self.default_image_size, self.default_image_size), mode='bilinear', align_corners=True) | |
# tensor 转 PIL | |
pixel_values_local, pixel_values_global = [], [] | |
for image_tensor_i in image_local: | |
pixel_values_local.append(self.to_pil(image_tensor_i.to(torch.float))) | |
for image_tensor_i in image_global: | |
pixel_values_global.append(self.to_pil(image_tensor_i.to(torch.float))) | |
embeddings = [] | |
embeddings_siglip_list = self.image_encoder_siglip(pixel_values_global, device) | |
embeddings_dino_list = self.image_encoder_dino(pixel_values_global, device) | |
for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list): | |
embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat | |
embeddings.append(embeddings_i) | |
embeddings_local_siglip_deep = self.image_encoder_siglip(pixel_values_local, device)[-1] | |
embeddings_local_dino_deep = self.image_encoder_dino(pixel_values_local, device)[-1] | |
embeddings_local_deep = torch.cat([embeddings_local_siglip_deep, embeddings_local_dino_deep], dim=-1) | |
embeddings_local_deep = rearrange(embeddings_local_deep, "(b hp wp) l c -> b (l hp wp) c", hp=self.patchfy_scale, wp=self.patchfy_scale) | |
embeddings.append(embeddings_local_deep) | |
else: | |
# tensor 转 PIL | |
pixel_values = [] | |
for image_tensor_i in image_tensor: | |
pixel_values.append(self.to_pil(image_tensor_i)) | |
embeddings = [] | |
embeddings_siglip_list = self.image_encoder_siglip(pixel_values, device) | |
embeddings_dino_list = self.image_encoder_dino(pixel_values, device) | |
for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list): | |
# 逐层concat的方式 | |
embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat | |
embeddings.append(embeddings_i) | |
if len(embeddings) == 1: | |
embeddings = embeddings[0] | |
return embeddings | |
class ShallowDeepPatchfySiglipDinoEncoder_v2(nn.Module): | |
def __init__(self, siglip_config={}, dino_config={}, patchfy_scale=2, default_image_size=384): | |
super().__init__() | |
self.to_pil = torchvision.transforms.ToPILImage() | |
self.image_encoder_siglip = SigLIPViTBackbone(**siglip_config) | |
self.image_encoder_dino = DinoViTBackbone(**dino_config) | |
self.patchfy = (patchfy_scale > 1) | |
self.patchfy_scale = patchfy_scale | |
self.default_image_size = default_image_size | |
def forward(self, image_tensor_dict, device="cpu", **kwargs): # input image size = 768 | |
image_tensor = image_tensor_dict["image_ref"] | |
bs = image_tensor.size(0) | |
if self.patchfy: | |
image_local = rearrange(image_tensor, "b c (h hp) (w wp) -> (b hp wp) c h w", hp=self.patchfy_scale, wp=self.patchfy_scale) | |
image_global = torch.nn.functional.interpolate(image_tensor, size=(self.default_image_size, self.default_image_size), mode='bilinear', align_corners=True) | |
pixel_values_local, pixel_values_global = [], [] | |
for image_tensor_i in image_local: | |
pixel_values_local.append(self.to_pil(image_tensor_i.to(torch.float32))) | |
for image_tensor_i in image_global: | |
pixel_values_global.append(self.to_pil(image_tensor_i.to(torch.float32))) | |
embeddings = [] | |
embeddings_siglip_list = self.image_encoder_siglip(pixel_values_global, device) | |
embeddings_dino_list = self.image_encoder_dino(pixel_values_global, device) | |
for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list): | |
embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat | |
embeddings.append(embeddings_i) | |
embeddings_local_siglip_deep = self.image_encoder_siglip(pixel_values_local, device)[-1] | |
embeddings_local_dino_deep = self.image_encoder_dino(pixel_values_local, device)[-1] | |
embeddings_local_deep = torch.cat([embeddings_local_siglip_deep, embeddings_local_dino_deep], dim=-1) | |
embeddings_local_deep = rearrange(embeddings_local_deep, "(b hp wp) l c -> b (l hp wp) c", hp=self.patchfy_scale, wp=self.patchfy_scale) | |
embeddings.append(embeddings_local_deep) | |
else: | |
# tensor 转 PIL | |
pixel_values = [] | |
for image_tensor_i in image_tensor: | |
pixel_values.append(self.to_pil(image_tensor_i)) | |
embeddings = [] | |
embeddings_siglip_list = self.image_encoder_siglip(pixel_values, device) | |
embeddings_dino_list = self.image_encoder_dino(pixel_values, device) | |
for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list): | |
embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat | |
embeddings.append(embeddings_i) | |
if len(embeddings) == 1: | |
embeddings = embeddings[0] | |
return embeddings | |