richardaecn's picture
Upload 105 files
e19aac6 verified
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys, os.path as osp
import warnings
from abc import ABC, abstractmethod
import torch, logging
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForCausalLM,
AutoConfig,
BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel,
)
from .constants import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_PATCH_TOKEN,
IGNORE_INDEX,
IMAGE_TOKEN_INDEX,
MASK_TOKEN_INDEX,
)
from collections import OrderedDict
from .utils import get_model_config
from .language_model.builder import build_llm_and_tokenizer
from .multimodal_encoder.builder import build_vision_tower, build_context_provider
from .multimodal_projector.builder import build_mm_projector
from .configuration_llava import LlavaConfig
from transformers.modeling_utils import ContextManagers, no_init_weights
## TODO decide whether should we use metaclass
class LlavaMetaModel(ABC):
def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs):
# TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation.
if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "mm_projector"):
# already initialized, skipped
return
model_dtype = getattr(config, "model_dtype", "torch.float16")
if not hasattr(config, "model_dtype"):
warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
config.model_dtype = model_dtype
# print("init_vlm(): config", config); input("DEBUG init_vlm")
cfgs = get_model_config(config)
# Only the first three are required. Others are optional.
llm_cfg, vision_tower_cfg, mm_projector_cfg, mask_encoder_cfg, context_provider_cfg = cfgs
if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None:
raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
# print("init_vlm():", cfgs); input("DEBUG init_vlm")
# print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG init_vlm")
self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs)
self.vision_tower = build_vision_tower(vision_tower_cfg, config)
self.mm_projector = build_mm_projector(mm_projector_cfg, config)
self.context_provider = build_context_provider(context_provider_cfg, config) if context_provider_cfg is not None else None
self.post_config()
self.is_loaded = True
assert (
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
), "At least one of the components must be instantiated."
@classmethod
def load_from_config(cls, model_path_or_config, *args, **kwargs):
pass
## FIXME we will use this function to load model in the future
@classmethod
def load_pretrained(cls, model_path_or_config, *args, **kwargs):
kwargs.pop("config", None)
if isinstance(model_path_or_config, str):
config = AutoConfig.from_pretrained(model_path_or_config)
elif isinstance(model_path_or_config, LlavaConfig):
config = model_path_or_config
else:
raise NotImplementedError(f"wrong type, {type(model_path_or_config)} \
{isinstance(model_path_or_config, LlavaConfig)}")
model_dtype = getattr(config, "model_dtype", "torch.float16")
if not hasattr(config, "model_dtype"):
warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
config.model_dtype = model_dtype
cfgs = get_model_config(config)
# Only the first three are required. Others are optional.
llm_cfg, vision_tower_cfg, mm_projector_cfg, mask_encoder_cfg, context_provider_cfg = cfgs
if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None:
raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
# print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained")
with ContextManagers([no_init_weights(_enable=True),]):
vlm = cls(config, *args, **kwargs)
# print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained finish")
if hasattr(vlm, "llm") or hasattr(vlm, "vision_tower") or hasattr(vlm, "mm_projector"):
if vlm.is_loaded:
return vlm
vlm.llm, vlm.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs)
vlm.vision_tower = build_vision_tower(vision_tower_cfg, config)
vlm.mm_projector = build_mm_projector(mm_projector_cfg, config)
if mask_encoder_cfg is not None:
raise NotImplementedError("Mask encoder is not supported.")
vlm.context_provider = build_context_provider(context_provider_cfg, config) if context_provider_cfg is not None else None
self.post_config()
self.is_loaded = True
# FIXME(ligeng, yunhao): llm should never be none here.
assert (
vlm.llm is not None or vlm.vision_tower is not None or vlm.mm_projector is not None
), "At least one of the components must be instantiated."
return vlm
## FIXME we will use this function to save the model in the future
def save_pretrained(self, output_dir, state_dict=None):
if state_dict is None:
# other wise fetch from deepspeed
# state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
state_dict = self.state_dict()
if getattr(self, "tokenizer", None):
self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
if self.get_llm():
print(f"saving llm to {osp.join(output_dir, 'llm')}")
self.llm.config._name_or_path = osp.join(output_dir, "llm")
llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
self.config.llm_cfg = self.llm.config
if self.get_vision_tower() and "radio" not in self.get_vision_tower().__class__.__name__.lower():
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
vision_tower_state_dict = OrderedDict(
{k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
)
self.vision_tower.vision_tower.save_pretrained(
os.path.join(output_dir, "vision_tower"),
state_dict=vision_tower_state_dict,
)
self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
self.config.vision_tower_cfg = self.vision_tower.config
if hasattr(self.config.vision_tower_cfg, 'auto_map'):
delattr(self.config.vision_tower_cfg, 'auto_map')
if self.get_mm_projector():
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
mm_projector_state_dict = OrderedDict(
{k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
)
self.mm_projector.save_pretrained(
os.path.join(output_dir, "mm_projector"),
state_dict=mm_projector_state_dict,
)
self.config.mm_projector_cfg = self.mm_projector.config
if self.get_context_provider():
print(f"saving context_provider to {osp.join(output_dir, 'context_provider')}")
self.context_provider.config._name_or_path = osp.join(output_dir, "context_provider")
context_provider_state_dict = OrderedDict(
{k.split("context_provider.")[-1]: v for k, v in state_dict.items() if "context_provider" in k}
)
self.context_provider.save_pretrained(
os.path.join(output_dir, "context_provider"),
state_dict=context_provider_state_dict,
)
self.config.context_provider_cfg = self.context_provider.config
## update and save top-level config
self.config._name_or_path = output_dir
self.config.architectures = [self.__class__.__name__]
self.config.save_pretrained(output_dir)
def get_llm(self):
llm = getattr(self, "llm", None)
if type(llm) is list:
llm = llm[0]
return llm
def get_lm_head(self):
lm_head = getattr(self.get_llm(), "lm_head", None)
return lm_head
def get_vision_tower(self):
vision_tower = getattr(self, "vision_tower", None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def get_mm_projector(self):
mm_projector = getattr(self, "mm_projector", None)
if type(mm_projector) is list:
mm_projector = mm_projector[0]
return mm_projector
def get_context_provider(self):
context_provider = getattr(self, "context_provider", None)
return context_provider
def post_config(self):
self.training = self.get_llm().training
## configuration
if getattr(self.config, "llm_cfg", None) is None:
self.config.llm_cfg = self.llm.config
if getattr(self.config, "vision_tower_cfg", None) is None:
self.config.vision_tower_cfg = self.vision_tower.config
if getattr(self.config, "mm_projector_cfg", None) is None:
self.config.mm_projector_cfg = self.mm_projector.config
if getattr(self.config, "context_provider_cfg", None) is None and self.context_provider is not None:
self.config.context_provider_cfg = self.context_provider.config
def freezed_module_patch(self):
'''
Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
'''
if self.training:
if self.get_llm() and not getattr(self.config, "tune_language_model", False):
logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.")
if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
self.get_vision_tower().eval()
if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
self.get_mm_projector().eval()
if self.get_context_provider() and not getattr(self.config, "tune_context_provider", False):
self.get_context_provider().eval()
def encode_images(self, images):
image_features = self.get_vision_tower()(images)
image_features = self.get_mm_projector()(image_features)
return image_features
def encode_images_with_context(self, images):
context_provider = self.get_context_provider()
# If the channels completely match, they are cimage (image with context).
cimage_mask = torch.any((images[:, :4, ...] != images[:, 4:, ...]).flatten(start_dim=1), dim=1)
if context_provider.treat_image_as_cimage:
# If the context provider treats the image as cimage, then all images are cimage.
cimage_mask[:] = True
if context_provider.context_image_as_queries:
# Swap the crop image and full image since the model uses the full image as queries by default
images = torch.cat((images[:, 4:, ...], images[:, :4, ...]), dim=1)
# Process the first 4 channels for all images: for image it's the image, for cimage it's the full image
vision_tower = self.get_vision_tower()
# Encode context images (full images)
image_features = vision_tower(images[:, :4, ...]).to(self.device)
# Each cimage has 8 channels (full and crop concatenated)
cimage_concatenated = images[cimage_mask]
cimage_full_features = image_features[cimage_mask]
if context_provider.context_provider_type == "cross_attn_end_to_all":
cimage_features = self.context_provider(
cimage_full_features=cimage_full_features,
cimage_concatenated=cimage_concatenated,
vision_tower=vision_tower
).to(self.device)
elif context_provider.context_provider_type == "concat":
# Full features of cimages are computed but not used.
cimage_features = self.context_provider(
cimage_concatenated=cimage_concatenated,
vision_tower=vision_tower
).to(self.device)
else:
raise NotImplementedError(f"Context provider type {context_provider.context_provider_type} not implemented.")
# Put cimage_features into image_features
image_features[cimage_mask] = cimage_features
# Project to the llm space
image_features = self.get_mm_projector()(image_features)
return image_features
## @yunhao: is there a better way to handle function call and attributes for llm?
## support beam search
def _temporary_reorder_cache(self, past_key_values, sorted_idx):
return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx)
def get_input_embeddings(self):
return self.get_llm().get_input_embeddings()
def get_output_embeddings(self):
return self.get_llm().get_output_embeddings()
def resize_token_embeddings(self, embed_size):
self.get_llm().resize_token_embeddings(embed_size)
class LlavaMetaForCausalLM(ABC):
"""This class is originally implemented by the LLaVA team and
modified by Haotian Tang and Jason Lu based on Ji Lin's implementation
to support multiple images and input packing."""
## TODO move the forward function here if there is no need to override it
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
if (
past_key_values is not None
and vision_tower is not None
and images is not None
and input_ids.shape[1] == 1
):
target_shape = past_key_values[-1][-1].shape[-2] + 1
attention_mask = torch.cat(
(
attention_mask,
torch.ones(
(
attention_mask.shape[0],
target_shape - attention_mask.shape[1],
),
dtype=attention_mask.dtype,
device=attention_mask.device,
),
),
dim=1,
)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
return (
input_ids,
position_ids,
attention_mask,
past_key_values,
None,
labels,
)
# handle different image dtypes for packing
if type(images) is list:
images = torch.cat(images, dim=0)
elif images.ndim == 5: # batch_size x seq_len x image_channels
images = images.flatten(0, 1)
if getattr(self, "context_provider", None):
image_features = self.encode_images_with_context(images)
else:
# Since we slice it with index below, turning it into a list splits things by the first index which does not result in data copy or degrade performance.
# Example dimension: [1, 196, 2560]
assert images.shape[1] <= 4, f"images have more than 4 channels, but context provider is not included"
image_features = self.encode_images(images).to(self.device)
# Note (kentang-mit@): image start / end is not implemented here to support pretraining.
if getattr(self.config, "turn_mm_projector", False) and getattr(self.config, "mm_use_im_start_end", False):
raise NotImplementedError
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask
input_ids_copy = input_ids.clone()
# kentang-mit@: Otherwise tokenizer out of bounds. Embeddings of image tokens will not be used.
input_ids_copy[input_ids_copy == IMAGE_TOKEN_INDEX] = 0
input_embeds = self.llm.model.embed_tokens(input_ids_copy)
input_ids = [
cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
]
input_embeds_1 = [
cur_input_embeds[cur_attention_mask]
for cur_input_embeds, cur_attention_mask in zip(input_embeds, attention_mask)
]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
# print("BEFORE BATCH LOOP:", len(input_ids), input_ids[0].shape, input_ids[0].device, [(x == IMAGE_TOKEN_INDEX).sum() for x in input_ids])
# kentang-mit@: If some part of the model is executed in the loop, the the loop length needs to be a constant.
for batch_idx, cur_input_ids in enumerate(input_ids):
cur_input_ids = input_ids[batch_idx]
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_image_features = image_features[0]
# cur_input_embeds_1 = self.get_llm().embed_tokens(cur_input_ids)
cur_input_embeds_1 = input_embeds_1[batch_idx]
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
# kenang-mit@: we do not have placeholdr image for text-only data now.
# cur_image_idx += 1
continue
cur_input_embeds = input_embeds_1[batch_idx]
image_token_indices = (
[-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
)
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
cur_input_embeds_no_im = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
cur_input_embeds_no_im.append(cur_input_embeds[image_token_indices[i] + 1 : image_token_indices[i + 1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
# cur_input_embeds = self.get_llm().embed_tokens(torch.cat(cur_input_ids_noim))
# cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.llm.config, "tokenizer_model_max_length", None)
if tokenizer_model_max_length is not None:
if any(len(x) > tokenizer_model_max_length for x in new_input_embeds):
warnings.warn("Inputs truncated!")
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len),
IGNORE_INDEX,
dtype=new_labels[0].dtype,
device=new_labels[0].device,
)
attention_mask = torch.zeros(
(batch_size, max_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.llm.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return (
None,
position_ids,
attention_mask,
past_key_values,
new_input_embeds,
new_labels,
)
def repack_multimodal_data(
self,
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
):
# kentang-mit@: reorder and repack (reduce computation overhead)
# requires transformers replacement.
new_inputs_embeds = []
new_position_ids = []
new_labels = []
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
sorted_seqlens_in_batch, sorted_idx = torch.sort(seqlens_in_batch, descending=True)
# print(sorted_seqlens_in_batch)
max_seqlen = inputs_embeds.shape[1]
cur_inputs_embeds = []
cur_position_ids = []
cur_labels = []
cur_batch_len = 0
# print(sorted_seqlens_in_batch.device, len(sorted_seqlens_in_batch), max_seqlen)
for i in range(len(sorted_seqlens_in_batch)):
cur_seqlen = sorted_seqlens_in_batch[i].item()
if cur_seqlen + cur_batch_len <= max_seqlen:
cur_batch_len += cur_seqlen
# each item: num_tokens x num_channels
# remove padding on-the-fly
cur_inputs_embeds.append(inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]])
# each item: num_tokens
cur_position_ids.append(
torch.arange(
cur_inputs_embeds[-1].shape[0],
device=cur_inputs_embeds[-1].device,
)
)
# each item: num_tokens
# remove padding on-the-fly
cur_labels.append(labels[sorted_idx[i]][attention_mask[sorted_idx[i]]])
else:
new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0))
new_position_ids.append(torch.cat(cur_position_ids, 0))
new_labels.append(torch.cat(cur_labels, 0))
# The current batch is too long. We will start a new batch.
cur_batch_len = cur_seqlen
cur_inputs_embeds = [inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]]]
cur_position_ids = [
torch.arange(
cur_inputs_embeds[-1].shape[0],
device=cur_inputs_embeds[-1].device,
)
]
cur_labels = [labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]]
if len(cur_inputs_embeds):
new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0))
new_position_ids.append(torch.cat(cur_position_ids, 0))
new_labels.append(torch.cat(cur_labels, 0))
# print(new_position_ids[0].device, [x.shape for x in new_inputs_embeds], [x.shape for x in new_labels], [x.shape for x in new_position_ids])
# assert 0
new_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
new_inputs_embeds, batch_first=True, padding_value=self.llm.pad_token_id
)
new_position_ids = torch.nn.utils.rnn.pad_sequence(new_position_ids, batch_first=True, padding_value=-1)
new_labels = torch.nn.utils.rnn.pad_sequence(new_labels, batch_first=True, padding_value=IGNORE_INDEX)
## yunhao: it's currently a workaround to avoid errors for seq_len < 100
new_attention_mask = new_position_ids.ne(-1)
# sanity check
assert new_attention_mask.sum() == attention_mask.sum()
# print(new_inputs_embeds.shape, (new_attention_mask.sum(1)))
# print(sorted_seqlens_in_batch.device, sorted_seqlens_in_batch, new_attention_mask.sum(1))
# return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
return (
None,
new_position_ids,
new_attention_mask,
past_key_values,
new_inputs_embeds,
new_labels,
sorted_seqlens_in_batch,
)
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
## TODO yunhao: handle cases for <im_st> <im_end>
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
)
elif model_args.mm_use_im_patch_token:
if model_args.mm_projector:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False