Spaces:
Sleeping
Sleeping
import onnx | |
import onnxruntime as ort | |
import numpy as np | |
import cv2 | |
import yaml | |
import copy | |
class DetectionModel(): | |
def __init__(self, model_path="weights/best.onnx"): | |
self.current_input = None | |
self.latest_output = None | |
self.model = None | |
self.model_ckpt = model_path | |
if self.__check_model(): | |
self.model = ort.InferenceSession(self.model_ckpt) | |
else: | |
raise Exception("Model couldn't be validated using ONNX, please check the checkpoint") | |
self._load_labels() | |
def __check_model(self): | |
model = onnx.load(self.model_ckpt) | |
try: | |
onnx.checker.check_model(model) | |
return True | |
except: | |
return False | |
def _preprepocess_input(self, image: np.ndarray) -> np.ndarray: | |
""" Preprocess the input image | |
Resizes the image to 640x640, transposes the matrix so that it's CxHxW and normalizes the image. | |
Then the result is converted to `np.float32` and returned with the extra `batch` dimension | |
Args: | |
image (np.ndarray): The input image | |
Returns: | |
processed_image (np.ndarray): The preprocessed image as 1x3x640x640 `np.float32` array | |
""" | |
processed_image = copy.deepcopy(image) | |
processed_image = cv2.resize(processed_image, (640, 640)) | |
processed_image = processed_image.transpose(2, 0, 1) | |
processed_image = (processed_image / 255.0).astype(np.float32) | |
processed_image = np.expand_dims(processed_image, axis=0) | |
return processed_image | |
def _postprocess_output(self, predictions) -> np.ndarray: | |
""" Postprocess the output of the model | |
Args: | |
predictions (np.ndarray): The output of the model as a `np.ndarray` | |
Returns: | |
detections (np.ndarray): The detections as a `np.ndarray` with shape (N, 6) where N is the number of detections. | |
The columns are as follows: [x1, y1, x2, y2, confidence, class] | |
""" | |
w_ratio = self.current_input.shape[1] / 640 | |
h_ratio = self.current_input.shape[0] / 640 | |
detections = [] | |
for pred in predictions: | |
# detections.append([int(pred[0]), int(pred[1]), int(pred[2]), int(pred[3]), pred[4], self.ix2l[pred[5]]]) | |
detections.append([int(pred[0] * w_ratio), int(pred[1] * h_ratio), int(pred[2] * w_ratio), int(pred[3] * h_ratio), pred[4], self.ix2l[pred[5]]]) | |
return list(detections) | |
def _load_labels(self): | |
with open("data.yaml", "r") as f: | |
data = yaml.safe_load(f) | |
self.labels = data['names'] | |
self.l2ix = {l:i for i, l in enumerate(self.labels)} | |
self.ix2l = {i:l for i, l in enumerate(self.labels)} | |
def __call__(self, image: np.ndarray): | |
processed_image = self._preprepocess_input(image) | |
self.latest_output = list(self.model.run(None, {"images": processed_image})[0][0]) | |
self.current_input = image | |
detections = self._postprocess_output(self.latest_output) | |
return detections | |
def visualize(self, input_image: np.ndarray, detections: list[list]) -> np.ndarray: | |
""" Visualizes the detections on the current input image | |
Args: | |
detections (list[list]): The detections as a list of lists | |
Returns: | |
image (np.ndarray): The image with the detections drawn on it | |
""" | |
image = copy.deepcopy(input_image) | |
for det in detections: | |
x1, y1, x2, y2, conf, label = det | |
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
cv2.putText(image, f"{label}: {conf:.3f}", (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA) | |
image = cv2.resize(image, self.current_input.shape[:2][::-1]) | |
return image |