|
import torch |
|
import cv2 |
|
from PIL import Image |
|
import numpy as np |
|
|
|
class MultimodalGradCAM: |
|
def __init__(self, model, processor): |
|
self.model = model |
|
self.processor = processor |
|
self.activations = {} |
|
self.gradients = {} |
|
|
|
|
|
self._register_hooks() |
|
|
|
def _register_hooks(self): |
|
|
|
def forward_hook(module, input, output): |
|
self.activations['vision'] = output.last_hidden_state |
|
|
|
def backward_hook(module, grad_input, grad_output): |
|
self.gradients['vision'] = grad_output[0] |
|
|
|
vision_encoder = self.model.get_vision_encoder() |
|
vision_encoder.layers[-1].register_forward_hook(forward_hook) |
|
vision_encoder.layers[-1].register_backward_hook(backward_hook) |
|
|
|
def generate_saliency(self, image, question): |
|
|
|
inputs = self.processor( |
|
text=question, |
|
images=image, |
|
return_tensors="pt", |
|
padding=True |
|
) |
|
|
|
|
|
outputs = self.model(**inputs) |
|
answer_ids = outputs.logits.argmax(dim=-1) |
|
|
|
|
|
target_token_id = answer_ids[0, -1].item() |
|
target = outputs.logits[0, -1, target_token_id] |
|
|
|
|
|
self.model.zero_grad() |
|
target.backward() |
|
|
|
|
|
activations = self.activations['vision'].detach() |
|
gradients = self.gradients['vision'].detach() |
|
|
|
|
|
weights = gradients.mean(dim=[1, 2], keepdim=True) |
|
cam = (weights * activations).sum(dim=-1, keepdims=True) |
|
cam = torch.relu(cam) |
|
|
|
|
|
cam = cam.squeeze().cpu().numpy() |
|
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) |
|
|
|
return cam |
|
|
|
def visualize(self, image, cam): |
|
|
|
img_size = image.size[::-1] |
|
cam = cv2.resize(cam, img_size) |
|
|
|
|
|
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET) |
|
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
superimposed = np.array(image) * 0.4 + heatmap * 0.6 |
|
return Image.fromarray(np.uint8(superimposed)) |
|
|