Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions sahi/models/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def set_model(self, model: Any, **kwargs):
category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
self.category_mapping = category_mapping

def perform_inference(self, image: np.ndarray):
def perform_inference(self, image: np.ndarray | list[np.ndarray]):
"""Prediction is performed using self.model and the prediction result is set to self._original_predictions.

Args:
Expand All @@ -81,7 +81,12 @@ def perform_inference(self, image: np.ndarray):
if self.image_size is not None:
kwargs = {"imgsz": self.image_size, **kwargs}

prediction_result = self.model(image[:, :, ::-1], **kwargs) # YOLO expects numpy arrays to have BGR
if isinstance(image, list):
image_bgr = [image_item[:, :, ::-1] for image_item in image] # YOLO expects numpy arrays to have BGR
else:
image_bgr = image[:, :, ::-1] # YOLO expects numpy arrays to have BGR

prediction_result = self.model(image_bgr, **kwargs)

# Handle different result types for PyTorch vs ONNX models
# ONNX models might return results in a different format
Expand Down Expand Up @@ -132,7 +137,10 @@ def perform_inference(self, image: np.ndarray):
prediction_result = [result.boxes.data for result in prediction_result]

self._original_predictions = prediction_result
self._original_shape = image.shape
if isinstance(image, list):
self._original_shape = [image_item.shape for image_item in image]
else:
self._original_shape = image.shape

@property
def category_names(self):
Expand Down Expand Up @@ -211,6 +219,10 @@ def _create_object_prediction_list_from_original_predictions(
shift_amount = shift_amount_list[image_ind]
full_shape = None if full_shape_list is None else full_shape_list[image_ind]
object_prediction_list = []
if isinstance(self._original_shape, list):
original_shape = self._original_shape[image_ind]
else:
original_shape = self._original_shape

# Extract boxes and optional masks/obb
if self.has_mask or self.is_obb:
Expand Down Expand Up @@ -248,7 +260,7 @@ def _create_object_prediction_list_from_original_predictions(
bool_mask = masks_or_points[pred_ind]
# Resize mask to original image size
bool_mask = cv2.resize(
bool_mask.astype(np.uint8), (self._original_shape[1], self._original_shape[0])
bool_mask.astype(np.uint8), (original_shape[1], original_shape[0])
)
segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
else: # is_obb
Expand All @@ -266,7 +278,7 @@ def _create_object_prediction_list_from_original_predictions(
segmentation=segmentation,
category_name=category_name,
shift_amount=shift_amount,
full_shape=self._original_shape[:2] if full_shape is None else full_shape, # (height, width)
full_shape=original_shape[:2] if full_shape is None else full_shape, # (height, width)
)
object_prediction_list.append(object_prediction)

Expand Down
90 changes: 66 additions & 24 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def get_sliced_prediction(
exclude_classes_by_id: list[int] | None = None,
progress_bar: bool = False,
progress_callback=None,
batch_size: int = 1,
) -> PredictionResult:
"""Function for slice image + get predicion for each slice + combine predictions in full image.

