yeeeon commited on
Commit
19e3c32
·
1 Parent(s): 800b9a4
Files changed (2) hide show
  1. app.py +76 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ def get_model_name(name, batch_size, learning_rate, epoch):
9
+ """ Generate a name for the model consisting of all the hyperparameter values
10
+
11
+ Args:
12
+ config: Configuration object containing the hyperparameters
13
+ Returns:
14
+ path: A string with the hyperparameter name and value concatenated
15
+ """
16
+ path = "model_{0}_bs{1}_lr{2}_epoch{3}".format(name,
17
+ batch_size,
18
+ learning_rate,
19
+ epoch)
20
+ return path
21
+
22
+ class LargeNet(nn.Module):
23
+ def __init__(self):
24
+ super(LargeNet, self).__init__()
25
+ self.name = "large"
26
+ self.conv1 = nn.Conv2d(3, 5, 5)
27
+ self.pool = nn.MaxPool2d(2, 2)
28
+ self.conv2 = nn.Conv2d(5, 10, 5)
29
+ self.fc1 = nn.Linear(10 * 29 * 29, 32)
30
+ self.fc2 = nn.Linear(32, 8)
31
+
32
+ def forward(self, x):
33
+ x = self.pool(F.relu(self.conv1(x)))
34
+ x = self.pool(F.relu(self.conv2(x)))
35
+ x = x.view(-1, 10 * 29 * 29)
36
+ x = F.relu(self.fc1(x))
37
+ x = self.fc2(x)
38
+ x = x.squeeze(1) # Flatten to [batch_size]
39
+ return x
40
+
41
+ transform = transforms.Compose([
42
+ transforms.Resize((128, 128)), # Resize to 128x128
43
+ transforms.ToTensor(), # Convert to Tensor
44
+ transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
45
+ ])
46
+
47
+ def load_model():
48
+ net = LargeNet() #small or large network
49
+ model_path = get_model_name(net.name, batch_size=128, learning_rate=0.001, epoch=29)
50
+ state = torch.load(model_path)
51
+ net.load_state_dict(state)
52
+
53
+ net.eval()
54
+ return net
55
+
56
+ class_names = ["Gasoline_Can", "Pebbels", "pliers", "Screw_Driver", "Toolbox", "Wrench", "other"]
57
+
58
+
59
+ def predict(image):
60
+ model = load_model()
61
+ image = transform(image).unsqueeze(0)
62
+ with torch.no_grad():
63
+ output = model(image)
64
+ _, pred = torch.max(output, 1)
65
+ return class_names[pred.item()]
66
+
67
+ interface = gr.Interface(
68
+ fn=predict,
69
+ inputs=gr.Image(type="pil"),
70
+ outputs="label",
71
+ title="Mechanical Tools Classifier",
72
+ description="Upload an image to classify it as one of the mechanical tools."
73
+ )
74
+
75
+ if __name__ == "__main__":
76
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow
5
+ numpy
6
+ pandas