File size: 2,840 Bytes
0a911cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c96cf7e
 
 
 
 
 
 
 
 
 
 
0a911cc
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image


def vis_attn(image, patch_attention_map, alpha=0.5, vis_option="none"):

    image = np.array(image)
    H, W, _ = image.shape

    seq_len = patch_attention_map.shape[0]
    grid_size = int(seq_len ** 0.5)

    patch_attention_map = patch_attention_map.reshape(grid_size, grid_size)
    patch_attention_map = cv2.resize(patch_attention_map.cpu().detach().numpy(), (W, H), interpolation=cv2.INTER_CUBIC)

    patch_attention_map = (patch_attention_map - patch_attention_map.min()) / (patch_attention_map.max() - patch_attention_map.min())
    patch_attention_map = np.uint8(255 * patch_attention_map)

    heatmap = cv2.applyColorMap(patch_attention_map, cv2.COLORMAP_JET)

    blended_image = cv2.addWeighted(image, 1 - alpha, heatmap, alpha, 0)
    blended_image = cv2.cvtColor(blended_image, cv2.COLOR_RGB2BGR)
    blended_image = Image.fromarray(blended_image)

    return blended_image


def detect_image(image, processor, clip_model, detection_model):

    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = clip_model(**inputs)
    
    last_hidden_states = outputs.last_hidden_state[:, 0, :]
    
    pred_score = float(detection_model(last_hidden_states)[0][0].cpu().detach().numpy())
    assert 0 <= pred_score <= 1
    
    for layer_idx in range(len(outputs.attentions)):
        attn_map = outputs.attentions[layer_idx]
        if layer_idx == 0:
            last_layer_attn = attn_map
        else:
            if layer_idx < 6:
                last_layer_attn += attn_map
    
    head_mean_attn = last_layer_attn.mean(dim=1)[0]
    cls_attention_map = head_mean_attn[0, 1:]

    blended_image = vis_attn(image, cls_attention_map)

    results = {
        "pred_score": pred_score,
        "attn_map": blended_image,
    }

    return results

def detect_video(frames, processor, clip_model, detection_model):
    image = frames[0]

    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = clip_model(**inputs)
    
    last_hidden_states = outputs.last_hidden_state[:, 0, :]
    
    pred_score = float(detection_model(last_hidden_states)[0][0].cpu().detach().numpy())
    assert 0 <= pred_score <= 1
    
    for layer_idx in range(len(outputs.attentions)):
        attn_map = outputs.attentions[layer_idx]
        if layer_idx == 0:
            last_layer_attn = attn_map
        else:
            if layer_idx < 6:
                last_layer_attn += attn_map
    
    head_mean_attn = last_layer_attn.mean(dim=1)[0]
    cls_attention_map = head_mean_attn[0, 1:]

    blended_image = vis_attn(image, cls_attention_map)

    results = {
        "pred_score": pred_score,
        "attn_map": blended_image,
    }

    return results