File size: 3,935 Bytes
57fb0f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import onnx
import onnxruntime as ort
import numpy as np

import cv2
import yaml

import copy

class DetectionModel():
    def __init__(self, model_path="weights/best.onnx"):
        self.current_input = None
        self.latest_output = None
        self.model = None
        self.model_ckpt = model_path

        if self.__check_model():
            self.model = ort.InferenceSession(self.model_ckpt)
        else:
            raise Exception("Model couldn't be validated using ONNX, please check the checkpoint")

        self._load_labels()

    def __check_model(self):
        model = onnx.load(self.model_ckpt)
        try:
            onnx.checker.check_model(model)
            return True
        except:
            return False

    def _preprepocess_input(self, image: np.ndarray) -> np.ndarray:
        """ Preprocess the input image
         
        Resizes the image to 640x640, transposes the matrix so that it's CxHxW and normalizes the image.
        Then the result is converted to `np.float32` and returned with the extra `batch` dimension

        Args:
            image (np.ndarray): The input image
        
        Returns:
            processed_image (np.ndarray): The preprocessed image as 1x3x640x640 `np.float32` array

        """
        processed_image = copy.deepcopy(image)
        processed_image = cv2.resize(processed_image, (640, 640))
        processed_image = processed_image.transpose(2, 0, 1)
        processed_image = (processed_image / 255.0).astype(np.float32)
        processed_image = np.expand_dims(processed_image, axis=0)

        return processed_image


    def _postprocess_output(self, predictions) -> np.ndarray:
        """ Postprocess the output of the model
         
        Args:
            predictions (np.ndarray): The output of the model as a `np.ndarray`
        
        Returns:
            detections (np.ndarray): The detections as a `np.ndarray` with shape (N, 6) where N is the number of detections.
                                      The columns are as follows: [x1, y1, x2, y2, confidence, class]

        """
        w_ratio = self.current_input.shape[1] / 640
        h_ratio = self.current_input.shape[0] / 640

        detections = []
        for pred in predictions:
            # detections.append([int(pred[0]), int(pred[1]), int(pred[2]), int(pred[3]), pred[4], self.ix2l[pred[5]]])
            detections.append([int(pred[0] * w_ratio), int(pred[1] * h_ratio), int(pred[2] * w_ratio), int(pred[3] * h_ratio), pred[4], self.ix2l[pred[5]]])
    
        return list(detections)

    def _load_labels(self):
        with open("data.yaml", "r") as f:
            data = yaml.safe_load(f)
        
        self.labels = data['names']
        self.l2ix = {l:i for i, l in enumerate(self.labels)}
        self.ix2l = {i:l for i, l in enumerate(self.labels)}

    def __call__(self, image: np.ndarray):

        processed_image = self._preprepocess_input(image)

        self.latest_output = list(self.model.run(None, {"images": processed_image})[0][0])
        self.current_input = image

        detections = self._postprocess_output(self.latest_output)

        return detections

    def visualize(self, input_image: np.ndarray, detections: list[list]) -> np.ndarray:
        """ Visualizes the detections on the current input image
        
        Args:
            detections (list[list]): The detections as a list of lists
        
        Returns:
            image (np.ndarray): The image with the detections drawn on it

        """
        image = copy.deepcopy(input_image)
        for det in detections:
            x1, y1, x2, y2, conf, label = det
            cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(image, f"{label}: {conf:.3f}", (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
        
        image = cv2.resize(image, self.current_input.shape[:2][::-1])

        return image