# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ base_vision.py Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility functions, and initialization logic. We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision Transformer model for feature extraction. """ from abc import ABC, abstractmethod from dataclasses import dataclass from functools import partial from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union import timm import torch import torch.nn as nn import torchvision.transforms.functional as TVF from PIL.Image import Image from timm.models.vision_transformer import Block, VisionTransformer from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy from torchvision.transforms import Compose, Resize # === Utility Functions for Monkey-Patching === def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: def wrapper(*args: Any, **kwargs: Any) -> Any: result = fn(*args, **kwargs) return result[0] if (isinstance(result, tuple) or isinstance(result, list)) else result return wrapper def return_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: def wrapper(*args: Any, **kwargs: Any) -> Any: result = fn(*args, **kwargs) return result return wrapper # === Interface for an Image Transform === class ImageTransform(Protocol): def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ... # === Custom Torchvision Image Transforms === @dataclass class LetterboxPad: padding_fill_value: Tuple[int, int, int] def __call__(self, image: Image) -> Image: """Given a PIL.Image, pad to square by adding a symmetric border around the height/width.""" (w, h), max_wh = image.size, max(image.size) horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2) padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant") # === Abstract Base Class for arbitrary Vision Backbones === class VisionBackbone(nn.Module, ABC): def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: super().__init__() self.identifier: str = vision_backbone_id self.image_resize_strategy: str = image_resize_strategy self.default_image_size: int = default_image_size # Instance attributes for a Vision Backbone self.featurizer: nn.Module = None self.image_transform: ImageTransform = None def get_image_transform(self) -> ImageTransform: return self.image_transform @abstractmethod def get_fsdp_wrapping_policy(self) -> Callable: ... @abstractmethod def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features.""" raise NotImplementedError @property @abstractmethod def default_image_resolution(self) -> Tuple[int, int, int]: ... @property @abstractmethod def embed_dim(self) -> int: ... @property @abstractmethod def num_patches(self) -> int: ... @property @abstractmethod def half_precision_dtype(self) -> torch.dtype: ... # === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones === class TimmViTBackbone(VisionBackbone, ABC): def __init__( self, vision_backbone_id: str, timm_path_or_url: str, image_resize_strategy: str, default_image_size: int = 224, override_act_layer: Optional[str] = None, ) -> None: super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) self.timm_path_or_url = timm_path_or_url self.override_act_layer = override_act_layer self.dtype = torch.bfloat16 # Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary if self.override_act_layer is None: self.featurizer: VisionTransformer = timm.create_model( self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size, ) else: self.featurizer: VisionTransformer = timm.create_model( self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size, act_layer=self.override_act_layer, ) self.featurizer.eval() # Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 self.featurizer.forward = unpack_tuple( partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2}) ) # Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!) assert isinstance(self.featurizer, VisionTransformer), ( "Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, " "file an issue or implement the requisite logic (see `cobra/models/backbones/vision/base_vision.py`)!" ) # Get Config =>> Note :: Override default image size to ensure correct image transform self.data_cfg = timm.data.resolve_model_data_config(self.featurizer) self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) # Initialize Default Image Transform --> Modified by `self.image_resize_strategy` default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False) # Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)! if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url: assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" assert isinstance(resize_transform := default_image_transform.transforms[0], Resize) default_image_transform = Compose( [ Resize(self.default_image_size, interpolation=resize_transform.interpolation), *default_image_transform.transforms[1:], ] ) # Switch on `image_resize_strategy` if self.image_resize_strategy == "resize-naive": assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" assert isinstance(resize_transform := default_image_transform.transforms[0], Resize) target_size = (self.default_image_size, self.default_image_size) self.image_transform = Compose( [ Resize(target_size, interpolation=resize_transform.interpolation), *default_image_transform.transforms[1:], ] ) elif self.image_resize_strategy == "resize-crop": self.image_transform = default_image_transform elif self.image_resize_strategy == "letterbox": assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!" # Compute Padding Fill Value (rescaled normalization mean if applicable) fill = tuple([int(x * 255) for x in self.data_cfg["mean"]]) # Build New Transform self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms]) else: raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") def get_fsdp_wrapping_policy(self) -> Callable: """Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer.""" vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: """Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features.""" return self.featurizer(pixel_values) @property def default_image_resolution(self) -> Tuple[int, int, int]: return self.data_cfg["input_size"] @property def embed_dim(self) -> int: return self.featurizer.embed_dim @property def num_patches(self) -> int: return self.featurizer.patch_embed.num_patches @property def half_precision_dtype(self) -> torch.dtype: return self.dtype