Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from typing import Optional, Tuple, List | |
from .multimodal_config import MultiModalConfig | |
from .multimodal_components import CausalLM, MultiModalProjector | |
from ..vision.siglip_model import SigLip | |
from ..utils.kv_cache import KVCache | |
class PaliGemmaForConditionalGeneration(nn.Module): | |
def __init__(self, config: MultiModalConfig): | |
super().__init__() | |
self.config = config | |
self.vision_tower = SigLip(config.vision_config) | |
self.multi_modal_projector = MultiModalProjector(config) | |
self.vocab_size = config.vocab_size | |
language_model = CausalLM(config.text_config) | |
self.language_model = language_model | |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 | |
def tie_weights(self): | |
return self.language_model.tie_weights() | |
def _merge_input_ids_with_image_features( | |
self, image_features: torch.Tensor, inputs_embeds: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, kv_cache: Optional[KVCache] = None | |
): | |
_, _, embed_dim = image_features.shape | |
batch_size, sequence_length = input_ids.shape | |
dtype, device = inputs_embeds.dtype, inputs_embeds.device | |
scaled_image_features = image_features / (self.config.hidden_size**0.5) | |
final_embedding = torch.zeros(batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device) | |
text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id) | |
image_mask = input_ids == self.config.image_token_index | |
pad_mask = input_ids == self.pad_token_id | |
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim) | |
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim) | |
image_mask_expanded = image_mask.unsqueeze(-1).expand(-1, -1, embed_dim) | |
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding) | |
final_embedding = final_embedding.masked_scatter(image_mask_expanded, scaled_image_features) | |
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding) | |
dtype, device = inputs_embeds.dtype, inputs_embeds.device | |
min_dtype = torch.finfo(dtype).min | |
q_len = inputs_embeds.shape[1] | |
if kv_cache is None or kv_cache.num_items() == 0: | |
causal_mask = torch.full( | |
(batch_size, q_len, q_len), fill_value=0, dtype=dtype, device=device | |
) | |
else: | |
assert q_len == 1 | |
kv_len = kv_cache.num_items() + q_len | |
causal_mask = torch.full((batch_size, q_len, kv_len), fill_value=0, dtype=dtype, device=device) | |
causal_mask = causal_mask.unsqueeze(1) | |
if kv_cache is not None and kv_cache.num_items() > 0: | |
position_ids = attention_mask.cumsum(-1)[:, -1] | |
if position_ids.dim() == 1: | |
position_ids = position_ids.unsqueeze(0) | |
else: | |
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1).to(device) | |
return final_embedding, causal_mask, position_ids | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
pixel_values: torch.FloatTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
kv_cache: Optional[KVCache] = None, | |
) -> Tuple: | |
assert torch.all(attention_mask == 1), "The input cannot be padded" | |
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) | |
selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype)) | |
image_features = self.multi_modal_projector(selected_image_feature) | |
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, kv_cache) | |
outputs = self.language_model( | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
inputs_embeds=inputs_embeds, | |
kv_cache=kv_cache, | |
) | |
return outputs |