Spaces:
Runtime error
Runtime error
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
base_vision.py | |
Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility | |
functions, and initialization logic. | |
We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision | |
Transformer model for feature extraction. | |
""" | |
from abc import ABC, abstractmethod | |
from dataclasses import dataclass | |
from functools import partial | |
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union | |
import timm | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms.functional as TVF | |
from PIL.Image import Image | |
from timm.models.vision_transformer import Block, VisionTransformer | |
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy | |
from torchvision.transforms import Compose, Resize | |
# === Utility Functions for Monkey-Patching === | |
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: | |
def wrapper(*args: Any, **kwargs: Any) -> Any: | |
result = fn(*args, **kwargs) | |
return result[0] if (isinstance(result, tuple) or isinstance(result, list)) else result | |
return wrapper | |
def return_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: | |
def wrapper(*args: Any, **kwargs: Any) -> Any: | |
result = fn(*args, **kwargs) | |
return result | |
return wrapper | |
# === Interface for an Image Transform === | |
class ImageTransform(Protocol): | |
def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ... | |
# === Custom Torchvision Image Transforms === | |
class LetterboxPad: | |
padding_fill_value: Tuple[int, int, int] | |
def __call__(self, image: Image) -> Image: | |
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width.""" | |
(w, h), max_wh = image.size, max(image.size) | |
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2) | |
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) | |
return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant") | |
# === Abstract Base Class for arbitrary Vision Backbones === | |
class VisionBackbone(nn.Module, ABC): | |
def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: | |
super().__init__() | |
self.identifier: str = vision_backbone_id | |
self.image_resize_strategy: str = image_resize_strategy | |
self.default_image_size: int = default_image_size | |
# Instance attributes for a Vision Backbone | |
self.featurizer: nn.Module = None | |
self.image_transform: ImageTransform = None | |
def get_image_transform(self) -> ImageTransform: | |
return self.image_transform | |
def get_fsdp_wrapping_policy(self) -> Callable: ... | |
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: | |
"""Run a forward pass through the featurizer given a set of processed images, returning patch/grid features.""" | |
raise NotImplementedError | |
def default_image_resolution(self) -> Tuple[int, int, int]: ... | |
def embed_dim(self) -> int: ... | |
def num_patches(self) -> int: ... | |
def half_precision_dtype(self) -> torch.dtype: ... | |
# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones === | |
class TimmViTBackbone(VisionBackbone, ABC): | |
def __init__( | |
self, | |
vision_backbone_id: str, | |
timm_path_or_url: str, | |
image_resize_strategy: str, | |
default_image_size: int = 224, | |
override_act_layer: Optional[str] = None, | |
) -> None: | |
super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) | |
self.timm_path_or_url = timm_path_or_url | |
self.override_act_layer = override_act_layer | |
self.dtype = torch.bfloat16 | |
# Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary | |
if self.override_act_layer is None: | |
self.featurizer: VisionTransformer = timm.create_model( | |
self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size, | |
) | |
else: | |
self.featurizer: VisionTransformer = timm.create_model( | |
self.timm_path_or_url, | |
pretrained=True, | |
num_classes=0, | |
img_size=self.default_image_size, | |
act_layer=self.override_act_layer, | |
) | |
self.featurizer.eval() | |
# Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility | |
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! | |
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 | |
self.featurizer.forward = unpack_tuple( | |
partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2}) | |
) | |
# Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!) | |
assert isinstance(self.featurizer, VisionTransformer), ( | |
"Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, " | |
"file an issue or implement the requisite logic (see `cobra/models/backbones/vision/base_vision.py`)!" | |
) | |
# Get Config =>> Note :: Override default image size to ensure correct image transform | |
self.data_cfg = timm.data.resolve_model_data_config(self.featurizer) | |
self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) | |
# Initialize Default Image Transform --> Modified by `self.image_resize_strategy` | |
default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False) | |
# Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)! | |
if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url: | |
assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" | |
assert isinstance(resize_transform := default_image_transform.transforms[0], Resize) | |
default_image_transform = Compose( | |
[ | |
Resize(self.default_image_size, interpolation=resize_transform.interpolation), | |
*default_image_transform.transforms[1:], | |
] | |
) | |
# Switch on `image_resize_strategy` | |
if self.image_resize_strategy == "resize-naive": | |
assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" | |
assert isinstance(resize_transform := default_image_transform.transforms[0], Resize) | |
target_size = (self.default_image_size, self.default_image_size) | |
self.image_transform = Compose( | |
[ | |
Resize(target_size, interpolation=resize_transform.interpolation), | |
*default_image_transform.transforms[1:], | |
] | |
) | |
elif self.image_resize_strategy == "resize-crop": | |
self.image_transform = default_image_transform | |
elif self.image_resize_strategy == "letterbox": | |
assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" | |
assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!" | |
# Compute Padding Fill Value (rescaled normalization mean if applicable) | |
fill = tuple([int(x * 255) for x in self.data_cfg["mean"]]) | |
# Build New Transform | |
self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms]) | |
else: | |
raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") | |
def get_fsdp_wrapping_policy(self) -> Callable: | |
"""Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer.""" | |
vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) | |
transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) | |
return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) | |
def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: | |
"""Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features.""" | |
return self.featurizer(pixel_values) | |
def default_image_resolution(self) -> Tuple[int, int, int]: | |
return self.data_cfg["input_size"] | |
def embed_dim(self) -> int: | |
return self.featurizer.embed_dim | |
def num_patches(self) -> int: | |
return self.featurizer.patch_embed.num_patches | |
def half_precision_dtype(self) -> torch.dtype: | |
return self.dtype | |