File size: 7,627 Bytes
7cc4b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
dinosiglip_vit.py

Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
"""
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Tuple
import os
import timm
import torch
from PIL import Image
from timm.models.vision_transformer import Block, VisionTransformer
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
from torchvision.transforms import Compose, Resize

from models.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple, return_tuple

import torchvision
import torch.nn as nn

@dataclass
class DinoSigLIPImageTransform:
    dino_image_transform: ImageTransform
    siglip_image_transform: ImageTransform
    is_cobra: bool = True

    def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
        return {"dino": self.dino_image_transform(img, **kwargs).unsqueeze(0), "siglip": self.siglip_image_transform(img, **kwargs).unsqueeze(0)}


class SigLIPViTBackbone(VisionBackbone):
    def __init__(self, backbone_name_or_path: str, image_resize_strategy: str, default_image_size: int = 224, last_n = 2, feature_index = 25) -> None:
        super().__init__(backbone_name_or_path, image_resize_strategy, default_image_size=default_image_size)
        # load from local paths
        sigclip_pretrained_cfg = timm.models.create_model(backbone_name_or_path).default_cfg
        sigclip_pretrained_cfg['file'] = 'ckpts/vit_so400m_patch14_siglip_384/open_clip_pytorch_model.bin'

        # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
        self.siglip_featurizer: VisionTransformer = timm.create_model(
            backbone_name_or_path, pretrained=True, num_classes=0, img_size=self.default_image_size,
            pretrained_cfg=sigclip_pretrained_cfg
        )
        self.siglip_featurizer.eval()

        # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
        #   => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
        #   => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
        # return the output tokens from the `n` last blocks
        print("siglip has {} layer intermediate features. ".format(len(self.siglip_featurizer.blocks))) # 27
        # self.siglip_featurizer.forward = unpack_tuple(
        #     partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - last_n})
        # )
        if isinstance(feature_index, tuple) or isinstance(feature_index, list):
            feature_index = set(feature_index)
        else:
            feature_index = {feature_index}
        self.siglip_featurizer.forward = return_tuple(
            partial(self.siglip_featurizer.get_intermediate_layers, n=feature_index)
        )

        # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models

        self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer)
        self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)

        # Initialize *both* Transforms
        default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False)

        # Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
        assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!"
        assert isinstance(sl_resize_transform := default_siglip_transform.transforms[0], Resize)
        default_siglip_transform = Compose(
            [
                Resize(self.default_image_size, interpolation=sl_resize_transform.interpolation),
                *default_siglip_transform.transforms[1:],
            ]
        )

        if self.image_resize_strategy == "resize-naive":
            assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!"
            assert isinstance(siglip_resize_transform := default_siglip_transform.transforms[0], Resize)

            target_size = (self.default_image_size, self.default_image_size)
            siglip_transform = Compose(
                [
                    Resize(target_size, interpolation=siglip_resize_transform.interpolation),
                    *default_siglip_transform.transforms[1:],
                ]
            )

            self.siglip_transform = siglip_transform
        else:
            raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")

    def get_fsdp_wrapping_policy(self) -> Callable:
        """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
        vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
        transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
        return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])

    def forward(self, pixel_values, device="cpu") -> torch.Tensor:
        """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
        # b, c , h , w : 0-1
        t_tensors = []
        for pixel_value in pixel_values:
            t_tensors.append(self.siglip_transform(pixel_value).unsqueeze(0))
        t_tensors = torch.cat(t_tensors, dim=0).to(device)

        t_tensors_list = self.siglip_featurizer(t_tensors)
        return t_tensors_list
        
    @property
    def default_image_resolution(self) -> Tuple[int, int, int]:
        return self.dino_data_cfg["input_size"]

    @property
    def embed_dim(self) -> int:
        return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim

    @property
    def num_patches(self) -> int:
        assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
        return self.dino_featurizer.patch_embed.num_patches

    @property
    def half_precision_dtype(self) -> torch.dtype:
        return torch.bfloat16


class SigLIPEncoder(nn.Module):
    def __init__(self, backbone_name_or_path: str, image_resize_strategy: str, default_image_size: int = 224, feature_index = 25):
        super().__init__()
        self.image_encoder = SigLIPViTBackbone(backbone_name_or_path, image_resize_strategy, default_image_size, feature_index)
        self.to_pil = torchvision.transforms.ToPILImage()

    def forward(self, image_tensor, device="cpu"): # input image size = 768
        pixel_values = []
        for image_tensor_i in image_tensor:
            pixel_values.append(self.to_pil(image_tensor_i))
        
        embeddings_dino_list = self.image_encoder(pixel_values, device)
        if len(embeddings_dino_list) == 1:
            embeddings_dino_list = embeddings_dino_list[0]
        return embeddings_dino_list