Spaces:
Sleeping
Sleeping
from typing import Dict, Optional | |
import torch | |
import torch.distributed as dist | |
from torch import nn, Tensor | |
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig | |
from peft import LoraConfig, get_peft_model, PeftModel | |
from src.arguments import ModelArguments | |
from src.vlm_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM | |
from src.vlm_backbone.llava_next import LlavaNextForConditionalGeneration | |
from transformers import Qwen2VLForConditionalGeneration | |
class MMEBModel(nn.Module): | |
TRANSFORMER_CLS = AutoModelForCausalLM | |
def __init__(self, | |
encoder: PreTrainedModel, | |
pooling: str = 'cls', | |
normalize: bool = False, | |
temperature: float = 1.0, | |
): | |
super().__init__() | |
self.config = encoder.config | |
self.encoder = encoder | |
self.pooling = pooling | |
self.normalize = normalize | |
self.temperature = temperature | |
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') | |
self.is_ddp = dist.is_initialized() | |
if self.is_ddp: | |
self.process_rank = dist.get_rank() | |
self.world_size = dist.get_world_size() | |
def encode_input(self, input): | |
hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) | |
hidden_states = hidden_states.hidden_states[-1] | |
pooled_output = self._pooling(hidden_states, input['attention_mask']) | |
return pooled_output | |
def _pooling(self, last_hidden_state, attention_mask): | |
if self.pooling == 'last' or self.pooling == 'eos': | |
sequence_lengths = attention_mask.sum(dim=1) - 1 | |
batch_size = last_hidden_state.shape[0] | |
reps = last_hidden_state[ | |
torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths] | |
else: | |
raise NotImplementedError | |
if self.normalize: | |
reps = torch.nn.functional.normalize(reps, p=2, dim=-1) | |
return reps | |
def build(cls, model_args: ModelArguments, **hf_kwargs): | |
# Loading the base model | |
lora_target_modules = None | |
if model_args.model_backbone == "llava_next": | |
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
config.use_cache = False | |
config.padding_side = "left" | |
base_model = LlavaNextForConditionalGeneration.from_pretrained( | |
model_args.model_name, | |
config=config, | |
torch_dtype=torch.bfloat16, | |
low_cpu_mem_usage=True, | |
) | |
elif model_args.model_backbone == "qwen": | |
base_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
model_args.model_name, | |
torch_dtype=torch.bfloat16, | |
low_cpu_mem_usage=True, | |
) | |
base_model.padding_side = "right" | |
# Loading the base model | |
elif model_args.model_backbone == "phi35v": | |
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
# config._attn_implementation = "eager" | |
config.attn_implementation = "flash_attention_2" | |
config.padding_side = "right" | |
config.use_cache = False | |
base_model = Phi3VForCausalLM.from_pretrained( | |
model_args.model_name, | |
config=config, | |
torch_dtype=torch.bfloat16, | |
low_cpu_mem_usage=True, | |
) | |
elif model_args.model_backbone == "internvl_2_5": | |
# from transformers import InternVLChatConfig, InternVLChatModel | |
from src.vlm_backbone.intern_vl import InternVLChatConfig, InternVLChatModel | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.model_name, | |
trust_remote_code=True | |
) | |
# import pdb;pdb.set_trace() | |
config = InternVLChatConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
# config.vision_config.image_size = data_args.force_image_size # 假设data_args包含图像尺寸 | |
config.use_flash_attn = False | |
base_model = InternVLChatModel.from_pretrained( | |
model_args.model_name, | |
config=config, | |
tokenizer=tokenizer, | |
# attn_implementation="flash_attention_2", | |
torch_dtype=torch.bfloat16 | |
) | |
lora_target_modules = base_model.get_lora_target_modules() | |
else: | |
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
config.use_cache = False | |
config.padding_side = "right" | |
base_model = cls.TRANSFORMER_CLS.from_pretrained( | |
model_args.model_name, **hf_kwargs, config=config, | |
attn_implementation="flash_attention_2", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True) | |
base_model.padding_side = "right" | |
if model_args.lora: | |
if lora_target_modules is None: | |
lora_target_modules = model_args.lora_target_modules.split(',') | |
lora_config = LoraConfig( | |
r=model_args.lora_r, | |
lora_alpha=model_args.lora_alpha, | |
target_modules=lora_target_modules, | |
lora_dropout=model_args.lora_dropout, | |
init_lora_weights="gaussian", | |
use_dora=True, | |
inference_mode=False | |
) | |
lora_model = get_peft_model(base_model, lora_config) | |
model = cls( | |
encoder=lora_model, | |
pooling=model_args.pooling, | |
normalize=model_args.normalize, | |
temperature=model_args.temperature | |
) | |
else: | |
model = cls( | |
encoder=base_model, | |
pooling=model_args.pooling, | |
normalize=model_args.normalize, | |
temperature=model_args.temperature | |
) | |
return model | |
def load(cls, model_args: ModelArguments, **hf_kwargs): | |
# Loading the base model | |
checkpoint_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name | |
if model_args.model_backbone == "llava_next": | |
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
config.use_cache = False | |
base_model = LlavaNextForConditionalGeneration.from_pretrained( | |
model_args.model_name, | |
torch_dtype=torch.bfloat16, | |
low_cpu_mem_usage=True, | |
# attn_implementation="flash_attention_2" | |
) | |
base_model.padding_side = "left" | |
elif model_args.model_backbone == "phi35v": | |
# Loading the base model | |
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
config.use_cache = False | |
config.padding_side = "right" | |
base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **hf_kwargs, config=config, | |
attn_implementation="flash_attention_2", | |
torch_dtype=torch.bfloat16, trust_remote_code=True) | |
base_model.padding_side = "right" | |
elif model_args.model_backbone == "internvl_2_5": | |
print("loading model") | |
from src.vlm_backbone.intern_vl import InternVLChatConfig, InternVLChatModel | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.model_name, | |
trust_remote_code=True | |
) | |
config = InternVLChatConfig.from_pretrained(model_args.model_name) | |
# config.vision_config.image_size = data_args.force_image_size | |
config.use_flash_attn = False | |
base_model = InternVLChatModel.from_pretrained( | |
model_args.model_name, | |
config=config, | |
tokenizer=tokenizer, | |
# attn_implementation="flash_attention_2", | |
torch_dtype=torch.bfloat16 | |
) | |
else: | |
# Loading the base model | |
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
config.use_cache = False | |
config.padding_side = "right" | |
base_model = cls.TRANSFORMER_CLS.from_pretrained( | |
checkpoint_path, **hf_kwargs, config=config, | |
attn_implementation="flash_attention_2", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True) | |
base_model.padding_side = "right" | |
# Building the model on top of the base | |
if model_args.lora: | |
print("loading lora parameters") | |
lora_config = LoraConfig.from_pretrained(checkpoint_path) | |
lora_model = PeftModel.from_pretrained(base_model, checkpoint_path, config=lora_config) | |
merged_model = lora_model.merge_and_unload() | |
model = cls( | |
encoder=merged_model, | |
pooling=model_args.pooling, | |
normalize=model_args.normalize | |
) | |
else: | |
model = cls( | |
encoder=base_model, | |
pooling=model_args.pooling, | |
normalize=model_args.normalize | |
) | |
return model | |
def save(self, output_dir: str): | |
self.encoder.save_pretrained(output_dir) | |
def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, neg: Dict[str, Tensor] = None): | |
qry_reps = self.encode_input(qry) if qry else None # (bsz_per_device, dim) | |
tgt_reps = self.encode_input(tgt) if tgt else None # (bsz_per_device, dim) | |
neg_reps = self.encode_input(neg) if neg else None # (bsz_per_device, dim) | |
if qry_reps is None or tgt_reps is None: | |
return {"qry_reps": qry_reps, "tgt_reps": tgt_reps} | |
# Gather representations if using DDP | |
if self.is_ddp: | |
all_qry_reps = self._dist_gather_tensor(qry_reps) | |
all_tgt_reps = self._dist_gather_tensor(tgt_reps) | |
all_neg_reps = self._dist_gather_tensor(neg_reps) if neg_reps is not None else None | |
else: | |
all_qry_reps = qry_reps | |
all_tgt_reps = tgt_reps | |
all_neg_reps = neg_reps | |
# Compute similarity scores | |
scores = self.compute_similarity(all_qry_reps, all_tgt_reps) | |
scores = scores.view(all_qry_reps.size(0), -1) | |
# Add negative scores if available | |
if all_neg_reps is not None: | |
qry_neg_cos = self.compute_similarity(all_qry_reps, all_neg_reps) | |
scores = torch.cat([scores, qry_neg_cos], dim=1) | |
# Compute loss | |
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) | |
target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) | |
loss = self.cross_entropy(scores / self.temperature, target) | |
if self.is_ddp: | |
loss = loss * self.world_size | |
return loss | |
def _dist_gather_tensor(self, t: Tensor): | |
t = t.contiguous() | |
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] | |
dist.all_gather(all_tensors, t) | |
all_tensors[self.process_rank] = t | |
all_tensors = torch.cat(all_tensors, dim=0) | |
return all_tensors | |
def compute_similarity(self, q_reps, p_reps): | |
return torch.matmul(q_reps, p_reps.transpose(0, 1)) | |