File size: 3,644 Bytes
34fc4ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f8bfb3
34fc4ed
 
 
6f8bfb3
34fc4ed
 
 
6f8bfb3
34fc4ed
 
ddcd2c4
 
 
 
 
 
 
 
 
 
 
 
 
34fc4ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70964ca
 
 
 
 
 
34fc4ed
 
 
c5d6aa2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import numpy as np
import tensorflow as tf
import torch 
import torch.nn as nn
import torch.optim as optim 
import torch.nn.functional as F
import timm
from PIL import Image
from torchvision import transforms

from Models import ResNet, EfficientNet, BaseLine

def get_model(model_name, classes, device):

    if model_name == 'Inception-V3':
        model = tf.lite.Interpreter(model_path='vgg.tflite')
        model.allocate_tensors()

    elif model_name == 'VGG':
        model = tf.lite.Interpreter(model_path='vgg.tflite')
        model.allocate_tensors()

    elif model_name == 'EfficientNet-B0':
        model = EfficientNet(len(classes)).to(device)
        model.load_state_dict(torch.load('EfficientNet-Model.pt', map_location=torch.device(device)))

    elif model_name == 'ResNet-50':
        model = ResNet(len(classes)).to(device)
        model.load_state_dict(torch.load('model-resnet50.pt', map_location=torch.device(device)))

    elif model_name == 'Base Line Model':
        model = BaseLine(len(classes)).to(device)
        model.load_state_dict(torch.load('BaseLine-Model.pt', map_location=torch.device(device)))

    return model
 
def get_transform(input_img, device):
    normalize = transforms.Normalize(
          [0.485, 0.456, 0.406], 
          [0.229, 0.224, 0.225]
    )

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    input_img = test_transform(input_img).unsqueeze(0).to(device)
    return input_img

def make_predictions(input_img, model_name):
    classes = ['buildings','forest', 'glacier', 'mountain', 'sea', 'street']
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = get_model(model_name, classes, device)

    if model_name in ['EfficientNet-B0', 'ResNet-50', 'Base Line Model']:
        model.eval()
        img = get_transform(input_img, device)
        pred = model(img)
        if torch.cuda.is_available():
            pred = F.softmax(pred).detach().cpu().numpy()
            y_prob = pred.argmax(axis=1)[0]
        else:
            pred = F.softmax(pred).detach().numpy()
            y_prob = pred.argmax(axis=1)[0]
  
    if model_name in ['Inception-V3', 'VGG']:
        input_img = np.array(input_img)
        img = input_img / 255.
        input_tensor= np.array(np.expand_dims(img,0), dtype=np.float32)
        input_index = model.get_input_details()[0]["index"]
        
        # setting input tensor
        model.set_tensor(input_index, input_tensor)

        #Run the inference
        model.invoke()
        output_details = model.get_output_details()

        # output data of image
        pred = model.get_tensor(output_details[0]['index'])
        y_prob = pred.argmax()

    label = classes[y_prob]
    confidences = {classes[i]: float(pred[0][i]) for i in range(len(classes))}
    
    return label, confidences

demo = gr.Interface(
    fn = make_predictions, 
    inputs = [gr.Image(shape=(150, 150), type="pil"), gr.Dropdown(choices=['EfficientNet-B0', 'ResNet-50', 'Inception-V3', 'VGG', 'Base Line Model'], value='EfficientNet-B0', label='Choose Model')], 
    outputs = [gr.outputs.Textbox(label="Output Class"), gr.outputs.Label(label='Confidences')],
    title = "MultiClass Classifier",
    examples=[
              ["Buildings.jpg", 'EfficientNet-B0'], 
              ["Forest.jpg", 'EfficientNet-B0'],
              ['Street.jpg', 'EfficientNet-B0'], 
              ['glacier.jpg', 'EfficientNet-B0'], 
              ['mountain.jpg', 'EfficientNet-B0'], 
              ['sea.jpg', 'EfficientNet-B0']
            ],
)

demo.launch(debug=True, inline=True)