|
import torch |
|
|
|
class SMOLLm_VISION_ImageCaptioning(torch.nn.Module): |
|
def __init__(self, llm_model, hidden_dim): |
|
super(ImageCaptioningModel, self).__init__() |
|
self.llm_model = llm_model |
|
self.fc = torch.nn.Linear(768, 960) |
|
self.relu=torch.nn.GELU() |
|
def forward(self, images, input_ids,att): |
|
|
|
image_features = self.relu(self.fc(images)) |
|
|
|
|
|
|
|
llama_inputs = self.llm_model.prepare_inputs_for_generation(input_ids) |
|
with torch.no_grad(): |
|
llama_embeds=self.llm_model.get_input_embeddings()(llama_inputs['input_ids']) |
|
|
|
|
|
combined_inputs = torch.cat([image_features.unsqueeze(1).float(),llama_embeds], dim=1) |
|
|
|
outputs = self.llm_model(inputs_embeds=combined_inputs,attention_mask=att) |
|
|
|
return outputs.logits[:,1:,:],combined_inputs |
|
|
|
|
|
class SmoLLM_processor(): |
|
def __init__(self,image_model,image_processor): |
|
self.image_model=image_model |
|
self.image_processor=image_processor |
|
|
|
def get_features(self,image): |
|
inputs = clip_processor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
image_features = clip_model.get_image_features(**inputs.to('cuda:0')).squeeze() |
|
|
|
|
|
|
|
|
|
|
|
return image_features |