Spaces:
Runtime error
Runtime error
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
dinosiglip_vit.py | |
Vision backbone that returns concatenated features from both DINOv2 and SigLIP. | |
""" | |
from dataclasses import dataclass | |
from functools import partial | |
from typing import Callable, Dict, Tuple | |
import os | |
import timm | |
import torch | |
from PIL import Image | |
from timm.models.vision_transformer import Block, VisionTransformer | |
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy | |
from torchvision.transforms import Compose, Resize | |
from models.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple, return_tuple | |
import torchvision | |
import torch.nn as nn | |
class DinoSigLIPImageTransform: | |
dino_image_transform: ImageTransform | |
siglip_image_transform: ImageTransform | |
is_cobra: bool = True | |
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: | |
return {"dino": self.dino_image_transform(img, **kwargs).unsqueeze(0), "siglip": self.siglip_image_transform(img, **kwargs).unsqueeze(0)} | |
class SigLIPViTBackbone(VisionBackbone): | |
def __init__(self, backbone_name_or_path: str, image_resize_strategy: str, default_image_size: int = 224, last_n = 2, feature_index = 25) -> None: | |
super().__init__(backbone_name_or_path, image_resize_strategy, default_image_size=default_image_size) | |
# load from local paths | |
sigclip_pretrained_cfg = timm.models.create_model(backbone_name_or_path).default_cfg | |
sigclip_pretrained_cfg['file'] = 'ckpts/vit_so400m_patch14_siglip_384/open_clip_pytorch_model.bin' | |
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary | |
self.siglip_featurizer: VisionTransformer = timm.create_model( | |
backbone_name_or_path, pretrained=True, num_classes=0, img_size=self.default_image_size, | |
pretrained_cfg=sigclip_pretrained_cfg | |
) | |
self.siglip_featurizer.eval() | |
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility | |
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! | |
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 | |
# return the output tokens from the `n` last blocks | |
print("siglip has {} layer intermediate features. ".format(len(self.siglip_featurizer.blocks))) # 27 | |
# self.siglip_featurizer.forward = unpack_tuple( | |
# partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - last_n}) | |
# ) | |
if isinstance(feature_index, tuple) or isinstance(feature_index, list): | |
feature_index = set(feature_index) | |
else: | |
feature_index = {feature_index} | |
self.siglip_featurizer.forward = return_tuple( | |
partial(self.siglip_featurizer.get_intermediate_layers, n=feature_index) | |
) | |
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models | |
self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer) | |
self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) | |
# Initialize *both* Transforms | |
default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False) | |
# Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!! | |
assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!" | |
assert isinstance(sl_resize_transform := default_siglip_transform.transforms[0], Resize) | |
default_siglip_transform = Compose( | |
[ | |
Resize(self.default_image_size, interpolation=sl_resize_transform.interpolation), | |
*default_siglip_transform.transforms[1:], | |
] | |
) | |
if self.image_resize_strategy == "resize-naive": | |
assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!" | |
assert isinstance(siglip_resize_transform := default_siglip_transform.transforms[0], Resize) | |
target_size = (self.default_image_size, self.default_image_size) | |
siglip_transform = Compose( | |
[ | |
Resize(target_size, interpolation=siglip_resize_transform.interpolation), | |
*default_siglip_transform.transforms[1:], | |
] | |
) | |
self.siglip_transform = siglip_transform | |
else: | |
raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") | |
def get_fsdp_wrapping_policy(self) -> Callable: | |
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" | |
vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) | |
transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) | |
return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) | |
def forward(self, pixel_values, device="cpu") -> torch.Tensor: | |
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" | |
# b, c , h , w : 0-1 | |
t_tensors = [] | |
for pixel_value in pixel_values: | |
t_tensors.append(self.siglip_transform(pixel_value).unsqueeze(0)) | |
t_tensors = torch.cat(t_tensors, dim=0).to(device) | |
t_tensors_list = self.siglip_featurizer(t_tensors) | |
return t_tensors_list | |
def default_image_resolution(self) -> Tuple[int, int, int]: | |
return self.dino_data_cfg["input_size"] | |
def embed_dim(self) -> int: | |
return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim | |
def num_patches(self) -> int: | |
assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches | |
return self.dino_featurizer.patch_embed.num_patches | |
def half_precision_dtype(self) -> torch.dtype: | |
return torch.bfloat16 | |
class SigLIPEncoder(nn.Module): | |
def __init__(self, backbone_name_or_path: str, image_resize_strategy: str, default_image_size: int = 224, feature_index = 25): | |
super().__init__() | |
self.image_encoder = SigLIPViTBackbone(backbone_name_or_path, image_resize_strategy, default_image_size, feature_index) | |
self.to_pil = torchvision.transforms.ToPILImage() | |
def forward(self, image_tensor, device="cpu"): # input image size = 768 | |
pixel_values = [] | |
for image_tensor_i in image_tensor: | |
pixel_values.append(self.to_pil(image_tensor_i)) | |
embeddings_dino_list = self.image_encoder(pixel_values, device) | |
if len(embeddings_dino_list) == 1: | |
embeddings_dino_list = embeddings_dino_list[0] | |
return embeddings_dino_list |