File size: 5,067 Bytes
1d7c63d
 
 
 
dd58475
 
1d7c63d
 
 
86104a0
1d7c63d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86104a0
1d7c63d
 
 
 
 
 
 
 
 
86104a0
1d7c63d
 
 
86104a0
 
 
 
 
dd58475
86104a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d7c63d
86104a0
 
 
 
 
 
 
 
 
1d7c63d
0c61c42
 
 
86104a0
 
 
0c61c42
1d7c63d
 
dd58475
 
 
 
 
 
 
 
86104a0
 
 
 
dd58475
 
 
 
 
 
 
 
 
 
 
 
 
0c61c42
 
 
 
dd58475
86104a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import xplique
import tensorflow as tf
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
                                  SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
                                  GradCAMPP, Lime, KernelShap,SobolAttributionMethod,HsicAttributionMethod)
from xplique.attributions.global_sensitivity_analysis import LatinHypercube
import numpy as np
import matplotlib.pyplot as plt
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
from labels import lookup_140
BATCH_SIZE = 1

def show(img, p=False, **kwargs):
    img = np.array(img, dtype=np.float32)

    # check if channel first
    if img.shape[0] == 1:
        img = img[0]

    # check if cmap
    if img.shape[-1] == 1:
        img = img[:,:,0]
    elif img.shape[-1] == 3:
        img = img[:,:,::-1]

    # normalize
    if img.max() > 1 or img.min() < 0:
        img -= img.min(); img/=img.max()
    # check if clip percentile
    if p is not False:
        img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
    plt.imshow(img, **kwargs)
    plt.axis('off')
    
    #return img
    


def explain(model, input_image,explain_method,nb_samples,size=600, n_classes=171) :
    """
    Generate explanations for a given model and dataset.
    :param model: The model to explain.
    :param X: The dataset.
    :param Y: The labels.
    :param explainer: The explainer to use.
    :param batch_size: The batch size to use.
    :return: The explanations.
    """
    print('using explain_method:',explain_method)
    # we only need the classification part of the model
    class_model = tf.keras.Model(model.input, model.output[1]) 
    
    explainers = []
    if explain_method=="Sobol":
        explainers.append(SobolAttributionMethod(class_model, grid_size=8, nb_design=32))
    if explain_method=="HSIC":
        explainers.append(HsicAttributionMethod(class_model, 
                                  grid_size=7, nb_design=1500,
                                  sampler = LatinHypercube(binary=True)))
    if explain_method=="Rise":
        explainers.append(Rise(class_model,nb_samples = nb_samples, batch_size = BATCH_SIZE,grid_size=15,
                 preservation_probability=0.5))
    if explain_method=="Saliency":
        explainers.append(Saliency(class_model))

    # explainers = [
    #     #Sobol, RISE, HSIC, Saliency
    #          #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
    #          #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
    #          #GradCAM(class_model),
    #          SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
    #          HsicAttributionMethod(class_model, 
    #                               grid_size=7, nb_design=1500,
    #                               sampler = LatinHypercube(binary=True)),
    #          Saliency(class_model),
    #          Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
    #              preservation_probability=0.5),
    #          #
    # ]

    cropped,repetitions = _clever_crop(input_image,(size,size))
    # size_repetitions = int(size//(repetitions.numpy()+1))
    # print(size)
    # print(type(input_image))
    # print(input_image.shape)
    # size_repetitions = int(size//(repetitions+1))
    # print(type(repetitions))
    # print(repetitions)
    # print(size_repetitions)
    # print(type(size_repetitions))
    X = preprocess(cropped,size=size)
    predictions = class_model.predict(np.array([X]))
    #Y = np.argmax(predictions)
    top_5_indices = np.argsort(predictions[0])[-5:][::-1]
    classes = []
    for index in top_5_indices:
        classes.append(lookup_140[index])
    #print(top_5_indices)
    X = np.expand_dims(X, 0)
    explanations = []
    for e,explainer in enumerate(explainers):
        print(f'{e}/{len(explainers)}')
        for i,Y in enumerate(top_5_indices):
            Y = tf.one_hot([Y], n_classes)
            print(f'{i}/{len(top_5_indices)}')
            phi = np.abs(explainer(X, Y))[0]
            if len(phi.shape) == 3:
                phi = np.mean(phi, -1)
            show(X[0])
            show(phi, p=1, alpha=0.4)
            # show(X[0][:,size_repetitions:2*size_repetitions,:])
            # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
            plt.savefig(f'phi_{e}{i}.png')
            explanations.append(f'phi_{e}{i}.png')
    # avg=[]
    # for i,Y in enumerate(top_5_indices):
    #     Y = tf.one_hot([Y], n_classes)
    #     print(f'{i}/{len(top_5_indices)}')
    #     phi = np.abs(explainer(X, Y))[0]
    #     if len(phi.shape) == 3:
    #         phi = np.mean(phi, -1)
    #     show(X[0][:,size_repetitions:2*size_repetitions,:])
    # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
    # plt.savefig(f'phi_6.png')
    # avg.append(f'phi_6.png')

    print('Done')
    if len(explanations)==1:
        explanations = explanations[0]
    # return explanations,avg
    return classes,explanations