File size: 6,724 Bytes
1d7c63d
 
 
 
dd58475
 
1d7c63d
 
 
86104a0
1c93d1b
1d7c63d
 
1c93d1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5579c05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d7c63d
99ddcfc
f8b140a
1d7c63d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5579c05
 
1d7c63d
 
 
 
 
 
 
5579c05
1d7c63d
 
 
 
 
 
 
 
 
86104a0
1d7c63d
 
 
86104a0
 
 
 
 
dd58475
86104a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb27f59
86104a0
 
 
 
 
 
 
 
 
eb27f59
 
 
0c61c42
 
 
86104a0
 
 
0c61c42
1d7c63d
 
dd58475
 
 
 
 
 
 
 
1330097
 
5579c05
 
86104a0
 
dd58475
 
 
 
 
 
 
 
 
 
 
 
 
0c61c42
 
 
 
dd58475
86104a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import xplique
import tensorflow as tf
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
                                  SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
                                  GradCAMPP, Lime, KernelShap,SobolAttributionMethod,HsicAttributionMethod)
from xplique.attributions.global_sensitivity_analysis import LatinHypercube
import numpy as np
import matplotlib.pyplot as plt
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
from labels import lookup_140
import cv2
BATCH_SIZE = 1

def preprocess_image(image, output_size=(300, 300)):
    #shape (height, width, channels)
    h, w = image.shape[:2]

    #padding
    if h > w:
        padding = (h - w) // 2
        image_padded = cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
    else:
        padding = (w - h) // 2
        image_padded = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
    
    # resize
    image_resized = cv2.resize(image_padded, output_size, interpolation=cv2.INTER_AREA)

    return image_resized


def transform(image, original_size,output_size):
    """
    resize xai output back to original scale and pad to square-shape
    """
    h,w = original_size
    image = cv2.resize(image,(h,w), interpolation = cv2.INTER_AREA)
    if h > w:
        padding = (h - w) // 2
        image= cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
    else:
        padding = (w - h) // 2
        image = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
    image = cv2.resize(image,output_size, interpolation = cv2.INTER_AREA)
    return image

def show(img, original_size, output_size,p=False, **kwargs):

    #img = preprocess_image(img, output_size=(output_size,output_size))

    # check if channel first
    if img.shape[0] == 1:
        img = img[0]

    # check if cmap
    if img.shape[-1] == 1:
        img = img[:,:,0]
    elif img.shape[-1] == 3:
        img = img[:,:,::-1]

    # normalize
    if img.max() > 1 or img.min() < 0:
        img -= img.min(); img/=img.max()
    # check if clip percentile
    if p is not False:
        img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
    
    img = transform(img,original_size=original_size,output_size=output_size)
    plt.imshow(img, **kwargs)
    plt.axis('off')
    
    #return img
    


def explain(model, input_image,h,w,explain_method,nb_samples,size=600, n_classes=171) :
    """
    Generate explanations for a given model and dataset.
    :param model: The model to explain.
    :param X: The dataset.
    :param Y: The labels.
    :param explainer: The explainer to use.
    :param batch_size: The batch size to use.
    :return: The explanations.
    """
    print('using explain_method:',explain_method)
    # we only need the classification part of the model
    class_model = tf.keras.Model(model.input, model.output[1]) 
    
    explainers = []
    if explain_method=="Sobol":
        explainers.append(SobolAttributionMethod(class_model, grid_size=8, nb_design=32))
    if explain_method=="HSIC":
        explainers.append(HsicAttributionMethod(class_model, 
                                  grid_size=7, nb_design=1500,
                                  sampler = LatinHypercube(binary=True)))
    if explain_method=="Rise":
        explainers.append(Rise(class_model,nb_samples = nb_samples, batch_size = BATCH_SIZE,grid_size=15,
                 preservation_probability=0.5))
    if explain_method=="Saliency":
        explainers.append(Saliency(class_model))

    # explainers = [
    #     #Sobol, RISE, HSIC, Saliency
    #          #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
    #          #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
    #          #GradCAM(class_model),
    #          SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
    #          HsicAttributionMethod(class_model, 
    #                               grid_size=7, nb_design=1500,
    #                               sampler = LatinHypercube(binary=True)),
    #          Saliency(class_model),
    #          Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
    #              preservation_probability=0.5),
    #          #
    # ]

    # cropped,repetitions = _clever_crop(input_image,(size,size))
    # size_repetitions = int(size//(repetitions.numpy()+1))
    # print(size)
    # print(type(input_image))
    # print(input_image.shape)
    # size_repetitions = int(size//(repetitions+1))
    # print(type(repetitions))
    # print(repetitions)
    # print(size_repetitions)
    # print(type(size_repetitions))
    # X = preprocess(cropped,size=size)
    X = tf.image.resize(input_image, (size, size))
    X = tf.reshape(X, (size, size, 3))/255
    predictions = class_model.predict(np.array([X]))
    #Y = np.argmax(predictions)
    top_5_indices = np.argsort(predictions[0])[-5:][::-1]
    classes = []
    for index in top_5_indices:
        classes.append(lookup_140[index])
    #print(top_5_indices)
    X = np.expand_dims(X, 0)
    explanations = []
    for e,explainer in enumerate(explainers):
        print(f'{e}/{len(explainers)}')
        for i,Y in enumerate(top_5_indices):
            Y = tf.one_hot([Y], n_classes)
            print(f'{i}/{len(top_5_indices)}')
            phi = np.abs(explainer(X, Y))[0]
            if len(phi.shape) == 3:
                phi = np.mean(phi, -1)
            #apply Gaussian smoothing
            phi_smoothed = cv2.GaussianBlur(phi, (5, 5), sigmaX=1.0, sigmaY=1.0)
            show(X[0],original_size=(h,w),output_size = (size,size))
            show(phi_smoothed, original_size=(h,w),output_size = (size,size),p=1, alpha=0.2)
            # show(X[0][:,size_repetitions:2*size_repetitions,:])
            # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
            plt.savefig(f'phi_{e}{i}.png')
            explanations.append(f'phi_{e}{i}.png')
    # avg=[]
    # for i,Y in enumerate(top_5_indices):
    #     Y = tf.one_hot([Y], n_classes)
    #     print(f'{i}/{len(top_5_indices)}')
    #     phi = np.abs(explainer(X, Y))[0]
    #     if len(phi.shape) == 3:
    #         phi = np.mean(phi, -1)
    #     show(X[0][:,size_repetitions:2*size_repetitions,:])
    # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
    # plt.savefig(f'phi_6.png')
    # avg.append(f'phi_6.png')

    print('Done')
    if len(explanations)==1:
        explanations = explanations[0]
    # return explanations,avg
    return classes,explanations