File size: 2,563 Bytes
1ca9e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import cv2
from PIL import Image
import numpy as np

class MultimodalGradCAM:
    def __init__(self, model, processor):
        self.model = model
        self.processor = processor
        self.activations = {}
        self.gradients = {}
        
        # Register hooks
        self._register_hooks()

    def _register_hooks(self):
        # Hook the last vision transformer layer
        def forward_hook(module, input, output):
            self.activations['vision'] = output.last_hidden_state
            
        def backward_hook(module, grad_input, grad_output):
            self.gradients['vision'] = grad_output[0]
            
        vision_encoder = self.model.get_vision_encoder()
        vision_encoder.layers[-1].register_forward_hook(forward_hook)
        vision_encoder.layers[-1].register_backward_hook(backward_hook)

    def generate_saliency(self, image, question):
        # Preprocess inputs
        inputs = self.processor(
            text=question,
            images=image,
            return_tensors="pt",
            padding=True
        )
        
        # Forward pass
        outputs = self.model(**inputs)
        answer_ids = outputs.logits.argmax(dim=-1)
        
        # Get target token (use last token for answer)
        target_token_id = answer_ids[0, -1].item()
        target = outputs.logits[0, -1, target_token_id]
        
        # Backward pass
        self.model.zero_grad()
        target.backward()
        
        # Process activations and gradients
        activations = self.activations['vision'].detach()
        gradients = self.gradients['vision'].detach()
        
        # Grad-CAM calculation
        weights = gradients.mean(dim=[1, 2], keepdim=True)  # Global average pooling
        cam = (weights * activations).sum(dim=-1, keepdims=True)
        cam = torch.relu(cam)
        
        # Reshape and normalize
        cam = cam.squeeze().cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        return cam

    def visualize(self, image, cam):
        # Resize CAM to original image size
        img_size = image.size[::-1]  # (width, height) -> (height, width)
        cam = cv2.resize(cam, img_size)
        
        # Convert to heatmap
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Superimpose on original image
        superimposed = np.array(image) * 0.4 + heatmap * 0.6
        return Image.fromarray(np.uint8(superimposed))