Expand Down Expand Up @@ -215,6 +216,8 @@ def get_sliced_prediction(
progress_callback: callable
A callback function that will be called after each slice is processed.
The function should accept two arguments: (current_slice, total_slices)
batch_size: int
Number of slices to process per inference batch. Default: 1.
Returns:
A Dict with fields:
object_prediction_list: a list of sahi.prediction.ObjectPrediction
Expand All @@ -224,8 +227,10 @@ def get_sliced_prediction(
# for profiling
durations_in_seconds = dict()

# currently only 1 batch supported
num_batch = 1
if batch_size < 1:
raise ValueError(f"batch_size should be greater than 0, got {batch_size}")

num_batch = batch_size
# create slices from full image
time_start = time.perf_counter()
slice_image_result = slice_image(
Expand Down Expand Up @@ -264,7 +269,7 @@ def get_sliced_prediction(
postprocess_time = 0
time_start = time.perf_counter()
# create prediction input
num_group = int(num_slices / num_batch)
num_group = (num_slices + num_batch - 1) // num_batch
if verbose == 1 or verbose == 2:
tqdm.write(f"Performing prediction on {num_slices} slices.")

Expand All @@ -275,29 +280,66 @@ def get_sliced_prediction(

object_prediction_list = []
# perform sliced prediction
for group_ind in slice_iterator:
# prepare batch (currently supports only 1 batch)
for group_index in slice_iterator:
# prepare batch
image_list = []
shift_amount_list = []
for image_ind in range(num_batch):
image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind])
# perform batch prediction
prediction_result = get_prediction(
image=image_list[0],
detection_model=detection_model,
shift_amount=shift_amount_list[0],
full_shape=[
slice_image_result.original_image_height,
slice_image_result.original_image_width,
],
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
group_start = group_index * num_batch
group_end = min(group_start + num_batch, num_slices)
for image_index in range(group_start, group_end):
image_list.append(slice_image_result.images[image_index])
shift_amount_list.append(slice_image_result.starting_pixels[image_index])

if len(image_list) == 1:
prediction_result = get_prediction(
image=image_list[0],
detection_model=detection_model,
shift_amount=shift_amount_list[0],
full_shape=[
slice_image_result.original_image_height,
slice_image_result.original_image_width,
],
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
object_prediction_list_per_image = [prediction_result.object_prediction_list]
else:
image_as_pil_list = [read_image_as_pil(image) for image in image_list]
full_shape_list = [
[slice_image_result.original_image_height, slice_image_result.original_image_width]
for _ in image_as_pil_list
]
try:
detection_model.perform_inference([np.ascontiguousarray(image_as_pil) for image_as_pil in image_as_pil_list])
detection_model.convert_original_predictions(
shift_amount=shift_amount_list,
full_shape=full_shape_list,
)
object_prediction_list_per_image = detection_model.object_prediction_list_per_image
except Exception as e:
logger.warning(f"Batch sliced inference failed, falling back to single inference per slice. Error: {e}")
object_prediction_list_per_image = []
for image_index, image in enumerate(image_list):
prediction_result = get_prediction(
image=image,
detection_model=detection_model,
shift_amount=shift_amount_list[image_index],
full_shape=full_shape_list[image_index],
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
object_prediction_list_per_image.append(prediction_result.object_prediction_list)

# convert sliced predictions to full predictions
for object_prediction in prediction_result.object_prediction_list:
if object_prediction: # if not empty
object_prediction_list.append(object_prediction.get_shifted_object_prediction())
for image_object_prediction_list in object_prediction_list_per_image:
image_object_prediction_list = filter_predictions(
image_object_prediction_list,
exclude_classes_by_name,
exclude_classes_by_id,
)
for object_prediction in image_object_prediction_list:
if object_prediction: # if not empty
object_prediction_list.append(object_prediction.get_shifted_object_prediction())

# merge matching predictions during sliced prediction
if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length:
Expand All @@ -307,7 +349,7 @@ def get_sliced_prediction(

# Call progress callback if provided
if progress_callback is not None:
progress_callback(group_ind + 1, num_group)
progress_callback(group_end, num_slices)

# perform standard prediction
if num_slices > 1 and perform_standard_pred:
Expand Down
80 changes: 80 additions & 0 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,86 @@ def test_get_sliced_prediction_yolo11():
assert num_car > 0


def _serialize_object_prediction_for_comparison(object_prediction):
segmentation = None
if object_prediction.mask is not None:
segmentation = [[round(float(value), 3) for value in segment] for segment in object_prediction.mask.segmentation]

return {
"category_id": object_prediction.category.id,
"category_name": object_prediction.category.name,
"bbox_xyxy": [round(float(value), 3) for value in object_prediction.bbox.to_xyxy()],
"score": round(float(object_prediction.score.value), 3),
"segmentation": segmentation,
}


def _serialize_prediction_list_for_comparison(object_prediction_list):
serialized_list = [_serialize_object_prediction_for_comparison(obj_pred) for obj_pred in object_prediction_list]
return sorted(
serialized_list,
key=lambda item: (
item["category_id"],
item["category_name"],
item["bbox_xyxy"][0],
item["bbox_xyxy"][1],
item["bbox_xyxy"][2],
item["bbox_xyxy"][3],
item["score"],
),
)


def test_get_sliced_prediction_batch_size_exact_output():
# init model
download_yolo11n_model()

yolo11_detection_model = UltralyticsDetectionModel(
model_path=UltralyticsConstants.YOLO11N_MODEL_PATH,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
category_remapping=None,
load_at_init=False,
image_size=IMAGE_SIZE,
)
yolo11_detection_model.load_model()

# prepare image
image_path = "tests/data/small-vehicles1.jpeg"

common_prediction_kwargs = {
"image": image_path,
"detection_model": yolo11_detection_model,
"slice_height": 512,
"slice_width": 512,
"overlap_height_ratio": 0.1,
"overlap_width_ratio": 0.2,
"perform_standard_pred": False,
"postprocess_type": "GREEDYNMM",
"postprocess_match_threshold": 0.5,
"postprocess_match_metric": "IOS",
"postprocess_class_agnostic": True,
}

prediction_result_batch_size_1 = get_sliced_prediction(batch_size=1, **common_prediction_kwargs)
prediction_result_batch_size_4 = get_sliced_prediction(batch_size=4, **common_prediction_kwargs)
prediction_result_batch_size_8 = get_sliced_prediction(batch_size=8, **common_prediction_kwargs)

serialized_predictions_batch_size_1 = _serialize_prediction_list_for_comparison(
prediction_result_batch_size_1.object_prediction_list
)
serialized_predictions_batch_size_4 = _serialize_prediction_list_for_comparison(
prediction_result_batch_size_4.object_prediction_list
)
serialized_predictions_batch_size_8 = _serialize_prediction_list_for_comparison(
prediction_result_batch_size_8.object_prediction_list
)

assert len(serialized_predictions_batch_size_1) > 0
assert serialized_predictions_batch_size_4 == serialized_predictions_batch_size_1
assert serialized_predictions_batch_size_8 == serialized_predictions_batch_size_1


@pytest.mark.skipif(sys.version_info[:2] != (3, 11), reason="MMDet tests only run on Python 3.11")
def test_mmdet_yolox_tiny_prediction():
# Skip if mmdet is not installed
Expand Down