YOLOv8-OBB ONNRuntime推理部署

avatar
作者
猴君
阅读量:0

简介

本文将介绍如何使用 ONNX 进行 YOLOv8 Oriented Bounding Box (OBB) 推理。本例中,我们将使用 Python 编写的代码进行图像处理和对象检测,并展示如何加载模型、预处理图像、进行推理以及后处理结果。

代码

以下是实现 YOLOv8 OBB 推理的完整代码:

#!/usr/bin/env python # -*- coding: utf-8 -*- """ # @FileName      : YOLOv8_OBB.py # @Time          : 2024-07-25 17:33:48 # @Author        : XuMing # @Email         : 920972751@qq.com # @description   : YOLOv8 Oriented Bounding Box Inference using ONNX """ import cv2 import math import random import numpy as np import onnxruntime as ort from loguru import logger  class RotatedBOX:     def __init__(self, box, score, class_index):         self.box = box         self.score = score         self.class_index = class_index  class ONNXInfer:     def __init__(self, onnx_model, class_names, device='auto', conf_thres=0.5, nms_thres=0.4) -> None:         self.onnx_model = onnx_model         self.class_names = class_names         self.conf_thres = conf_thres         self.nms_thres = nms_thres         self.device = self._select_device(device)          logger.info(f"Loading model on {self.device}...")         self.session_model = ort.InferenceSession(             self.onnx_model,             providers=self.device,             sess_options=self._get_session_options()         )      def _select_device(self, device):         """         Select the appropriate device.         :param device: 'auto', 'cuda', or 'cpu'.         :return: List of providers.         """         if device == 'cuda' or (device == 'auto' and ort.get_device() == 'GPU'):             return ['CUDAExecutionProvider', 'CPUExecutionProvider']         return ['CPUExecutionProvider']      def _get_session_options(self):         sess_options = ort.SessionOptions()         sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED         sess_options.intra_op_num_threads = 4         return sess_options      def preprocess(self, img):         """         Preprocess the image for inference.         :param img: Input image.         :return: Preprocessed image blob, original image width, and original image height.         """         logger.info(             "Preprocessing input image to [1, channels, input_w, input_h] format")         height, width = img.shape[:2]         length = max(height, width)         image = np.zeros((length, length, 3), np.uint8)         image[0:height, 0:width] = img          input_shape = self.session_model.get_inputs()[0].shape[2:]         logger.debug(f"Input shape: {input_shape}")          blob = cv2.dnn.blobFromImage(             image, scalefactor=1 / 255, size=tuple(input_shape), swapRB=True)         logger.info(f"Preprocessed image blob shape: {blob.shape}")          return blob, image, width, height      def predict(self, img):         """         Perform inference on the image.         :param img: Input image.         :return: Inference results.         """         blob, resized_image, orig_width, orig_height = self.preprocess(img)         inputs = {self.session_model.get_inputs()[0].name: blob}         try:             outputs = self.session_model.run(None, inputs)         except Exception as e:             logger.error(f"Inference failed: {e}")             raise         return self.postprocess(outputs, resized_image, orig_width, orig_height)      def postprocess(self, outputs, resized_image, orig_width, orig_height):         """         Postprocess the model output.         :param outputs: Model outputs.         :param resized_image: Resized image used for inference.         :param orig_width: Original image width.         :param orig_height: Original image height.         :return: List of RotatedBOX objects.         """         output_data = outputs[0]         logger.info(             f"Postprocessing output data with shape: {output_data.shape}")          input_shape = self.session_model.get_inputs()[0].shape[2:]         x_factor = resized_image.shape[1] / float(input_shape[1])         y_factor = resized_image.shape[0] / float(input_shape[0])          flattened_output = output_data.flatten()         reshaped_output = np.reshape(             flattened_output, (output_data.shape[1], output_data.shape[2])).T          detected_boxes = []         confidences = []         rotated_boxes = []          num_classes = len(self.class_names)          for detection in reshaped_output:             class_scores = detection[4:4 + num_classes]             class_id = np.argmax(class_scores)             confidence_score = class_scores[class_id]              if confidence_score > self.conf_thres:                 cx, cy, width, height = detection[:4] * \                     [x_factor, y_factor, x_factor, y_factor]                 angle = detection[4 + num_classes]                  if 0.5 * math.pi <= angle <= 0.75 * math.pi:                     angle -= math.pi                  box = ((cx, cy), (width, height), angle * 180 / math.pi)                 rotated_box = RotatedBOX(box, confidence_score, class_id)                  detected_boxes.append(cv2.boundingRect(cv2.boxPoints(box)))                 rotated_boxes.append(rotated_box)                 confidences.append(confidence_score)          nms_indices = cv2.dnn.NMSBoxes(             detected_boxes, confidences, self.conf_thres, self.nms_thres)         remain_boxes = [rotated_boxes[i] for i in nms_indices.flatten()]          logger.info(f"Detected {len(remain_boxes)} objects after NMS")         return remain_boxes      def generate_colors(self, num_classes):         """         Generate a list of distinct colors for each class.          :param num_classes: Number of classes.         :return: List of RGB color tuples.         """         colors = []         for _ in range(num_classes):             colors.append((random.randint(0, 255), random.randint(                 0, 255), random.randint(0, 255)))         return colors      def drawshow(self, original_image, detected_boxes, class_labels):         """         Draw detected bounding boxes and labels on the image and display it.          :param original_image: The input image on which to draw the boxes.         :param detected_boxes: List of detected RotatedBOX objects.         :param class_labels: List of class labels.         """         # Generate random colors for each class         num_classes = len(class_labels)         colors = self.generate_colors(num_classes)          for detected_box in detected_boxes:             box = detected_box.box             points = cv2.boxPoints(box)              # Rescale the points back to the original image dimensions             points[:, 0] = points[:, 0]             points[:, 1] = points[:, 1]             points = np.int0(points)              class_id = detected_box.class_index              # Draw the bounding box with the color for the class             color = colors[class_id]             cv2.polylines(original_image, [points],                           isClosed=True, color=color, thickness=2)             # Put the class label text with the same color             cv2.putText(original_image, class_labels[class_id], (points[0][0], points[0][1]),                         cv2.FONT_HERSHEY_PLAIN, 1.0, color, 1)          # Display the image with drawn boxes         cv2.imshow("Detected Objects", original_image)         cv2.waitKey(0)         cv2.destroyAllWindows()  if __name__ == "__main__":     img_path = "OIP-C.jpg"     model_path = "yolov8s-obb.onnx"     class_names = [         "plane", "ship", "storage tank", "baseball diamond", "tennis court",         "basketball court", "ground track field", "harbor", "bridge",         "large vehicle", "small vehicle", "helicopter", "roundabout",         "soccer ball field", "swimming pool"     ]      img = cv2.imread(img_path)     if img is None:         logger.error(f"Failed to load image: {img_path}")     else:         app = ONNXInfer(onnx_model=model_path, class_names=class_names)         predictions = app.predict(img)         # logger.info(f"Inference results: {predictions}")         app.drawshow(img, predictions, class_names)  

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!