AustingDong
init
1ca9e3b
import torch
import cv2
from PIL import Image
import numpy as np
class MultimodalGradCAM:
def __init__(self, model, processor):
self.model = model
self.processor = processor
self.activations = {}
self.gradients = {}
# Register hooks
self._register_hooks()
def _register_hooks(self):
# Hook the last vision transformer layer
def forward_hook(module, input, output):
self.activations['vision'] = output.last_hidden_state
def backward_hook(module, grad_input, grad_output):
self.gradients['vision'] = grad_output[0]
vision_encoder = self.model.get_vision_encoder()
vision_encoder.layers[-1].register_forward_hook(forward_hook)
vision_encoder.layers[-1].register_backward_hook(backward_hook)
def generate_saliency(self, image, question):
# Preprocess inputs
inputs = self.processor(
text=question,
images=image,
return_tensors="pt",
padding=True
)
# Forward pass
outputs = self.model(**inputs)
answer_ids = outputs.logits.argmax(dim=-1)
# Get target token (use last token for answer)
target_token_id = answer_ids[0, -1].item()
target = outputs.logits[0, -1, target_token_id]
# Backward pass
self.model.zero_grad()
target.backward()
# Process activations and gradients
activations = self.activations['vision'].detach()
gradients = self.gradients['vision'].detach()
# Grad-CAM calculation
weights = gradients.mean(dim=[1, 2], keepdim=True) # Global average pooling
cam = (weights * activations).sum(dim=-1, keepdims=True)
cam = torch.relu(cam)
# Reshape and normalize
cam = cam.squeeze().cpu().numpy()
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam
def visualize(self, image, cam):
# Resize CAM to original image size
img_size = image.size[::-1] # (width, height) -> (height, width)
cam = cv2.resize(cam, img_size)
# Convert to heatmap
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
# Superimpose on original image
superimposed = np.array(image) * 0.4 + heatmap * 0.6
return Image.fromarray(np.uint8(superimposed))