Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions sahi/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

class DetectionModel:
required_packages: list[str] | None = None
_object_prediction_list_per_image: list[list[ObjectPrediction]]
_shift_amount_indices_per_prediction: list[int]
_original_shape: np._AnyShapeT

def __init__(
self,
Expand Down Expand Up @@ -60,6 +63,7 @@ def __init__(
self._original_predictions = None
self._object_prediction_list_per_image = None
self.set_device(device)
self._shift_amount_indices_per_prediction = None

# automatically ensure dependencies
self.check_dependencies()
Expand Down Expand Up @@ -121,6 +125,16 @@ def perform_inference(self, image: np.ndarray):
A numpy array that contains the image to be predicted.
"""
raise NotImplementedError()

def perform_per_image_batch_inference(self, image_list: np.ndarray | list[np.ndarray], order: str = "RGB") -> None:
"""This function should be implemented in a way that prediction should be performed using self.model and the
prediction result should be set to self._original_predictions.

Args:
image: np.ndarray or list
A numpy array that contains the image to be predicted or a list of numpy arrays.
"""
raise NotImplementedError()

def _create_object_prediction_list_from_original_predictions(
self,
Expand Down Expand Up @@ -196,3 +210,7 @@ def object_prediction_list_per_image(self) -> list[list[ObjectPrediction]]:
@property
def original_predictions(self):
return self._original_predictions

@property
def shift_amount_indices_per_prediction(self):
return self._shift_amount_indices_per_prediction
125 changes: 119 additions & 6 deletions sahi/models/ultralytics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, List

import cv2
import numpy as np
Expand Down Expand Up @@ -61,6 +61,101 @@ def set_model(self, model: Any, **kwargs):
if not self.category_mapping:
category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
self.category_mapping = category_mapping

def perform_per_image_batch_inference(self, image_list: list[np.ndarray], order: str = None):
"""Prediction is performed on a list of images using self.model and the prediction result is set to self._original_predictions.

Args:
image: list[np.ndarray]
A list of numpy array of shapes (H, W, C) that contains the images to be predicted.
order: str, optional
The order of the image.

"""

# Confirm model is loaded

import torch
from ultralytics.engine.results import Results
from ultralytics import YOLO

if order == "RGB":
image_list = [cv2.cvtColor(image, cv2.COLOR_RGB2BGR) for image in image_list]

if self.model is None:
raise ValueError("Model is not loaded, load it by calling .load_model()")

kwargs = {"cfg": self.config_path, "verbose": False, "conf": self.confidence_threshold, "device": self.device, "batch": len(image_list)}

if self.image_size is not None:
kwargs = {"imgsz": self.image_size, **kwargs}

model: YOLO = self.model

prediction_result: List[Results] = model(image_list, **kwargs) # YOLO expects numpy arrays to have BGR

temp_results_list: torch.Tensor | np.ndarray[np._AnyShape, np.dtype[Any]] = []
temp_shift_idxs: list[int] = []

# Handle different result types for PyTorch vs ONNX models
# ONNX models might return results in a different format
for idx, image_result in enumerate(prediction_result):
processed_result = None
if self.has_mask:
from ultralytics.engine.results import Masks

if not image_result.masks:
# Create empty masks if none exist
if hasattr(self.model, "device"):
device = self.model.device
else:
device = "cpu" # Default for ONNX models
image_result.masks = Masks(
torch.tensor([], device=device), image_result.boxes.orig_shape
)

# We do not filter results again as confidence threshold is already applied above
processed_result = [
(
result.boxes.data,
result.masks.data,
)
for result in image_result
]

elif self.is_obb:
# For OBB task, get OBB points in xyxyxyxy format
device = getattr(self.model, "device", "cpu")
processed_result = [
(
# Get OBB data: xyxy, conf, cls
torch.cat(
[
result.obb.xyxy, # box coordinates
result.obb.conf.unsqueeze(-1), # confidence scores
result.obb.cls.unsqueeze(-1), # class ids
],
dim=1,
)
if result.obb is not None
else torch.empty((0, 6), device=device),
# Get OBB points in (N, 4, 2) format
result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=device),
)
for result in image_result
]
else: # If model doesn't do segmentation or OBB then no need to check masks
# We do not filter results again as confidence threshold is already applied above
processed_result = [result.boxes.data for result in prediction_result]

# Save shift indexes
temp_shift_idxs.extend([idx] * len(processed_result))
temp_results_list.extend(processed_result)

self._original_predictions = temp_results_list
self._original_shape = image_list[0].shape
# Save shift indexes
self._shift_amount_indices_per_prediction = temp_shift_idxs

def perform_inference(self, image: np.ndarray):
"""Prediction is performed using self.model and the prediction result is set to self._original_predictions.
Expand All @@ -73,6 +168,7 @@ def perform_inference(self, image: np.ndarray):
# Confirm model is loaded

import torch
from ultralytics.engine.results import Results

if self.model is None:
raise ValueError("Model is not loaded, load it by calling .load_model()")
Expand All @@ -82,7 +178,7 @@ def perform_inference(self, image: np.ndarray):
if self.image_size is not None:
kwargs = {"imgsz": self.image_size, **kwargs}

prediction_result = self.model(image[:, :, ::-1], **kwargs) # YOLO expects numpy arrays to have BGR
prediction_result: List[Results] | torch.Tensor | np.ndarray[np._AnyShape, np.dtype[Any]] = self.model(image[:, :, ::-1], **kwargs) # YOLO expects numpy arrays to have BGR

# Handle different result types for PyTorch vs ONNX models
# ONNX models might return results in a different format
Expand Down Expand Up @@ -199,19 +295,36 @@ def _create_object_prediction_list_from_original_predictions(
Size of the full image after shifting, should be in the form of
List[[height, width],[height, width],...]
"""
use_indices = False
print(self._shift_amount_indices_per_prediction)
if len(self._original_predictions) != len(shift_amount_list):
if self._shift_amount_indices_per_prediction is None:
raise ValueError(
f"Number of predictions ({len(self._original_predictions)}) and shift_amount_list ({len(shift_amount_list)}) do not match."
)
elif len(self._shift_amount_indices_per_prediction) == len(self._original_predictions):
use_indices = True
else:
raise ValueError(
f"Number of predictions ({len(self._original_predictions)}) and shift_amount_indices_per_prediction ({len(self._shift_amount_indices_per_prediction)}) do not match."
)

original_predictions = self._original_predictions

# compatibility for sahi v0.8.15
shift_amount_list = fix_shift_amount_list(shift_amount_list)
full_shape_list = fix_full_shape_list(full_shape_list)

# handle all predictions
object_prediction_list_per_image = []
object_prediction_list_per_image: List[List[ObjectPrediction]] = []

for image_ind, image_predictions in enumerate(original_predictions):
shift_amount = shift_amount_list[image_ind]
full_shape = None if full_shape_list is None else full_shape_list[image_ind]
object_prediction_list = []
if use_indices:
shift_amount = shift_amount_list[self._shift_amount_indices_per_prediction[image_ind]]
else:
shift_amount = shift_amount_list[image_ind]
full_shape = None if full_shape_list is None or image_ind >= len(full_shape_list) else full_shape_list[image_ind]
object_prediction_list: List[ObjectPrediction] = []

# Extract boxes and optional masks/obb
if self.has_mask or self.is_obb:
Expand Down
Loading