diff --git a/sahi/models/ultralytics.py b/sahi/models/ultralytics.py index 8538cfd96..740badfa2 100644 --- a/sahi/models/ultralytics.py +++ b/sahi/models/ultralytics.py @@ -61,7 +61,7 @@ def set_model(self, model: Any, **kwargs): category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)} self.category_mapping = category_mapping - def perform_inference(self, image: np.ndarray): + def perform_inference(self, image: np.ndarray | list[np.ndarray]): """Prediction is performed using self.model and the prediction result is set to self._original_predictions. Args: @@ -81,7 +81,12 @@ def perform_inference(self, image: np.ndarray): if self.image_size is not None: kwargs = {"imgsz": self.image_size, **kwargs} - prediction_result = self.model(image[:, :, ::-1], **kwargs) # YOLO expects numpy arrays to have BGR + if isinstance(image, list): + image_bgr = [image_item[:, :, ::-1] for image_item in image] # YOLO expects numpy arrays to have BGR + else: + image_bgr = image[:, :, ::-1] # YOLO expects numpy arrays to have BGR + + prediction_result = self.model(image_bgr, **kwargs) # Handle different result types for PyTorch vs ONNX models # ONNX models might return results in a different format @@ -132,7 +137,10 @@ def perform_inference(self, image: np.ndarray): prediction_result = [result.boxes.data for result in prediction_result] self._original_predictions = prediction_result - self._original_shape = image.shape + if isinstance(image, list): + self._original_shape = [image_item.shape for image_item in image] + else: + self._original_shape = image.shape @property def category_names(self): @@ -211,6 +219,10 @@ def _create_object_prediction_list_from_original_predictions( shift_amount = shift_amount_list[image_ind] full_shape = None if full_shape_list is None else full_shape_list[image_ind] object_prediction_list = [] + if isinstance(self._original_shape, list): + original_shape = self._original_shape[image_ind] + else: + original_shape = self._original_shape # Extract boxes and optional masks/obb if self.has_mask or self.is_obb: @@ -248,7 +260,7 @@ def _create_object_prediction_list_from_original_predictions( bool_mask = masks_or_points[pred_ind] # Resize mask to original image size bool_mask = cv2.resize( - bool_mask.astype(np.uint8), (self._original_shape[1], self._original_shape[0]) + bool_mask.astype(np.uint8), (original_shape[1], original_shape[0]) ) segmentation = get_coco_segmentation_from_bool_mask(bool_mask) else: # is_obb @@ -266,7 +278,7 @@ def _create_object_prediction_list_from_original_predictions( segmentation=segmentation, category_name=category_name, shift_amount=shift_amount, - full_shape=self._original_shape[:2] if full_shape is None else full_shape, # (height, width) + full_shape=original_shape[:2] if full_shape is None else full_shape, # (height, width) ) object_prediction_list.append(object_prediction) diff --git a/sahi/predict.py b/sahi/predict.py index 6a16c2a51..e206719f8 100644 --- a/sahi/predict.py +++ b/sahi/predict.py @@ -156,6 +156,7 @@ def get_sliced_prediction( exclude_classes_by_id: list[int] | None = None, progress_bar: bool = False, progress_callback=None, + batch_size: int = 1, ) -> PredictionResult: """Function for slice image + get predicion for each slice + combine predictions in full image. @@ -215,6 +216,8 @@ def get_sliced_prediction( progress_callback: callable A callback function that will be called after each slice is processed. The function should accept two arguments: (current_slice, total_slices) + batch_size: int + Number of slices to process per inference batch. Default: 1. Returns: A Dict with fields: object_prediction_list: a list of sahi.prediction.ObjectPrediction @@ -224,8 +227,10 @@ def get_sliced_prediction( # for profiling durations_in_seconds = dict() - # currently only 1 batch supported - num_batch = 1 + if batch_size < 1: + raise ValueError(f"batch_size should be greater than 0, got {batch_size}") + + num_batch = batch_size # create slices from full image time_start = time.perf_counter() slice_image_result = slice_image( @@ -264,7 +269,7 @@ def get_sliced_prediction( postprocess_time = 0 time_start = time.perf_counter() # create prediction input - num_group = int(num_slices / num_batch) + num_group = (num_slices + num_batch - 1) // num_batch if verbose == 1 or verbose == 2: tqdm.write(f"Performing prediction on {num_slices} slices.") @@ -275,29 +280,66 @@ def get_sliced_prediction( object_prediction_list = [] # perform sliced prediction - for group_ind in slice_iterator: - # prepare batch (currently supports only 1 batch) + for group_index in slice_iterator: + # prepare batch image_list = [] shift_amount_list = [] - for image_ind in range(num_batch): - image_list.append(slice_image_result.images[group_ind * num_batch + image_ind]) - shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind]) - # perform batch prediction - prediction_result = get_prediction( - image=image_list[0], - detection_model=detection_model, - shift_amount=shift_amount_list[0], - full_shape=[ - slice_image_result.original_image_height, - slice_image_result.original_image_width, - ], - exclude_classes_by_name=exclude_classes_by_name, - exclude_classes_by_id=exclude_classes_by_id, - ) + group_start = group_index * num_batch + group_end = min(group_start + num_batch, num_slices) + for image_index in range(group_start, group_end): + image_list.append(slice_image_result.images[image_index]) + shift_amount_list.append(slice_image_result.starting_pixels[image_index]) + + if len(image_list) == 1: + prediction_result = get_prediction( + image=image_list[0], + detection_model=detection_model, + shift_amount=shift_amount_list[0], + full_shape=[ + slice_image_result.original_image_height, + slice_image_result.original_image_width, + ], + exclude_classes_by_name=exclude_classes_by_name, + exclude_classes_by_id=exclude_classes_by_id, + ) + object_prediction_list_per_image = [prediction_result.object_prediction_list] + else: + image_as_pil_list = [read_image_as_pil(image) for image in image_list] + full_shape_list = [ + [slice_image_result.original_image_height, slice_image_result.original_image_width] + for _ in image_as_pil_list + ] + try: + detection_model.perform_inference([np.ascontiguousarray(image_as_pil) for image_as_pil in image_as_pil_list]) + detection_model.convert_original_predictions( + shift_amount=shift_amount_list, + full_shape=full_shape_list, + ) + object_prediction_list_per_image = detection_model.object_prediction_list_per_image + except Exception as e: + logger.warning(f"Batch sliced inference failed, falling back to single inference per slice. Error: {e}") + object_prediction_list_per_image = [] + for image_index, image in enumerate(image_list): + prediction_result = get_prediction( + image=image, + detection_model=detection_model, + shift_amount=shift_amount_list[image_index], + full_shape=full_shape_list[image_index], + exclude_classes_by_name=exclude_classes_by_name, + exclude_classes_by_id=exclude_classes_by_id, + ) + object_prediction_list_per_image.append(prediction_result.object_prediction_list) + # convert sliced predictions to full predictions - for object_prediction in prediction_result.object_prediction_list: - if object_prediction: # if not empty - object_prediction_list.append(object_prediction.get_shifted_object_prediction()) + for image_object_prediction_list in object_prediction_list_per_image: + image_object_prediction_list = filter_predictions( + image_object_prediction_list, + exclude_classes_by_name, + exclude_classes_by_id, + ) + for object_prediction in image_object_prediction_list: + if object_prediction: # if not empty + object_prediction_list.append(object_prediction.get_shifted_object_prediction()) # merge matching predictions during sliced prediction if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length: @@ -307,7 +349,7 @@ def get_sliced_prediction( # Call progress callback if provided if progress_callback is not None: - progress_callback(group_ind + 1, num_group) + progress_callback(group_end, num_slices) # perform standard prediction if num_slices > 1 and perform_standard_pred: diff --git a/tests/test_predict.py b/tests/test_predict.py index f9f87c472..fda1920f5 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -343,6 +343,86 @@ def test_get_sliced_prediction_yolo11(): assert num_car > 0 +def _serialize_object_prediction_for_comparison(object_prediction): + segmentation = None + if object_prediction.mask is not None: + segmentation = [[round(float(value), 3) for value in segment] for segment in object_prediction.mask.segmentation] + + return { + "category_id": object_prediction.category.id, + "category_name": object_prediction.category.name, + "bbox_xyxy": [round(float(value), 3) for value in object_prediction.bbox.to_xyxy()], + "score": round(float(object_prediction.score.value), 3), + "segmentation": segmentation, + } + + +def _serialize_prediction_list_for_comparison(object_prediction_list): + serialized_list = [_serialize_object_prediction_for_comparison(obj_pred) for obj_pred in object_prediction_list] + return sorted( + serialized_list, + key=lambda item: ( + item["category_id"], + item["category_name"], + item["bbox_xyxy"][0], + item["bbox_xyxy"][1], + item["bbox_xyxy"][2], + item["bbox_xyxy"][3], + item["score"], + ), + ) + + +def test_get_sliced_prediction_batch_size_exact_output(): + # init model + download_yolo11n_model() + + yolo11_detection_model = UltralyticsDetectionModel( + model_path=UltralyticsConstants.YOLO11N_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=False, + image_size=IMAGE_SIZE, + ) + yolo11_detection_model.load_model() + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + + common_prediction_kwargs = { + "image": image_path, + "detection_model": yolo11_detection_model, + "slice_height": 512, + "slice_width": 512, + "overlap_height_ratio": 0.1, + "overlap_width_ratio": 0.2, + "perform_standard_pred": False, + "postprocess_type": "GREEDYNMM", + "postprocess_match_threshold": 0.5, + "postprocess_match_metric": "IOS", + "postprocess_class_agnostic": True, + } + + prediction_result_batch_size_1 = get_sliced_prediction(batch_size=1, **common_prediction_kwargs) + prediction_result_batch_size_4 = get_sliced_prediction(batch_size=4, **common_prediction_kwargs) + prediction_result_batch_size_8 = get_sliced_prediction(batch_size=8, **common_prediction_kwargs) + + serialized_predictions_batch_size_1 = _serialize_prediction_list_for_comparison( + prediction_result_batch_size_1.object_prediction_list + ) + serialized_predictions_batch_size_4 = _serialize_prediction_list_for_comparison( + prediction_result_batch_size_4.object_prediction_list + ) + serialized_predictions_batch_size_8 = _serialize_prediction_list_for_comparison( + prediction_result_batch_size_8.object_prediction_list + ) + + assert len(serialized_predictions_batch_size_1) > 0 + assert serialized_predictions_batch_size_4 == serialized_predictions_batch_size_1 + assert serialized_predictions_batch_size_8 == serialized_predictions_batch_size_1 + + @pytest.mark.skipif(sys.version_info[:2] != (3, 11), reason="MMDet tests only run on Python 3.11") def test_mmdet_yolox_tiny_prediction(): # Skip if mmdet is not installed