Spaces:
Running
on
Zero
Running
on
Zero
# ------------------------------------------------------------------------ | |
# Copyright 2023 Haotian Liu | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ------------------------------------------------------------------------ | |
# Modified from LLaVA (https://github.com/haotian-liu/LLaVA) and S^2(https://github.com/bfshi/scaling_on_scales) | |
# Copyright 2024 Jiachen Li | |
# ------------------------------------------------------------------------ | |
import torch | |
import torch.nn as nn | |
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig | |
import torch.nn.functional as F | |
from transformers.activations import ACT2FN | |
import math | |
from einops import rearrange | |
from .clip import CLIPVisionTransformer | |
from .clip_smoe import CLIPSMoEVisionTransformer | |
class CLIPVisionTower(nn.Module): | |
def __init__(self, vision_tower, args, delay_load=False): | |
super().__init__() | |
self.vision_tower_name = vision_tower | |
self.select_layer = args.mm_vision_select_layer | |
self.clip_smoe = args.clip_smoe | |
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') | |
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) | |
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) | |
self.scales = args.scales | |
if args.clip_smoe: | |
self.vision_model = CLIPSMoEVisionTransformer(self.cfg_only, num_experts=args.num_experts, num_selected=args.num_selected) | |
else: | |
self.vision_model = CLIPVisionTransformer(self.cfg_only) | |
self.is_loaded = True | |
def feature_select(self, image_features): | |
#image_features = image_forward_outs.hidden_states[self.select_layer] | |
if self.select_feature == 'patch': | |
image_features = image_features[:, 1:] | |
elif self.select_feature == 'cls_patch': | |
image_features = image_features | |
else: | |
raise ValueError(f'Unexpected select feature: {self.select_feature}') | |
return image_features | |
def split_chessboard(self, x, num_split): | |
""" | |
x: b * c * h * w | |
Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension | |
""" | |
B, C, H, W = x.shape | |
assert H % num_split == 0 and W % num_split == 0 | |
h, w = H // num_split, W // num_split | |
x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0) | |
return x_split | |
def merge_chessboard(self, x, num_split): | |
""" | |
x: b * c * h * w | |
Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square. | |
(inverse of split_chessboard) | |
""" | |
B, C, H, W = x.shape | |
assert B % (num_split**2) == 0 | |
b = B // (num_split**2) | |
x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1) for i in range(num_split)], dim=-2) | |
return x_merge | |
def forward(self, images): | |
if type(images) is list: | |
image_features = [] | |
for image in images: | |
image_forward_out = self.vision_model(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) | |
image_feature = self.feature_select(image_forward_out).to(image.dtype) | |
image_features.append(image_feature) | |
else: | |
input_size = images.shape[3] | |
img_sizes = [int(input_size * scale) for scale in self.scales] | |
num_splits = [math.ceil(size / input_size) for size in img_sizes] | |
image_pyramids = [images] | |
for i, (size, num_split) in enumerate(zip(img_sizes, num_splits)): | |
if i > 0: | |
x = F.interpolate(images.to(torch.float32), size=size, mode='bicubic').to(images.dtype) | |
x = self.split_chessboard(x, num_split=num_split) | |
image_pyramids.append(x) | |
if self.clip_smoe: | |
image_features = [] | |
balance_losses = [] | |
router_z_losses = [] | |
for i, (x, num_split) in enumerate(zip(image_pyramids, num_splits)): | |
out_x, balance_loss, router_z_loss = self.vision_model(x) | |
out_x = self.feature_select(out_x) | |
if i > 0: | |
out_x = rearrange(out_x, 'b (h w) c -> b c h w', h=int(out_x.shape[1] ** 0.5), w=int(out_x.shape[1] ** 0.5)) | |
out_x = self.merge_chessboard(out_x, num_split=num_split) | |
out_x = F.interpolate(out_x.to(torch.float32), size=int(image_features[0].shape[1] ** 0.5), mode='area').to(x.dtype) | |
out_x = rearrange(out_x, 'b c h w -> b (h w) c') | |
image_features.append(out_x) | |
balance_losses.append(balance_loss) | |
router_z_losses.append(router_z_loss) | |
image_features = torch.cat(image_features, dim=-1) | |
return image_features, torch.stack(balance_losses).mean(), torch.stack(router_z_losses).mean() | |
else: | |
image_features = [] | |
for i, (x, num_split) in enumerate(zip(image_pyramids, num_splits)): | |
out_x = self.vision_model(x) | |
out_x = self.feature_select(out_x) | |
if i > 0: | |
out_x = rearrange(out_x, 'b (h w) c -> b c h w', h=int(out_x.shape[1] ** 0.5), w=int(out_x.shape[1] ** 0.5)) | |
out_x = self.merge_chessboard(out_x, num_split=num_split) | |
out_x = F.interpolate(out_x.to(torch.float32), size=int(image_features[0].shape[1] ** 0.5), mode='area').to(x.dtype) | |
out_x = rearrange(out_x, 'b c h w -> b (h w) c') | |
image_features.append(out_x) | |
image_features = torch.cat(image_features, dim=-1) | |
return image_features, None, None | |
def dummy_feature(self): | |
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
def dtype(self): | |
return self.vision_model.dtype | |
def device(self): | |
return self.vision_model.device | |
def config(self): | |
if self.is_loaded: | |
return self.vision_model.config | |
else: | |
return self.cfg_only | |
def hidden_size(self): | |
return self.config.hidden_size | |
def num_patches_per_side(self): | |
return self.config.image_size // self.config.patch_size | |
def num_patches(self): | |
return (self.config.image_size // self.config.patch_size) ** 2 | |