# Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os, sys, os.path as osp import warnings from abc import ABC, abstractmethod import torch, logging from transformers import ( AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, PretrainedConfig, PreTrainedModel, ) from .constants import ( DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX, MASK_TOKEN_INDEX, ) from collections import OrderedDict from .utils import get_model_config from .language_model.builder import build_llm_and_tokenizer from .multimodal_encoder.builder import build_vision_tower, build_context_provider from .multimodal_projector.builder import build_mm_projector from .configuration_llava import LlavaConfig from transformers.modeling_utils import ContextManagers, no_init_weights ## TODO decide whether should we use metaclass class LlavaMetaModel(ABC): def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs): # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation. if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "mm_projector"): # already initialized, skipped return model_dtype = getattr(config, "model_dtype", "torch.float16") if not hasattr(config, "model_dtype"): warnings.warn("model_dtype not found in config, defaulting to torch.float16.") config.model_dtype = model_dtype # print("init_vlm(): config", config); input("DEBUG init_vlm") cfgs = get_model_config(config) # Only the first three are required. Others are optional. llm_cfg, vision_tower_cfg, mm_projector_cfg, mask_encoder_cfg, context_provider_cfg = cfgs if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None: raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.") # print("init_vlm():", cfgs); input("DEBUG init_vlm") # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG init_vlm") self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) self.vision_tower = build_vision_tower(vision_tower_cfg, config) self.mm_projector = build_mm_projector(mm_projector_cfg, config) self.context_provider = build_context_provider(context_provider_cfg, config) if context_provider_cfg is not None else None self.post_config() self.is_loaded = True assert ( self.llm is not None or self.vision_tower is not None or self.mm_projector is not None ), "At least one of the components must be instantiated." @classmethod def load_from_config(cls, model_path_or_config, *args, **kwargs): pass ## FIXME we will use this function to load model in the future @classmethod def load_pretrained(cls, model_path_or_config, *args, **kwargs): kwargs.pop("config", None) if isinstance(model_path_or_config, str): config = AutoConfig.from_pretrained(model_path_or_config) elif isinstance(model_path_or_config, LlavaConfig): config = model_path_or_config else: raise NotImplementedError(f"wrong type, {type(model_path_or_config)} \ {isinstance(model_path_or_config, LlavaConfig)}") model_dtype = getattr(config, "model_dtype", "torch.float16") if not hasattr(config, "model_dtype"): warnings.warn("model_dtype not found in config, defaulting to torch.float16.") config.model_dtype = model_dtype cfgs = get_model_config(config) # Only the first three are required. Others are optional. llm_cfg, vision_tower_cfg, mm_projector_cfg, mask_encoder_cfg, context_provider_cfg = cfgs if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None: raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.") # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained") with ContextManagers([no_init_weights(_enable=True),]): vlm = cls(config, *args, **kwargs) # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained finish") if hasattr(vlm, "llm") or hasattr(vlm, "vision_tower") or hasattr(vlm, "mm_projector"): if vlm.is_loaded: return vlm vlm.llm, vlm.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) vlm.vision_tower = build_vision_tower(vision_tower_cfg, config) vlm.mm_projector = build_mm_projector(mm_projector_cfg, config) if mask_encoder_cfg is not None: raise NotImplementedError("Mask encoder is not supported.") vlm.context_provider = build_context_provider(context_provider_cfg, config) if context_provider_cfg is not None else None self.post_config() self.is_loaded = True # FIXME(ligeng, yunhao): llm should never be none here. assert ( vlm.llm is not None or vlm.vision_tower is not None or vlm.mm_projector is not None ), "At least one of the components must be instantiated." return vlm ## FIXME we will use this function to save the model in the future def save_pretrained(self, output_dir, state_dict=None): if state_dict is None: # other wise fetch from deepspeed # state_dict = accelerator.get_state_dict(is_deepspeed_enabled) state_dict = self.state_dict() if getattr(self, "tokenizer", None): self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) if self.get_llm(): print(f"saving llm to {osp.join(output_dir, 'llm')}") self.llm.config._name_or_path = osp.join(output_dir, "llm") llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k}) self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict) self.config.llm_cfg = self.llm.config if self.get_vision_tower() and "radio" not in self.get_vision_tower().__class__.__name__.lower(): print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}") self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower") vision_tower_state_dict = OrderedDict( {k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k} ) self.vision_tower.vision_tower.save_pretrained( os.path.join(output_dir, "vision_tower"), state_dict=vision_tower_state_dict, ) self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower")) self.config.vision_tower_cfg = self.vision_tower.config if hasattr(self.config.vision_tower_cfg, 'auto_map'): delattr(self.config.vision_tower_cfg, 'auto_map') if self.get_mm_projector(): print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}") self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector") mm_projector_state_dict = OrderedDict( {k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k} ) self.mm_projector.save_pretrained( os.path.join(output_dir, "mm_projector"), state_dict=mm_projector_state_dict, ) self.config.mm_projector_cfg = self.mm_projector.config if self.get_context_provider(): print(f"saving context_provider to {osp.join(output_dir, 'context_provider')}") self.context_provider.config._name_or_path = osp.join(output_dir, "context_provider") context_provider_state_dict = OrderedDict( {k.split("context_provider.")[-1]: v for k, v in state_dict.items() if "context_provider" in k} ) self.context_provider.save_pretrained( os.path.join(output_dir, "context_provider"), state_dict=context_provider_state_dict, ) self.config.context_provider_cfg = self.context_provider.config ## update and save top-level config self.config._name_or_path = output_dir self.config.architectures = [self.__class__.__name__] self.config.save_pretrained(output_dir) def get_llm(self): llm = getattr(self, "llm", None) if type(llm) is list: llm = llm[0] return llm def get_lm_head(self): lm_head = getattr(self.get_llm(), "lm_head", None) return lm_head def get_vision_tower(self): vision_tower = getattr(self, "vision_tower", None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def get_mm_projector(self): mm_projector = getattr(self, "mm_projector", None) if type(mm_projector) is list: mm_projector = mm_projector[0] return mm_projector def get_context_provider(self): context_provider = getattr(self, "context_provider", None) return context_provider def post_config(self): self.training = self.get_llm().training ## configuration if getattr(self.config, "llm_cfg", None) is None: self.config.llm_cfg = self.llm.config if getattr(self.config, "vision_tower_cfg", None) is None: self.config.vision_tower_cfg = self.vision_tower.config if getattr(self.config, "mm_projector_cfg", None) is None: self.config.mm_projector_cfg = self.mm_projector.config if getattr(self.config, "context_provider_cfg", None) is None and self.context_provider is not None: self.config.context_provider_cfg = self.context_provider.config def freezed_module_patch(self): ''' Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. ''' if self.training: if self.get_llm() and not getattr(self.config, "tune_language_model", False): logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.") if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False): self.get_vision_tower().eval() if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False): self.get_mm_projector().eval() if self.get_context_provider() and not getattr(self.config, "tune_context_provider", False): self.get_context_provider().eval() def encode_images(self, images): image_features = self.get_vision_tower()(images) image_features = self.get_mm_projector()(image_features) return image_features def encode_images_with_context(self, images): context_provider = self.get_context_provider() # If the channels completely match, they are cimage (image with context). cimage_mask = torch.any((images[:, :4, ...] != images[:, 4:, ...]).flatten(start_dim=1), dim=1) if context_provider.treat_image_as_cimage: # If the context provider treats the image as cimage, then all images are cimage. cimage_mask[:] = True if context_provider.context_image_as_queries: # Swap the crop image and full image since the model uses the full image as queries by default images = torch.cat((images[:, 4:, ...], images[:, :4, ...]), dim=1) # Process the first 4 channels for all images: for image it's the image, for cimage it's the full image vision_tower = self.get_vision_tower() # Encode context images (full images) image_features = vision_tower(images[:, :4, ...]).to(self.device) # Each cimage has 8 channels (full and crop concatenated) cimage_concatenated = images[cimage_mask] cimage_full_features = image_features[cimage_mask] if context_provider.context_provider_type == "cross_attn_end_to_all": cimage_features = self.context_provider( cimage_full_features=cimage_full_features, cimage_concatenated=cimage_concatenated, vision_tower=vision_tower ).to(self.device) elif context_provider.context_provider_type == "concat": # Full features of cimages are computed but not used. cimage_features = self.context_provider( cimage_concatenated=cimage_concatenated, vision_tower=vision_tower ).to(self.device) else: raise NotImplementedError(f"Context provider type {context_provider.context_provider_type} not implemented.") # Put cimage_features into image_features image_features[cimage_mask] = cimage_features # Project to the llm space image_features = self.get_mm_projector()(image_features) return image_features ## @yunhao: is there a better way to handle function call and attributes for llm? ## support beam search def _temporary_reorder_cache(self, past_key_values, sorted_idx): return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx) def get_input_embeddings(self): return self.get_llm().get_input_embeddings() def get_output_embeddings(self): return self.get_llm().get_output_embeddings() def resize_token_embeddings(self, embed_size): self.get_llm().resize_token_embeddings(embed_size) class LlavaMetaForCausalLM(ABC): """This class is originally implemented by the LLaVA team and modified by Haotian Tang and Jason Lu based on Ji Lin's implementation to support multiple images and input packing.""" ## TODO move the forward function here if there is no need to override it def prepare_inputs_labels_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if ( past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1 ): target_shape = past_key_values[-1][-1].shape[-2] + 1 attention_mask = torch.cat( ( attention_mask, torch.ones( ( attention_mask.shape[0], target_shape - attention_mask.shape[1], ), dtype=attention_mask.dtype, device=attention_mask.device, ), ), dim=1, ) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 return ( input_ids, position_ids, attention_mask, past_key_values, None, labels, ) # handle different image dtypes for packing if type(images) is list: images = torch.cat(images, dim=0) elif images.ndim == 5: # batch_size x seq_len x image_channels images = images.flatten(0, 1) if getattr(self, "context_provider", None): image_features = self.encode_images_with_context(images) else: # Since we slice it with index below, turning it into a list splits things by the first index which does not result in data copy or degrade performance. # Example dimension: [1, 196, 2560] assert images.shape[1] <= 4, f"images have more than 4 channels, but context provider is not included" image_features = self.encode_images(images).to(self.device) # Note (kentang-mit@): image start / end is not implemented here to support pretraining. if getattr(self.config, "turn_mm_projector", False) and getattr(self.config, "mm_use_im_start_end", False): raise NotImplementedError # Let's just add dummy tensors if they do not exist, # it is a headache to deal with None all the time. # But it is not ideal, and if you have a better idea, # please open an issue / submit a PR, thanks. _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) # remove the padding using attention_mask input_ids_copy = input_ids.clone() # kentang-mit@: Otherwise tokenizer out of bounds. Embeddings of image tokens will not be used. input_ids_copy[input_ids_copy == IMAGE_TOKEN_INDEX] = 0 input_embeds = self.llm.model.embed_tokens(input_ids_copy) input_ids = [ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) ] input_embeds_1 = [ cur_input_embeds[cur_attention_mask] for cur_input_embeds, cur_attention_mask in zip(input_embeds, attention_mask) ] labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] new_input_embeds = [] new_labels = [] cur_image_idx = 0 # print("BEFORE BATCH LOOP:", len(input_ids), input_ids[0].shape, input_ids[0].device, [(x == IMAGE_TOKEN_INDEX).sum() for x in input_ids]) # kentang-mit@: If some part of the model is executed in the loop, the the loop length needs to be a constant. for batch_idx, cur_input_ids in enumerate(input_ids): cur_input_ids = input_ids[batch_idx] num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() if num_images == 0: cur_image_features = image_features[0] # cur_input_embeds_1 = self.get_llm().embed_tokens(cur_input_ids) cur_input_embeds_1 = input_embeds_1[batch_idx] cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) new_input_embeds.append(cur_input_embeds) new_labels.append(labels[batch_idx]) # kenang-mit@: we do not have placeholdr image for text-only data now. # cur_image_idx += 1 continue cur_input_embeds = input_embeds_1[batch_idx] image_token_indices = ( [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] ) cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] cur_input_embeds_no_im = [] for i in range(len(image_token_indices) - 1): cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) cur_input_embeds_no_im.append(cur_input_embeds[image_token_indices[i] + 1 : image_token_indices[i + 1]]) split_sizes = [x.shape[0] for x in cur_labels_noim] # cur_input_embeds = self.get_llm().embed_tokens(torch.cat(cur_input_ids_noim)) # cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) cur_new_input_embeds = [] cur_new_labels = [] for i in range(num_images + 1): cur_new_input_embeds.append(cur_input_embeds_no_im[i]) cur_new_labels.append(cur_labels_noim[i]) if i < num_images: cur_image_features = image_features[cur_image_idx] cur_image_idx += 1 cur_new_input_embeds.append(cur_image_features) cur_new_labels.append( torch.full( (cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype, ) ) cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) new_input_embeds.append(cur_new_input_embeds) new_labels.append(cur_new_labels) # Truncate sequences to max length as image embeddings can make the sequence longer tokenizer_model_max_length = getattr(self.llm.config, "tokenizer_model_max_length", None) if tokenizer_model_max_length is not None: if any(len(x) > tokenizer_model_max_length for x in new_input_embeds): warnings.warn("Inputs truncated!") new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] new_labels = [x[:tokenizer_model_max_length] for x in new_labels] # Combine them max_len = max(x.shape[0] for x in new_input_embeds) batch_size = len(new_input_embeds) new_input_embeds_padded = [] new_labels_padded = torch.full( (batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device, ) attention_mask = torch.zeros( (batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device, ) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): cur_len = cur_new_embed.shape[0] if getattr(self.llm.config, "tokenizer_padding_side", "right") == "left": new_input_embeds_padded.append( torch.cat( ( torch.zeros( (max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), cur_new_embed, ), dim=0, ) ) if cur_len > 0: new_labels_padded[i, -cur_len:] = cur_new_labels attention_mask[i, -cur_len:] = True position_ids[i, -cur_len:] = torch.arange( 0, cur_len, dtype=position_ids.dtype, device=position_ids.device ) else: new_input_embeds_padded.append( torch.cat( ( cur_new_embed, torch.zeros( (max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), ), dim=0, ) ) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange( 0, cur_len, dtype=position_ids.dtype, device=position_ids.device ) new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None return ( None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, ) def repack_multimodal_data( self, input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, ): # kentang-mit@: reorder and repack (reduce computation overhead) # requires transformers replacement. new_inputs_embeds = [] new_position_ids = [] new_labels = [] seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) sorted_seqlens_in_batch, sorted_idx = torch.sort(seqlens_in_batch, descending=True) # print(sorted_seqlens_in_batch) max_seqlen = inputs_embeds.shape[1] cur_inputs_embeds = [] cur_position_ids = [] cur_labels = [] cur_batch_len = 0 # print(sorted_seqlens_in_batch.device, len(sorted_seqlens_in_batch), max_seqlen) for i in range(len(sorted_seqlens_in_batch)): cur_seqlen = sorted_seqlens_in_batch[i].item() if cur_seqlen + cur_batch_len <= max_seqlen: cur_batch_len += cur_seqlen # each item: num_tokens x num_channels # remove padding on-the-fly cur_inputs_embeds.append(inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]]) # each item: num_tokens cur_position_ids.append( torch.arange( cur_inputs_embeds[-1].shape[0], device=cur_inputs_embeds[-1].device, ) ) # each item: num_tokens # remove padding on-the-fly cur_labels.append(labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]) else: new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0)) new_position_ids.append(torch.cat(cur_position_ids, 0)) new_labels.append(torch.cat(cur_labels, 0)) # The current batch is too long. We will start a new batch. cur_batch_len = cur_seqlen cur_inputs_embeds = [inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]]] cur_position_ids = [ torch.arange( cur_inputs_embeds[-1].shape[0], device=cur_inputs_embeds[-1].device, ) ] cur_labels = [labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]] if len(cur_inputs_embeds): new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0)) new_position_ids.append(torch.cat(cur_position_ids, 0)) new_labels.append(torch.cat(cur_labels, 0)) # print(new_position_ids[0].device, [x.shape for x in new_inputs_embeds], [x.shape for x in new_labels], [x.shape for x in new_position_ids]) # assert 0 new_inputs_embeds = torch.nn.utils.rnn.pad_sequence( new_inputs_embeds, batch_first=True, padding_value=self.llm.pad_token_id ) new_position_ids = torch.nn.utils.rnn.pad_sequence(new_position_ids, batch_first=True, padding_value=-1) new_labels = torch.nn.utils.rnn.pad_sequence(new_labels, batch_first=True, padding_value=IGNORE_INDEX) ## yunhao: it's currently a workaround to avoid errors for seq_len < 100 new_attention_mask = new_position_ids.ne(-1) # sanity check assert new_attention_mask.sum() == attention_mask.sum() # print(new_inputs_embeds.shape, (new_attention_mask.sum(1))) # print(sorted_seqlens_in_batch.device, sorted_seqlens_in_batch, new_attention_mask.sum(1)) # return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels return ( None, new_position_ids, new_attention_mask, past_key_values, new_inputs_embeds, new_labels, sorted_seqlens_in_batch, ) def initialize_vision_tokenizer(self, model_args, tokenizer): if model_args.mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg ## TODO yunhao: handle cases for if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu") embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError( f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." ) elif model_args.mm_use_im_patch_token: if model_args.mm_projector: for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): p.requires_grad = False