阅读量:0
简介
本文将介绍如何使用 ONNX 进行 YOLOv8 Oriented Bounding Box (OBB) 推理。本例中,我们将使用 Python 编写的代码进行图像处理和对象检测,并展示如何加载模型、预处理图像、进行推理以及后处理结果。
代码
以下是实现 YOLOv8 OBB 推理的完整代码:
#!/usr/bin/env python # -*- coding: utf-8 -*- """ # @FileName : YOLOv8_OBB.py # @Time : 2024-07-25 17:33:48 # @Author : XuMing # @Email : 920972751@qq.com # @description : YOLOv8 Oriented Bounding Box Inference using ONNX """ import cv2 import math import random import numpy as np import onnxruntime as ort from loguru import logger class RotatedBOX: def __init__(self, box, score, class_index): self.box = box self.score = score self.class_index = class_index class ONNXInfer: def __init__(self, onnx_model, class_names, device='auto', conf_thres=0.5, nms_thres=0.4) -> None: self.onnx_model = onnx_model self.class_names = class_names self.conf_thres = conf_thres self.nms_thres = nms_thres self.device = self._select_device(device) logger.info(f"Loading model on {self.device}...") self.session_model = ort.InferenceSession( self.onnx_model, providers=self.device, sess_options=self._get_session_options() ) def _select_device(self, device): """ Select the appropriate device. :param device: 'auto', 'cuda', or 'cpu'. :return: List of providers. """ if device == 'cuda' or (device == 'auto' and ort.get_device() == 'GPU'): return ['CUDAExecutionProvider', 'CPUExecutionProvider'] return ['CPUExecutionProvider'] def _get_session_options(self): sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED sess_options.intra_op_num_threads = 4 return sess_options def preprocess(self, img): """ Preprocess the image for inference. :param img: Input image. :return: Preprocessed image blob, original image width, and original image height. """ logger.info( "Preprocessing input image to [1, channels, input_w, input_h] format") height, width = img.shape[:2] length = max(height, width) image = np.zeros((length, length, 3), np.uint8) image[0:height, 0:width] = img input_shape = self.session_model.get_inputs()[0].shape[2:] logger.debug(f"Input shape: {input_shape}") blob = cv2.dnn.blobFromImage( image, scalefactor=1 / 255, size=tuple(input_shape), swapRB=True) logger.info(f"Preprocessed image blob shape: {blob.shape}") return blob, image, width, height def predict(self, img): """ Perform inference on the image. :param img: Input image. :return: Inference results. """ blob, resized_image, orig_width, orig_height = self.preprocess(img) inputs = {self.session_model.get_inputs()[0].name: blob} try: outputs = self.session_model.run(None, inputs) except Exception as e: logger.error(f"Inference failed: {e}") raise return self.postprocess(outputs, resized_image, orig_width, orig_height) def postprocess(self, outputs, resized_image, orig_width, orig_height): """ Postprocess the model output. :param outputs: Model outputs. :param resized_image: Resized image used for inference. :param orig_width: Original image width. :param orig_height: Original image height. :return: List of RotatedBOX objects. """ output_data = outputs[0] logger.info( f"Postprocessing output data with shape: {output_data.shape}") input_shape = self.session_model.get_inputs()[0].shape[2:] x_factor = resized_image.shape[1] / float(input_shape[1]) y_factor = resized_image.shape[0] / float(input_shape[0]) flattened_output = output_data.flatten() reshaped_output = np.reshape( flattened_output, (output_data.shape[1], output_data.shape[2])).T detected_boxes = [] confidences = [] rotated_boxes = [] num_classes = len(self.class_names) for detection in reshaped_output: class_scores = detection[4:4 + num_classes] class_id = np.argmax(class_scores) confidence_score = class_scores[class_id] if confidence_score > self.conf_thres: cx, cy, width, height = detection[:4] * \ [x_factor, y_factor, x_factor, y_factor] angle = detection[4 + num_classes] if 0.5 * math.pi <= angle <= 0.75 * math.pi: angle -= math.pi box = ((cx, cy), (width, height), angle * 180 / math.pi) rotated_box = RotatedBOX(box, confidence_score, class_id) detected_boxes.append(cv2.boundingRect(cv2.boxPoints(box))) rotated_boxes.append(rotated_box) confidences.append(confidence_score) nms_indices = cv2.dnn.NMSBoxes( detected_boxes, confidences, self.conf_thres, self.nms_thres) remain_boxes = [rotated_boxes[i] for i in nms_indices.flatten()] logger.info(f"Detected {len(remain_boxes)} objects after NMS") return remain_boxes def generate_colors(self, num_classes): """ Generate a list of distinct colors for each class. :param num_classes: Number of classes. :return: List of RGB color tuples. """ colors = [] for _ in range(num_classes): colors.append((random.randint(0, 255), random.randint( 0, 255), random.randint(0, 255))) return colors def drawshow(self, original_image, detected_boxes, class_labels): """ Draw detected bounding boxes and labels on the image and display it. :param original_image: The input image on which to draw the boxes. :param detected_boxes: List of detected RotatedBOX objects. :param class_labels: List of class labels. """ # Generate random colors for each class num_classes = len(class_labels) colors = self.generate_colors(num_classes) for detected_box in detected_boxes: box = detected_box.box points = cv2.boxPoints(box) # Rescale the points back to the original image dimensions points[:, 0] = points[:, 0] points[:, 1] = points[:, 1] points = np.int0(points) class_id = detected_box.class_index # Draw the bounding box with the color for the class color = colors[class_id] cv2.polylines(original_image, [points], isClosed=True, color=color, thickness=2) # Put the class label text with the same color cv2.putText(original_image, class_labels[class_id], (points[0][0], points[0][1]), cv2.FONT_HERSHEY_PLAIN, 1.0, color, 1) # Display the image with drawn boxes cv2.imshow("Detected Objects", original_image) cv2.waitKey(0) cv2.destroyAllWindows() if __name__ == "__main__": img_path = "OIP-C.jpg" model_path = "yolov8s-obb.onnx" class_names = [ "plane", "ship", "storage tank", "baseball diamond", "tennis court", "basketball court", "ground track field", "harbor", "bridge", "large vehicle", "small vehicle", "helicopter", "roundabout", "soccer ball field", "swimming pool" ] img = cv2.imread(img_path) if img is None: logger.error(f"Failed to load image: {img_path}") else: app = ONNXInfer(onnx_model=model_path, class_names=class_names) predictions = app.predict(img) # logger.info(f"Inference results: {predictions}") app.drawshow(img, predictions, class_names)