diff --git a/sahi/models/base.py b/sahi/models/base.py index b0e036c2d..fc2ba398d 100644 --- a/sahi/models/base.py +++ b/sahi/models/base.py @@ -13,6 +13,9 @@ class DetectionModel: required_packages: list[str] | None = None + _object_prediction_list_per_image: list[list[ObjectPrediction]] + _shift_amount_indices_per_prediction: list[int] + _original_shape: np._AnyShapeT def __init__( self, @@ -60,6 +63,7 @@ def __init__( self._original_predictions = None self._object_prediction_list_per_image = None self.set_device(device) + self._shift_amount_indices_per_prediction = None # automatically ensure dependencies self.check_dependencies() @@ -121,6 +125,16 @@ def perform_inference(self, image: np.ndarray): A numpy array that contains the image to be predicted. """ raise NotImplementedError() + + def perform_per_image_batch_inference(self, image_list: np.ndarray | list[np.ndarray], order: str = "RGB") -> None: + """This function should be implemented in a way that prediction should be performed using self.model and the + prediction result should be set to self._original_predictions. + + Args: + image: np.ndarray or list + A numpy array that contains the image to be predicted or a list of numpy arrays. + """ + raise NotImplementedError() def _create_object_prediction_list_from_original_predictions( self, @@ -196,3 +210,7 @@ def object_prediction_list_per_image(self) -> list[list[ObjectPrediction]]: @property def original_predictions(self): return self._original_predictions + + @property + def shift_amount_indices_per_prediction(self): + return self._shift_amount_indices_per_prediction \ No newline at end of file diff --git a/sahi/models/ultralytics.py b/sahi/models/ultralytics.py index 1274e557a..23591bc19 100644 --- a/sahi/models/ultralytics.py +++ b/sahi/models/ultralytics.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, List import cv2 import numpy as np @@ -61,6 +61,101 @@ def set_model(self, model: Any, **kwargs): if not self.category_mapping: category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)} self.category_mapping = category_mapping + + def perform_per_image_batch_inference(self, image_list: list[np.ndarray], order: str = None): + """Prediction is performed on a list of images using self.model and the prediction result is set to self._original_predictions. + + Args: + image: list[np.ndarray] + A list of numpy array of shapes (H, W, C) that contains the images to be predicted. + order: str, optional + The order of the image. + + """ + + # Confirm model is loaded + + import torch + from ultralytics.engine.results import Results + from ultralytics import YOLO + + if order == "RGB": + image_list = [cv2.cvtColor(image, cv2.COLOR_RGB2BGR) for image in image_list] + + if self.model is None: + raise ValueError("Model is not loaded, load it by calling .load_model()") + + kwargs = {"cfg": self.config_path, "verbose": False, "conf": self.confidence_threshold, "device": self.device, "batch": len(image_list)} + + if self.image_size is not None: + kwargs = {"imgsz": self.image_size, **kwargs} + + model: YOLO = self.model + + prediction_result: List[Results] = model(image_list, **kwargs) # YOLO expects numpy arrays to have BGR + + temp_results_list: torch.Tensor | np.ndarray[np._AnyShape, np.dtype[Any]] = [] + temp_shift_idxs: list[int] = [] + + # Handle different result types for PyTorch vs ONNX models + # ONNX models might return results in a different format + for idx, image_result in enumerate(prediction_result): + processed_result = None + if self.has_mask: + from ultralytics.engine.results import Masks + + if not image_result.masks: + # Create empty masks if none exist + if hasattr(self.model, "device"): + device = self.model.device + else: + device = "cpu" # Default for ONNX models + image_result.masks = Masks( + torch.tensor([], device=device), image_result.boxes.orig_shape + ) + + # We do not filter results again as confidence threshold is already applied above + processed_result = [ + ( + result.boxes.data, + result.masks.data, + ) + for result in image_result + ] + + elif self.is_obb: + # For OBB task, get OBB points in xyxyxyxy format + device = getattr(self.model, "device", "cpu") + processed_result = [ + ( + # Get OBB data: xyxy, conf, cls + torch.cat( + [ + result.obb.xyxy, # box coordinates + result.obb.conf.unsqueeze(-1), # confidence scores + result.obb.cls.unsqueeze(-1), # class ids + ], + dim=1, + ) + if result.obb is not None + else torch.empty((0, 6), device=device), + # Get OBB points in (N, 4, 2) format + result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=device), + ) + for result in image_result + ] + else: # If model doesn't do segmentation or OBB then no need to check masks + # We do not filter results again as confidence threshold is already applied above + processed_result = [result.boxes.data for result in prediction_result] + + # Save shift indexes + temp_shift_idxs.extend([idx] * len(processed_result)) + temp_results_list.extend(processed_result) + + self._original_predictions = temp_results_list + self._original_shape = image_list[0].shape + # Save shift indexes + self._shift_amount_indices_per_prediction = temp_shift_idxs def perform_inference(self, image: np.ndarray): """Prediction is performed using self.model and the prediction result is set to self._original_predictions. @@ -73,6 +168,7 @@ def perform_inference(self, image: np.ndarray): # Confirm model is loaded import torch + from ultralytics.engine.results import Results if self.model is None: raise ValueError("Model is not loaded, load it by calling .load_model()") @@ -82,7 +178,7 @@ def perform_inference(self, image: np.ndarray): if self.image_size is not None: kwargs = {"imgsz": self.image_size, **kwargs} - prediction_result = self.model(image[:, :, ::-1], **kwargs) # YOLO expects numpy arrays to have BGR + prediction_result: List[Results] | torch.Tensor | np.ndarray[np._AnyShape, np.dtype[Any]] = self.model(image[:, :, ::-1], **kwargs) # YOLO expects numpy arrays to have BGR # Handle different result types for PyTorch vs ONNX models # ONNX models might return results in a different format @@ -199,6 +295,20 @@ def _create_object_prediction_list_from_original_predictions( Size of the full image after shifting, should be in the form of List[[height, width],[height, width],...] """ + use_indices = False + print(self._shift_amount_indices_per_prediction) + if len(self._original_predictions) != len(shift_amount_list): + if self._shift_amount_indices_per_prediction is None: + raise ValueError( + f"Number of predictions ({len(self._original_predictions)}) and shift_amount_list ({len(shift_amount_list)}) do not match." + ) + elif len(self._shift_amount_indices_per_prediction) == len(self._original_predictions): + use_indices = True + else: + raise ValueError( + f"Number of predictions ({len(self._original_predictions)}) and shift_amount_indices_per_prediction ({len(self._shift_amount_indices_per_prediction)}) do not match." + ) + original_predictions = self._original_predictions # compatibility for sahi v0.8.15 @@ -206,12 +316,15 @@ def _create_object_prediction_list_from_original_predictions( full_shape_list = fix_full_shape_list(full_shape_list) # handle all predictions - object_prediction_list_per_image = [] + object_prediction_list_per_image: List[List[ObjectPrediction]] = [] for image_ind, image_predictions in enumerate(original_predictions): - shift_amount = shift_amount_list[image_ind] - full_shape = None if full_shape_list is None else full_shape_list[image_ind] - object_prediction_list = [] + if use_indices: + shift_amount = shift_amount_list[self._shift_amount_indices_per_prediction[image_ind]] + else: + shift_amount = shift_amount_list[image_ind] + full_shape = None if full_shape_list is None or image_ind >= len(full_shape_list) else full_shape_list[image_ind] + object_prediction_list: List[ObjectPrediction] = [] # Extract boxes and optional masks/obb if self.has_mask or self.is_obb: diff --git a/sahi/predict.py b/sahi/predict.py index 6a16c2a51..2c1e5926a 100644 --- a/sahi/predict.py +++ b/sahi/predict.py @@ -1,9 +1,11 @@ from __future__ import annotations +import math import os import time from collections.abc import Generator from functools import cmp_to_key +from typing import List import numpy as np from PIL import Image @@ -44,30 +46,38 @@ LOW_MODEL_CONFIDENCE = 0.1 -def filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id): +def filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id, shift_amount_indices = []): return [ - obj_pred - for obj_pred in object_prediction_list + (obj_pred, shift_amount) + for obj_pred, shift_amount in zip(object_prediction_list, shift_amount_indices) + if obj_pred.category.name not in (exclude_classes_by_name or []) + and obj_pred.category.id not in (exclude_classes_by_id or []) + ] + +def filter_predictions_with_shift_indices(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id, shift_amount_indices = []): + return [ + (obj_pred, shift_amount) + for obj_pred, shift_amount in zip(object_prediction_list, shift_amount_indices) if obj_pred.category.name not in (exclude_classes_by_name or []) and obj_pred.category.id not in (exclude_classes_by_id or []) ] def get_prediction( - image, + image: str | np.ndarray | list[np.ndarray], detection_model, - shift_amount: list | None = None, + shift_amount: list[int, int] | list[list[int, int]] | None = None, full_shape=None, postprocess: PostprocessPredictions | None = None, verbose: int = 0, exclude_classes_by_name: list[str] | None = None, exclude_classes_by_id: list[int] | None = None, -) -> PredictionResult: +) -> PredictionResult | list[PredictionResult]: """Function for performing prediction for given image using given detection_model. Args: - image: str or np.ndarray - Location of image or numpy image matrix to slice + image: str, np.ndarray or list of np.ndarray + Location of image, numpy image matrix to slice or list of numpy image matrices detection_model: model.DetectionMode shift_amount: List To shift the box and mask predictions from sliced image to full @@ -90,16 +100,27 @@ def get_prediction( durations_in_seconds: a dict containing elapsed times for profiling """ durations_in_seconds = dict() - - # read image as pil - image_as_pil = read_image_as_pil(image) + # get prediction # ensure shift_amount is a list instance (avoid mutable default arg) if shift_amount is None: shift_amount = [0, 0] - time_start = time.perf_counter() - detection_model.perform_inference(np.ascontiguousarray(image_as_pil)) + # check if image is a list + if isinstance(image, list): + time_start = time.perf_counter() + if len(image) == 0: + raise ValueError("Input image list is empty.") + + if not isinstance(image[0], np.ndarray): + raise ValueError("Input image list should be a list of numpy arrays.") + + detection_model.perform_per_image_batch_inference(image) + else: + # read image as pil + image_as_pil = read_image_as_pil(image) + time_start = time.perf_counter() + detection_model.perform_inference(np.ascontiguousarray(image_as_pil)) time_end = time.perf_counter() - time_start durations_in_seconds["prediction"] = time_end @@ -113,7 +134,39 @@ def get_prediction( shift_amount=shift_amount, full_shape=full_shape, ) + object_prediction_list: list[ObjectPrediction] = detection_model.object_prediction_list + if isinstance(image, list): + shift_amount_indices_per_prediction = detection_model.shift_amount_indices_per_prediction + object_prediction_list = filter_predictions_with_shift_indices(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id, shift_amount_indices_per_prediction) + object_prediction_dict = {} + + for obj_preds, shift_amount_index in object_prediction_list: + if shift_amount_index not in object_prediction_dict: + object_prediction_dict[shift_amount_index] = [] + if postprocess is not None: + obj_preds = postprocess(obj_preds) + object_prediction_dict[shift_amount_index].append(obj_preds) + + time_end = time.perf_counter() - time_start + durations_in_seconds["postprocess"] = time_end / len(object_prediction_dict) + + if verbose == 1: + print( + "Prediction performed in", + durations_in_seconds["prediction"], + "seconds.", + ) + + return [ + PredictionResult( + image=image[i], + object_prediction_list=obj_preds, + durations_in_seconds=durations_in_seconds, + ) + for i, obj_preds in object_prediction_dict.items() + ] + object_prediction_list = filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id) # postprocess matching predictions @@ -156,6 +209,7 @@ def get_sliced_prediction( exclude_classes_by_id: list[int] | None = None, progress_bar: bool = False, progress_callback=None, + num_batch: int = 1, ) -> PredictionResult: """Function for slice image + get predicion for each slice + combine predictions in full image. @@ -224,8 +278,6 @@ def get_sliced_prediction( # for profiling durations_in_seconds = dict() - # currently only 1 batch supported - num_batch = 1 # create slices from full image time_start = time.perf_counter() slice_image_result = slice_image( @@ -264,7 +316,7 @@ def get_sliced_prediction( postprocess_time = 0 time_start = time.perf_counter() # create prediction input - num_group = int(num_slices / num_batch) + num_group = math.ceil(num_slices / num_batch) if verbose == 1 or verbose == 2: tqdm.write(f"Performing prediction on {num_slices} slices.") @@ -273,20 +325,22 @@ def get_sliced_prediction( else: slice_iterator = range(num_group) - object_prediction_list = [] + object_prediction_list: List[ObjectPrediction] = [] # perform sliced prediction for group_ind in slice_iterator: # prepare batch (currently supports only 1 batch) - image_list = [] - shift_amount_list = [] - for image_ind in range(num_batch): - image_list.append(slice_image_result.images[group_ind * num_batch + image_ind]) - shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind]) + image_list: list[np.ndarray] = [] + shift_amount_list: list[list[int]] = [] + for image_ind in range(num_batch if num_slices > num_batch else num_slices): + idx = min(group_ind * num_batch + image_ind, num_slices - 1) + image_list.append(slice_image_result.images[idx]) + shift_amount_list.append(slice_image_result.starting_pixels[idx]) + # perform batch prediction prediction_result = get_prediction( - image=image_list[0], + image=image_list, detection_model=detection_model, - shift_amount=shift_amount_list[0], + shift_amount=shift_amount_list, full_shape=[ slice_image_result.original_image_height, slice_image_result.original_image_width, @@ -294,10 +348,17 @@ def get_sliced_prediction( exclude_classes_by_name=exclude_classes_by_name, exclude_classes_by_id=exclude_classes_by_id, ) - # convert sliced predictions to full predictions - for object_prediction in prediction_result.object_prediction_list: - if object_prediction: # if not empty - object_prediction_list.append(object_prediction.get_shifted_object_prediction()) + + if isinstance(prediction_result, list): + for prediction in prediction_result: + for object_prediction in prediction.object_prediction_list: + if object_prediction: # if not empty + object_prediction_list.append(object_prediction.get_shifted_object_prediction()) + else: + # convert sliced predictions to full predictions + for object_prediction in prediction_result.object_prediction_list: + if object_prediction: # if not empty + object_prediction_list.append(object_prediction.get_shifted_object_prediction()) # merge matching predictions during sliced prediction if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length: diff --git a/sahi/prediction.py b/sahi/prediction.py index 975e38a07..b5c8d2723 100644 --- a/sahi/prediction.py +++ b/sahi/prediction.py @@ -1,7 +1,7 @@ from __future__ import annotations import copy -from typing import Any +from typing import Any, Optional, TypedDict import numpy as np from PIL import Image @@ -154,13 +154,19 @@ def __repr__(self): score: {self.score}, category: {self.category}>""" +class DurationMetrics(TypedDict): + prediction: Optional[float] + postprocess: Optional[float] + slice: Optional[float] + model_load: Optional[float] + export_files: Optional[float] class PredictionResult: def __init__( self, object_prediction_list: list[ObjectPrediction], image: Image.Image | str | np.ndarray, - durations_in_seconds: dict[str, Any] = dict(), + durations_in_seconds: DurationMetrics = dict(), ): self.image: Image.Image = read_image_as_pil(image) self.image_width, self.image_height = self.image.size diff --git a/tests/test_predict.py b/tests/test_predict.py index f9f87c472..b2fec192d 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -12,6 +12,7 @@ from sahi.utils.file import download_from_url from .utils.ultralytics import UltralyticsConstants, download_yolo11n_model +import torch MODEL_DEVICE = "cpu" CONFIDENCE_THRESHOLD = 0.5 @@ -342,6 +343,153 @@ def test_get_sliced_prediction_yolo11(): num_car += 1 assert num_car > 0 +def test_get_sliced_batch_prediction_yolo11(): + # init model + download_yolo11n_model() + + yolo11_detection_model = UltralyticsDetectionModel( + model_path=UltralyticsConstants.YOLO11N_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=False, + image_size=IMAGE_SIZE, + ) + yolo11_detection_model.load_model() + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + + slice_height = 512 + slice_width = 512 + overlap_height_ratio = 0.1 + overlap_width_ratio = 0.2 + postprocess_type = "GREEDYNMM" + match_metric = "IOS" + match_threshold = 0.5 + class_agnostic = True + + # get sliced prediction + prediction_result = get_sliced_prediction( + image=image_path, + detection_model=yolo11_detection_model, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + perform_standard_pred=False, + postprocess_type=postprocess_type, + postprocess_match_threshold=match_threshold, + postprocess_match_metric=match_metric, + postprocess_class_agnostic=class_agnostic, + num_batch=4, + ) + object_prediction_list = prediction_result.object_prediction_list + + # compare + assert len(object_prediction_list) > 0 + num_person = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "person": + num_person += 1 + assert num_person == 0 + num_truck = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "truck": + num_truck += 1 + assert num_truck == 0 + num_car = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "car": + num_car += 1 + assert num_car > 0 + +def test_batch_vs_single_prediction_on_cuda_yolo11(): + # init model + download_yolo11n_model() + # if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available. Please install the CUDA version of PyTorch") + + device = "cuda:0" + + yolo11_detection_model = UltralyticsDetectionModel( + model_path=UltralyticsConstants.YOLO11N_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=device, + category_remapping=None, + load_at_init=False, + image_size=IMAGE_SIZE, + ) + yolo11_detection_model.load_model() + + yolo11_detection_model2 = UltralyticsDetectionModel( + model_path=UltralyticsConstants.YOLO11N_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=device, + category_remapping=None, + load_at_init=False, + image_size=IMAGE_SIZE, + ) + yolo11_detection_model2.load_model() + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + + slice_height = 512 + slice_width = 512 + overlap_height_ratio = 0.1 + overlap_width_ratio = 0.2 + postprocess_type = "GREEDYNMM" + match_metric = "IOS" + match_threshold = 0.5 + class_agnostic = True + + values1 = [] + values2 = [] + + N = 1000 + + for i in range(N): + # get batch sliced prediction + prediction_result1 = get_sliced_prediction( + image=image_path, + detection_model=yolo11_detection_model, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + perform_standard_pred=False, + postprocess_type=postprocess_type, + postprocess_match_threshold=match_threshold, + postprocess_match_metric=match_metric, + postprocess_class_agnostic=class_agnostic, + num_batch=4, + ) + + values1.append(prediction_result1.durations_in_seconds['prediction']) + + for i in range(N): + # get single sliced prediction + prediction_result2 = get_sliced_prediction( + image=image_path, + detection_model=yolo11_detection_model2, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + perform_standard_pred=False, + postprocess_type=postprocess_type, + postprocess_match_threshold=match_threshold, + postprocess_match_metric=match_metric, + postprocess_class_agnostic=class_agnostic, + ) + + values2.append(prediction_result2.durations_in_seconds['prediction']) + + assert np.mean(values1) < np.mean(values2) + print(f"Mean duration on batch: {np.mean(values1)}; Mean duration on single: {np.mean(values1)}") + @pytest.mark.skipif(sys.version_info[:2] != (3, 11), reason="MMDet tests only run on Python 3.11") def test_mmdet_yolox_tiny_prediction():