vlm-o / model /multimodal /multimodal_model.py
veerpareek's picture
Upload 35 files
577d9ca verified
import torch
from torch import nn
from typing import Optional, Tuple, List
from .multimodal_config import MultiModalConfig
from .multimodal_components import CausalLM, MultiModalProjector
from ..vision.siglip_model import SigLip
from ..utils.kv_cache import KVCache
class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, config: MultiModalConfig):
super().__init__()
self.config = config
self.vision_tower = SigLip(config.vision_config)
self.multi_modal_projector = MultiModalProjector(config)
self.vocab_size = config.vocab_size
language_model = CausalLM(config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
def tie_weights(self):
return self.language_model.tie_weights()
def _merge_input_ids_with_image_features(
self, image_features: torch.Tensor, inputs_embeds: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, kv_cache: Optional[KVCache] = None
):
_, _, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
dtype, device = inputs_embeds.dtype, inputs_embeds.device
scaled_image_features = image_features / (self.config.hidden_size**0.5)
final_embedding = torch.zeros(batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device)
text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id)
image_mask = input_ids == self.config.image_token_index
pad_mask = input_ids == self.pad_token_id
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
image_mask_expanded = image_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
final_embedding = final_embedding.masked_scatter(image_mask_expanded, scaled_image_features)
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
dtype, device = inputs_embeds.dtype, inputs_embeds.device
min_dtype = torch.finfo(dtype).min
q_len = inputs_embeds.shape[1]
if kv_cache is None or kv_cache.num_items() == 0:
causal_mask = torch.full(
(batch_size, q_len, q_len), fill_value=0, dtype=dtype, device=device
)
else:
assert q_len == 1
kv_len = kv_cache.num_items() + q_len
causal_mask = torch.full((batch_size, q_len, kv_len), fill_value=0, dtype=dtype, device=device)
causal_mask = causal_mask.unsqueeze(1)
if kv_cache is not None and kv_cache.num_items() > 0:
position_ids = attention_mask.cumsum(-1)[:, -1]
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0)
else:
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1).to(device)
return final_embedding, causal_mask, position_ids
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
assert torch.all(attention_mask == 1), "The input cannot be padded"
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
image_features = self.multi_modal_projector(selected_image_feature)
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, kv_cache)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
return outputs