diff --git a/inference/models/sam3/visual_segmentation.py b/inference/models/sam3/visual_segmentation.py index e8506fbb8e..8622055a0f 100644 --- a/inference/models/sam3/visual_segmentation.py +++ b/inference/models/sam3/visual_segmentation.py @@ -448,11 +448,14 @@ def find_prior_prompt_in_cache( """ Performs search over the cache to see if prior used prompts are subset of this one. """ + num_points = initial_prompt_set.num_points() + if num_points <= 1: + return None # there is only 1 point, hence no prior prompt can be found logits_for_image = [cache[k] for k in cache if k[0] == image_id] maxed_size = 0 best_match: Optional[np.ndarray] = None - desired_size = initial_prompt_set.num_points() - 1 + desired_size = num_points - 1 for cached_dict in logits_for_image[::-1]: logits = cached_dict["logits"] prompt_set: Sam2PromptSet = cached_dict["prompt_set"] diff --git a/inference_models/inference_models/models/sam3/LICENSE.txt b/inference_models/inference_models/models/sam3/LICENSE.txt new file mode 100644 index 0000000000..16ee5f6318 --- /dev/null +++ b/inference_models/inference_models/models/sam3/LICENSE.txt @@ -0,0 +1,61 @@ +SAM License +Last Updated: November 19, 2025 + +“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein. + + +“SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement. + +“Documentation” means the specifications, manuals and documentation accompanying +SAM Materials distributed by Meta. + + +“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. + + +“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland). + + +“Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom. + + +“Trade Controls” means any of the following: Sanctions and applicable export and import controls. + +By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement. + + +1. License Rights and Redistribution. + + +a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials. + +b. Redistribution and Use. +i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials. + + +ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication. + + +iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws. +iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials. +v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons. +2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind. + + +3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS. + +4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. + +5. Intellectual Property. + + +a. Subject to Meta’s ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications. + +b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials. + +6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement. + +7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. + + +8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta. \ No newline at end of file diff --git a/inference_models/inference_models/models/sam3/README.md b/inference_models/inference_models/models/sam3/README.md new file mode 100644 index 0000000000..45f61c402d --- /dev/null +++ b/inference_models/inference_models/models/sam3/README.md @@ -0,0 +1,177 @@ +This model uses [SAM 3 (Segment Anything 3)](https://github.com/facebookresearch/sam3) from Meta. + +# Instance Segmentation with Box Prompts + +```python +import cv2 as cv +import numpy as np +import supervision as sv + +from inference_models import AutoModel + +model = AutoModel.from_pretrained("sam3") + +mask_annotator = sv.MaskAnnotator() +box_annotator = sv.BoxAnnotator(color=sv.Color.BLACK) + +img = cv.imread("image.png") +predictions = model.segment_images( + images=img, + boxes=[[(100, 200, 300, 400)]], # xyxy format +) + +masks = predictions[0].masks.cpu().numpy() +detections = sv.Detections( + xyxy=sv.mask_to_xyxy(masks=masks), + mask=masks, +) + +annotated_frame = mask_annotator.annotate(scene=img, detections=detections) +annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections) + +cv.imshow("", annotated_frame) +cv.waitKey(0) +``` + +# Instance Segmentation with Point Prompts + +```python +import cv2 as cv +import numpy as np +import supervision as sv + +from inference_models import AutoModel + +model = AutoModel.from_pretrained("sam3") + +mask_annotator = sv.MaskAnnotator() + +img = cv.imread("image.png") +predictions = model.segment_images( + images=img, + point_coordinates=[[[250, 300]]], # xy format + point_labels=[[[1]]], # 1 = foreground, 0 = background +) + +masks = predictions[0].masks.cpu().numpy() +detections = sv.Detections( + xyxy=sv.mask_to_xyxy(masks=masks), + mask=masks, +) + +annotated_frame = mask_annotator.annotate(scene=img, detections=detections) + +cv.imshow("", annotated_frame) +cv.waitKey(0) +``` + +# Text-Prompted Segmentation + +SAM3 supports text-based prompting to segment objects described in natural language. + +```python +import cv2 as cv +import numpy as np +import supervision as sv + +from inference_models import AutoModel + +model = AutoModel.from_pretrained("sam3") + +mask_annotator = sv.MaskAnnotator() + +img = cv.imread("image.png") +results = model.segment_with_text( + images=img, + prompts=[ + {"text": "person"}, + {"text": "dog"}, + ], +) + +# Process results for each prompt +for prompt_result in results[0]: + masks = prompt_result["masks"] + scores = prompt_result["scores"] + + if len(masks) > 0: + detections = sv.Detections( + xyxy=sv.mask_to_xyxy(masks=masks), + mask=masks, + ) + img = mask_annotator.annotate(scene=img, detections=detections) + +cv.imshow("", img) +cv.waitKey(0) +``` + +# Visual Prompting with Text + +Combine bounding box prompts with text descriptions for more precise segmentation. + +```python +import cv2 as cv +import numpy as np +import supervision as sv + +from inference_models import AutoModel + +model = AutoModel.from_pretrained("sam3") + +mask_annotator = sv.MaskAnnotator() + +img = cv.imread("image.png") +results = model.segment_with_text( + images=img, + prompts=[ + { + "text": "shirt", + "boxes": [[100, 150, 300, 400]], # xyxy format + "box_labels": [1], # 1 = positive, 0 = negative + }, + ], +) + +masks = results[0][0]["masks"] +if len(masks) > 0: + detections = sv.Detections( + xyxy=sv.mask_to_xyxy(masks=masks), + mask=masks, + ) + img = mask_annotator.annotate(scene=img, detections=detections) + +cv.imshow("", img) +cv.waitKey(0) +``` + +# Embeddings Caching + +For interactive applications, you can cache image embeddings to speed up subsequent predictions. + +```python +import cv2 as cv +from inference_models import AutoModel +from inference_models.models.sam3.cache import Sam3ImageEmbeddingsInMemoryCache + +# Initialize cache +embeddings_cache = Sam3ImageEmbeddingsInMemoryCache.init(size_limit=10) + +model = AutoModel.from_pretrained( + "sam3", + sam3_image_embeddings_cache=embeddings_cache, +) + +img = cv.imread("image.png") + +# First call computes and caches embeddings +predictions = model.segment_images( + images=img, + boxes=[[(100, 200, 300, 400)]], +) + +# Subsequent calls with same image use cached embeddings +predictions = model.segment_images( + images=img, + boxes=[[(150, 250, 350, 450)]], # Different prompt, same image +) +``` diff --git a/inference_models/inference_models/models/sam3/__init__.py b/inference_models/inference_models/models/sam3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference_models/inference_models/models/sam3/cache.py b/inference_models/inference_models/models/sam3/cache.py new file mode 100644 index 0000000000..4a24a6e637 --- /dev/null +++ b/inference_models/inference_models/models/sam3/cache.py @@ -0,0 +1,162 @@ +from abc import ABC, abstractmethod +from collections import OrderedDict, defaultdict +from threading import Lock +from typing import DefaultDict, List, Optional + +import torch + +from inference_models.errors import EnvironmentConfigurationError +from inference_models.models.sam3.entities import ( + SAM3ImageEmbeddings, + SAM3MaskCacheEntry, +) + + +class Sam3ImageEmbeddingsCache(ABC): + + @abstractmethod + def retrieve_embeddings(self, key: str) -> Optional[SAM3ImageEmbeddings]: + pass + + @abstractmethod + def save_embeddings(self, key: str, embeddings: SAM3ImageEmbeddings) -> None: + pass + + +class Sam3ImageEmbeddingsCacheNullObject(Sam3ImageEmbeddingsCache): + + def retrieve_embeddings(self, key: str) -> Optional[SAM3ImageEmbeddings]: + pass + + def save_embeddings(self, key: str, embeddings: SAM3ImageEmbeddings) -> None: + pass + + +class Sam3ImageEmbeddingsInMemoryCache(Sam3ImageEmbeddingsCache): + + @classmethod + def init( + cls, size_limit: Optional[int], send_to_cpu: bool = True + ) -> "Sam3ImageEmbeddingsInMemoryCache": + return cls( + state=OrderedDict(), + size_limit=size_limit, + send_to_cpu=send_to_cpu, + ) + + def __init__( + self, + state: OrderedDict, + size_limit: Optional[int], + send_to_cpu: bool = True, + ): + self._state = state + self._size_limit = size_limit + self._send_to_cpu = send_to_cpu + self._state_lock = Lock() + + def retrieve_embeddings(self, key: str) -> Optional[SAM3ImageEmbeddings]: + return self._state.get(key) + + def save_embeddings(self, key: str, embeddings: SAM3ImageEmbeddings) -> None: + with self._state_lock: + if key in self._state: + return None + self._ensure_cache_has_capacity() + if self._send_to_cpu: + embeddings = embeddings.to(device=torch.device("cpu")) + self._state[key] = embeddings + + def _ensure_cache_has_capacity(self): + if self._size_limit is None: + return + if self._size_limit < 1: + raise EnvironmentConfigurationError( + message=f"In memory cache size for SAM3 embeddings was set to invalid value. " + f"If you are running inference locally - adjust settings of your deployment. If you see this " + f"error running on Roboflow platform - contact us to get help.", + help_url="https://todo", + ) + while len(self._state) > self._size_limit: + _ = self._state.popitem(last=False) + + +class Sam3LowResolutionMasksCache(ABC): + + @abstractmethod + def retrieve_all_masks_for_image(self, key: str) -> List[SAM3MaskCacheEntry]: + pass + + @abstractmethod + def save_mask(self, key: str, mask: SAM3MaskCacheEntry) -> None: + pass + + +class Sam3LowResolutionMasksCacheNullObject(Sam3LowResolutionMasksCache): + + def retrieve_all_masks_for_image(self, key: str) -> List[SAM3MaskCacheEntry]: + return [] + + def save_mask(self, key: str, mask: SAM3MaskCacheEntry) -> None: + pass + + +class Sam3LowResolutionMasksInMemoryCache(Sam3LowResolutionMasksCache): + + @classmethod + def init( + cls, size_limit: Optional[int], send_to_cpu: bool = True + ) -> "Sam3LowResolutionMasksInMemoryCache": + return cls( + ordering_state=OrderedDict(), + cache_state=defaultdict(list), + size_limit=size_limit, + send_to_cpu=send_to_cpu, + ) + + def __init__( + self, + ordering_state: OrderedDict, + cache_state: DefaultDict[str, List[SAM3MaskCacheEntry]], + size_limit: Optional[int], + send_to_cpu: bool = True, + ): + self._ordering_state = ordering_state + self._cache_state = cache_state + self._size_limit = size_limit + self._send_to_cpu = send_to_cpu + self._state_lock = Lock() + + def retrieve_all_masks_for_image(self, key: str) -> List[SAM3MaskCacheEntry]: + return self._cache_state.get(key, []) + + def save_mask(self, key: str, mask: SAM3MaskCacheEntry) -> None: + with self._state_lock: + if (key, mask.prompt_hash) in self._ordering_state: + return None + self._ensure_cache_has_capacity() + if self._send_to_cpu: + mask = mask.to(device=torch.device("cpu")) + self._ordering_state[(key, mask.prompt_hash)] = True + self._cache_state[key].append(mask) + + def _ensure_cache_has_capacity(self): + if self._size_limit is None: + return + if self._size_limit < 1: + raise EnvironmentConfigurationError( + message=f"In memory cache size for SAM3 low resolution masks was set to invalid value. " + f"If you are running inference locally - adjust settings of your deployment. If you see this " + f"error running on Roboflow platform - contact us to get help.", + help_url="https://todo", + ) + while len(self._ordering_state) > self._size_limit: + image_key, prompt_hash = self._ordering_state.popitem(last=False) + entries_for_image = self._cache_state[image_key] + to_remove_idx = None + for i, element in enumerate(entries_for_image): + if element.prompt_hash == prompt_hash: + to_remove_idx = i + break + if to_remove_idx is not None: + del entries_for_image[to_remove_idx] diff --git a/inference_models/inference_models/models/sam3/entities.py b/inference_models/inference_models/models/sam3/entities.py new file mode 100644 index 0000000000..fce751eda0 --- /dev/null +++ b/inference_models/inference_models/models/sam3/entities.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch + + +@dataclass(frozen=True) +class SAM3ImageEmbeddings: + image_hash: str + image_size_hw: Tuple[int, int] + embeddings: Dict[str, Any] + + def to(self, device: torch.device) -> "SAM3ImageEmbeddings": + def _move_to_device(obj: Any) -> Any: + if isinstance(obj, torch.Tensor): + return obj.to(device=device) + elif isinstance(obj, dict): + return {k: _move_to_device(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_move_to_device(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(_move_to_device(item) for item in obj) + return obj + + return SAM3ImageEmbeddings( + image_hash=self.image_hash, + image_size_hw=self.image_size_hw, + embeddings=_move_to_device(self.embeddings), + ) + + +@dataclass(frozen=True) +class SAM3Prediction: + masks: torch.Tensor + scores: torch.Tensor + logits: torch.Tensor + + +@dataclass(frozen=True) +class SAM3MaskCacheEntry: + prompt_hash: str + serialized_prompt: List[dict] + mask: torch.Tensor + + def to(self, device: torch.device) -> "SAM3MaskCacheEntry": + return SAM3MaskCacheEntry( + prompt_hash=self.prompt_hash, + serialized_prompt=self.serialized_prompt, + mask=self.mask.to(device=device), + ) diff --git a/inference_models/inference_models/models/sam3/sam3_torch.py b/inference_models/inference_models/models/sam3/sam3_torch.py new file mode 100644 index 0000000000..05b9c124ca --- /dev/null +++ b/inference_models/inference_models/models/sam3/sam3_torch.py @@ -0,0 +1,925 @@ +import hashlib +import json +from copy import copy +from typing import Dict, Generator, List, Optional, Tuple, TypeVar, Union + +import numpy as np +import torch +from PIL import Image +from sam3 import build_sam3_image_model +from sam3.eval.postprocessors import PostProcessImage +from sam3.model.sam3_image_processor import Sam3Processor +from sam3.model.utils.misc import copy_data_to_device +from sam3.train.data.collator import collate_fn_api +from sam3.train.data.sam3_image_dataset import Datapoint as Sam3Datapoint +from sam3.train.data.sam3_image_dataset import FindQueryLoaded +from sam3.train.data.sam3_image_dataset import Image as Sam3ImageDP +from sam3.train.data.sam3_image_dataset import InferenceMetadata +from sam3.train.transforms.basic_for_api import ( + ComposeAPI, + NormalizeAPI, + RandomResizeAPI, + ToTensorAPI, +) + +from inference_models.configuration import DEFAULT_DEVICE +from inference_models.errors import ( + CorruptedModelPackageError, + ModelInputError, +) +from inference_models.models.common.model_packages import get_model_package_contents +from inference_models.models.sam3.cache import ( + Sam3ImageEmbeddingsCache, + Sam3ImageEmbeddingsCacheNullObject, + Sam3LowResolutionMasksCache, + Sam3LowResolutionMasksCacheNullObject, +) +from inference_models.models.sam3.entities import ( + SAM3ImageEmbeddings, + SAM3MaskCacheEntry, + SAM3Prediction, +) +from inference_models.utils.file_system import read_json + +ArrayOrTensor = Union[np.ndarray, torch.Tensor] +T = TypeVar("T") + +MAX_SAM3_BATCH_SIZE = 8 +DEFAULT_SAM3_IMAGE_SIZE = 1024 + +SUPPORTED_VERSIONS = { + "sam3_final", +} + + +class SAM3Torch: + @classmethod + def from_pretrained( + cls, + model_name_or_path: str, + device: torch.device = DEFAULT_DEVICE, + max_batch_size: int = MAX_SAM3_BATCH_SIZE, + image_size: int = DEFAULT_SAM3_IMAGE_SIZE, + sam3_image_embeddings_cache: Optional[Sam3ImageEmbeddingsCache] = None, + sam3_low_resolution_masks_cache: Optional[Sam3LowResolutionMasksCache] = None, + compile_model: bool = False, + enable_inst_interactivity: bool = True, + **kwargs, + ) -> "SAM3Torch": + if sam3_image_embeddings_cache is None: + sam3_image_embeddings_cache = Sam3ImageEmbeddingsCacheNullObject() + if sam3_low_resolution_masks_cache is None: + sam3_low_resolution_masks_cache = Sam3LowResolutionMasksCacheNullObject() + + model_package_content = get_model_package_contents( + model_package_dir=model_name_or_path, + elements=[ + "weights.pt", + "bpe_simple_vocab_16e6.txt.gz", + ], + ) + + try: + config_content = get_model_package_contents( + model_package_dir=model_name_or_path, + elements=["sam_configuration.json"], + ) + version = decode_sam_version( + config_path=config_content["sam_configuration.json"] + ) + if version not in SUPPORTED_VERSIONS: + raise CorruptedModelPackageError( + message=f"Detected unsupported version of SAM3 model: {version}. Supported versions: " + f"are {SUPPORTED_VERSIONS}. If you run inference locally, verify the correctness of " + f"SAM3 model package. If you see the error running on Roboflow platform - " + "contact us to get help.", + help_url="https://todo", + ) + except KeyError: + pass + + device_str = "cuda" if device.type == "cuda" else "cpu" + sam3_model = build_sam3_image_model( + bpe_path=model_package_content["bpe_simple_vocab_16e6.txt.gz"], + checkpoint_path=model_package_content["weights.pt"], + device=device_str, + load_from_HF=False, + compile=compile_model, + enable_inst_interactivity=enable_inst_interactivity, + ) + + transform = ComposeAPI( + transforms=[ + RandomResizeAPI( + sizes=image_size, + max_size=image_size, + square=True, + consistent_transform=False, + ), + ToTensorAPI(), + NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + + return cls( + model=sam3_model, + transform=transform, + device=device, + max_batch_size=max_batch_size, + image_size=image_size, + sam3_image_embeddings_cache=sam3_image_embeddings_cache, + sam3_low_resolution_masks_cache=sam3_low_resolution_masks_cache, + enable_inst_interactivity=enable_inst_interactivity, + ) + + def __init__( + self, + model, + transform: ComposeAPI, + device: torch.device, + max_batch_size: int, + image_size: int, + sam3_image_embeddings_cache: Sam3ImageEmbeddingsCache, + sam3_low_resolution_masks_cache: Sam3LowResolutionMasksCache, + enable_inst_interactivity: bool = True, + ): + self._model = model + self._transform = transform + self._device = device + self._max_batch_size = max_batch_size + self._image_size = image_size + self._sam3_image_embeddings_cache = sam3_image_embeddings_cache + self._sam3_low_resolution_masks_cache = sam3_low_resolution_masks_cache + self._enable_inst_interactivity = enable_inst_interactivity + + def embed_images( + self, + images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]], + use_embeddings_cache: bool = True, + **kwargs, + ) -> List[SAM3ImageEmbeddings]: + images_list = maybe_wrap_in_list(images) + if images_list is None: + raise ModelInputError( + message="No images provided to embed_images()", + help_url="https://todo", + ) + + image_hashes = [compute_image_hash(img) for img in images_list] + original_sizes = [get_image_size(img) for img in images_list] + + embeddings_from_cache: Dict[int, SAM3ImageEmbeddings] = {} + images_to_compute, indices_to_compute = [], [] + + for idx, (image, image_hash) in enumerate(zip(images_list, image_hashes)): + cache_content = None + if use_embeddings_cache: + cache_content = self._sam3_image_embeddings_cache.retrieve_embeddings( + key=image_hash + ) + if cache_content is not None: + cache_content = cache_content.to(device=self._device) + embeddings_from_cache[idx] = cache_content + else: + images_to_compute.append(image) + indices_to_compute.append(idx) + + computed_embeddings = [] + if len(images_to_compute) > 0: + for batch_start in range(0, len(images_to_compute), self._max_batch_size): + batch_end = min( + batch_start + self._max_batch_size, len(images_to_compute) + ) + batch_images = images_to_compute[batch_start:batch_end] + batch_indices = indices_to_compute[batch_start:batch_end] + + batch_embeddings = self._forward_image_embeddings( + images=batch_images, + image_hashes=[image_hashes[i] for i in batch_indices], + original_sizes=[original_sizes[i] for i in batch_indices], + ) + computed_embeddings.extend(batch_embeddings) + + result_embeddings = [] + computed_idx = 0 + for i in range(len(images_list)): + if i in embeddings_from_cache: + result_embeddings.append(embeddings_from_cache[i]) + else: + result_embeddings.append(computed_embeddings[computed_idx]) + computed_idx += 1 + + if use_embeddings_cache: + for embeddings in result_embeddings: + self._sam3_image_embeddings_cache.save_embeddings( + key=embeddings.image_hash, embeddings=embeddings + ) + + return result_embeddings + + @torch.inference_mode() + def _forward_image_embeddings( + self, + images: List[Union[np.ndarray, torch.Tensor]], + image_hashes: List[str], + original_sizes: List[Tuple[int, int]], + ) -> List[SAM3ImageEmbeddings]: + result_embeddings = [] + + for image, image_hash, size in zip(images, image_hashes, original_sizes): + if isinstance(image, torch.Tensor): + np_image = image.cpu().numpy() + if np_image.shape[0] == 3: + np_image = np_image.transpose(1, 2, 0) + np_image = ( + (np_image * 255).astype(np.uint8) + if np_image.max() <= 1 + else np_image + ) + else: + np_image = image + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + processor = Sam3Processor(self._model) + state = processor.set_image(torch.from_numpy(np_image).permute(2, 0, 1)) + + result_embeddings.append( + SAM3ImageEmbeddings( + image_hash=image_hash, + image_size_hw=size, + embeddings=state, + ) + ) + + return result_embeddings + + def segment_images( + self, + images: Optional[ + Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]] + ] = None, + embeddings: Optional[ + Union[List[SAM3ImageEmbeddings], SAM3ImageEmbeddings] + ] = None, + point_coordinates: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None, + point_labels: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None, + boxes: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None, + mask_input: Optional[Union[List[ArrayOrTensor], ArrayOrTensor]] = None, + multi_mask_output: bool = True, + return_logits: bool = False, + load_from_mask_input_cache: bool = False, + save_to_mask_input_cache: bool = False, + use_embeddings_cache: bool = True, + **kwargs, + ) -> List[SAM3Prediction]: + if images is None and embeddings is None: + raise ModelInputError( + message="Attempted to use SAM3 model segment_images(...) method not providing valid input - " + "neither `images` nor `embeddings` parameter is given. If you run inference locally, " + "verify your integration making sure that the model interface is used correctly. Running " + "on Roboflow platform - contact us to get help.", + help_url="https://todo", + ) + + if images is not None: + embeddings_list = self.embed_images( + images=images, + use_embeddings_cache=use_embeddings_cache, + **kwargs, + ) + else: + embeddings_list = maybe_wrap_in_list(embeddings) + + image_hashes = [e.image_hash for e in embeddings_list] + original_image_sizes = [e.image_size_hw for e in embeddings_list] + + point_coordinates = maybe_wrap_in_list(point_coordinates) + point_labels = maybe_wrap_in_list(point_labels) + boxes = maybe_wrap_in_list(boxes) + mask_input = maybe_wrap_in_list(mask_input) + + point_coordinates, point_labels, boxes, mask_input = equalize_batch_size( + embeddings_batch_size=len(embeddings_list), + point_coordinates=point_coordinates, + point_labels=point_labels, + boxes=boxes, + mask_input=mask_input, + ) + + predictions = [] + for idx, embedding in enumerate(embeddings_list): + image_point_coords = point_coordinates[idx] if point_coordinates else None + image_point_labels = point_labels[idx] if point_labels else None + image_boxes = boxes[idx] if boxes else None + image_mask_input = mask_input[idx] if mask_input else None + image_hash = image_hashes[idx] + original_size = original_image_sizes[idx] + + serialized_prompt, prompt_hash = None, None + if save_to_mask_input_cache or load_from_mask_input_cache: + serialized_prompt = serialize_prompt( + point_coordinates=image_point_coords, + point_labels=image_point_labels, + boxes=image_boxes, + ) + prompt_hash = hash_serialized_prompt(serialized_prompt) + + if image_mask_input is None and load_from_mask_input_cache: + image_mask_input = attempt_load_image_mask_from_cache( + image_hash=image_hash, + serialized_prompt_hash=prompt_hash, + serialized_prompt=serialized_prompt, + sam3_low_resolution_masks_cache=self._sam3_low_resolution_masks_cache, + device=self._device, + ) + + prediction = self._predict_for_single_image( + embeddings=embedding, + original_image_size=original_size, + point_coordinates=image_point_coords, + point_labels=image_point_labels, + boxes=image_boxes, + mask_input=image_mask_input, + multi_mask_output=multi_mask_output, + return_logits=return_logits, + ) + + if save_to_mask_input_cache and len(prediction.masks.shape) >= 2: + max_score_id = torch.argmax(prediction.scores).item() + mask_entry = SAM3MaskCacheEntry( + prompt_hash=prompt_hash, + serialized_prompt=serialized_prompt, + mask=prediction.logits[max_score_id].unsqueeze(dim=0), + ) + self._sam3_low_resolution_masks_cache.save_mask( + key=image_hash, + mask=mask_entry, + ) + + predictions.append(prediction) + + return predictions + + @torch.inference_mode() + def _predict_for_single_image( + self, + embeddings: SAM3ImageEmbeddings, + original_image_size: Tuple[int, int], + point_coordinates: Optional[ArrayOrTensor] = None, + point_labels: Optional[ArrayOrTensor] = None, + boxes: Optional[ArrayOrTensor] = None, + mask_input: Optional[ArrayOrTensor] = None, + multi_mask_output: bool = True, + return_logits: bool = False, + ) -> SAM3Prediction: + args = {} + + if point_coordinates is not None and point_labels is not None: + if isinstance(point_coordinates, np.ndarray): + point_coordinates = point_coordinates.tolist() + elif isinstance(point_coordinates, torch.Tensor): + point_coordinates = point_coordinates.cpu().tolist() + + if isinstance(point_labels, np.ndarray): + point_labels = point_labels.tolist() + elif isinstance(point_labels, torch.Tensor): + point_labels = point_labels.cpu().tolist() + + args["point_coords"] = point_coordinates + args["point_labels"] = point_labels + + if boxes is not None: + if isinstance(boxes, np.ndarray): + boxes_list = boxes.tolist() + elif isinstance(boxes, torch.Tensor): + boxes_list = boxes.cpu().tolist() + else: + boxes_list = boxes + if len(boxes_list) > 0 and isinstance(boxes_list[0], (int, float)): + args["box"] = boxes_list + else: + args["box"] = boxes_list[0] if boxes_list else None + + args = pad_points(args) + if not any(args.values()): + args = {"point_coords": [[0, 0]], "point_labels": [-1], "box": None} + + mask_input_tensor = None + if mask_input is not None: + if isinstance(mask_input, np.ndarray): + mask_input_tensor = torch.from_numpy(mask_input).to(self._device) + elif isinstance(mask_input, torch.Tensor): + mask_input_tensor = mask_input.to(self._device) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + masks, scores, low_res_logits = self._model.predict_inst( + embeddings.embeddings, + mask_input=mask_input_tensor, + multimask_output=multi_mask_output, + return_logits=True, + normalize_coords=True, + **args, + ) + + masks, scores, low_res_logits = choose_most_confident_prediction( + masks=masks, + scores=scores, + low_resolution_logits=low_res_logits, + ) + + masks_tensor = ( + torch.from_numpy(masks) if isinstance(masks, np.ndarray) else masks + ) + scores_tensor = ( + torch.from_numpy(scores) if isinstance(scores, np.ndarray) else scores + ) + logits_tensor = ( + torch.from_numpy(low_res_logits) + if isinstance(low_res_logits, np.ndarray) + else low_res_logits + ) + + if not return_logits: + masks_tensor = masks_tensor > 0 + + return SAM3Prediction( + masks=masks_tensor, + scores=scores_tensor, + logits=logits_tensor, + ) + + def segment_with_text( + self, + images: Union[np.ndarray, List[np.ndarray]], + prompts: List[Dict], + output_prob_thresh: float = 0.5, + **kwargs, + ) -> List[Dict]: + images_list = maybe_wrap_in_list(images) + if images_list is None: + raise ModelInputError( + message="No images provided to segment_with_text()", + help_url="https://todo", + ) + + results = [] + for image in images_list: + np_image = image.cpu().numpy() if isinstance(image, torch.Tensor) else image + if np_image.shape[0] == 3: + np_image = np_image.transpose(1, 2, 0) + if np_image.max() <= 1: + np_image = np_image * 255 + + h, w = np_image.shape[:2] + pil_image = Image.fromarray(np_image.astype(np.uint8)) + + with torch.inference_mode(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + datapoint = Sam3Datapoint( + find_queries=[], + images=[Sam3ImageDP(data=pil_image, objects=[], size=(h, w))], + ) + + prompt_ids = [] + for idx, p in enumerate(prompts): + if p.get("boxes"): + q = _build_visual_query( + coco_id=idx, + h=h, + w=w, + boxes=p["boxes"], + labels=p.get("box_labels", []), + text=p.get("text"), + ) + else: + q = _build_text_query( + coco_id=idx, + h=h, + w=w, + text=p.get("text"), + ) + datapoint.find_queries.append(q) + prompt_ids.append(idx) + + datapoint = self._transform(datapoint) + batch = collate_fn_api(batch=[datapoint], dict_key="dummy")["dummy"] + batch = copy_data_to_device( + batch, + self._device, + non_blocking=True, + ) + + output = self._model(batch) + + post = PostProcessImage( + max_dets_per_img=-1, + iou_type="segm", + use_original_sizes_box=True, + use_original_sizes_mask=True, + convert_mask_to_rle=False, + detection_threshold=float(output_prob_thresh), + to_cpu=True, + ) + processed = post.process_results(output, batch.find_metadatas) + + image_results = [] + for idx, coco_id in enumerate(prompt_ids): + masks = processed[coco_id].get("masks") + scores = processed[coco_id].get("scores", []) + + if masks is not None: + if hasattr(masks, "detach"): + masks = masks.detach().cpu().numpy() + masks = np.array(masks) + else: + masks = np.zeros((0, h, w), dtype=np.uint8) + + image_results.append( + { + "prompt_index": idx, + "masks": masks, + "scores": list(scores), + } + ) + + results.append(image_results) + + return results + + +def decode_sam_version(config_path: str) -> str: + config = read_json(path=config_path) + version = config["version"] + if not isinstance(version, str): + raise ValueError("Could not decode SAM model version") + return version + + +def compute_image_hash(image: Union[torch.Tensor, np.ndarray]) -> str: + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + return hashlib.sha1(image.tobytes()).hexdigest() + + +def get_image_size(image: Union[torch.Tensor, np.ndarray]) -> Tuple[int, int]: + if isinstance(image, torch.Tensor): + if len(image.shape) == 3: + if image.shape[0] == 3: + return (image.shape[1], image.shape[2]) + else: + return (image.shape[0], image.shape[1]) + return (image.shape[-2], image.shape[-1]) + return (image.shape[0], image.shape[1]) + + +def maybe_wrap_in_list(value: Optional[Union[T, List[T]]]) -> Optional[List[T]]: + if value is None: + return None + if isinstance(value, list): + return value + return [value] + + +def equalize_batch_size( + embeddings_batch_size: int, + point_coordinates: Optional[List[ArrayOrTensor]], + point_labels: Optional[List[ArrayOrTensor]], + boxes: Optional[List[ArrayOrTensor]], + mask_input: Optional[List[ArrayOrTensor]], +) -> Tuple[ + Optional[List[ArrayOrTensor]], + Optional[List[ArrayOrTensor]], + Optional[List[ArrayOrTensor]], + Optional[List[ArrayOrTensor]], +]: + if ( + point_coordinates is not None + and len(point_coordinates) != embeddings_batch_size + ): + if len(point_coordinates) == 1: + point_coordinates = point_coordinates * embeddings_batch_size + else: + raise ModelInputError( + message=f"point_coordinates batch size ({len(point_coordinates)}) doesn't match " + f"embeddings batch size ({embeddings_batch_size})", + help_url="https://todo", + ) + + if point_labels is not None and len(point_labels) != embeddings_batch_size: + if len(point_labels) == 1: + point_labels = point_labels * embeddings_batch_size + else: + raise ModelInputError( + message=f"point_labels batch size ({len(point_labels)}) doesn't match " + f"embeddings batch size ({embeddings_batch_size})", + help_url="https://todo", + ) + + if boxes is not None and len(boxes) != embeddings_batch_size: + if len(boxes) == 1: + boxes = boxes * embeddings_batch_size + else: + raise ModelInputError( + message=f"boxes batch size ({len(boxes)}) doesn't match " + f"embeddings batch size ({embeddings_batch_size})", + help_url="https://todo", + ) + + if mask_input is not None and len(mask_input) != embeddings_batch_size: + if len(mask_input) == 1: + mask_input = mask_input * embeddings_batch_size + else: + raise ModelInputError( + message=f"mask_input batch size ({len(mask_input)}) doesn't match " + f"embeddings batch size ({embeddings_batch_size})", + help_url="https://todo", + ) + + return point_coordinates, point_labels, boxes, mask_input + + +def pad_points(args: Dict) -> Dict: + args = copy(args) + if args.get("point_coords") is not None: + point_labels = args.get("point_labels") + if ( + not isinstance(point_labels, list) + or len(point_labels) > 0 + and any(not isinstance(p, list) for p in point_labels) + ): + raise ModelInputError( + message="point_labels must be a nested list (e.g., [[1, 0, 1]]). " + "Each inner list should contain labels for points in a single prompt.", + help_url="https://todo", + ) + max_len = max(max(len(prompt) for prompt in args["point_coords"]), 1) + for prompt in args["point_coords"]: + for _ in range(max_len - len(prompt)): + prompt.append([0, 0]) + for label in args["point_labels"]: + for _ in range(max_len - len(label)): + label.append(-1) + return args + + +def choose_most_confident_prediction( + masks: np.ndarray, + scores: np.ndarray, + low_resolution_logits: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + if len(masks.shape) == 3: + masks = np.expand_dims(masks, axis=0) + scores = np.expand_dims(scores, axis=0) + low_resolution_logits = np.expand_dims(low_resolution_logits, axis=0) + + selected_masks, selected_scores, selected_logits = [], [], [] + for mask, score, logit in zip(masks, scores, low_resolution_logits): + max_idx = np.argmax(score) + selected_masks.append(mask[max_idx]) + selected_scores.append(score[max_idx]) + selected_logits.append(logit[max_idx]) + + return ( + np.asarray(selected_masks), + np.asarray(selected_scores), + np.asarray(selected_logits), + ) + + +def serialize_prompt( + point_coordinates: Optional[ArrayOrTensor], + point_labels: Optional[ArrayOrTensor], + boxes: Optional[ArrayOrTensor], +) -> List[dict]: + if point_coordinates is None and point_labels is None and boxes is None: + return [] + + result = {"points": [], "box": None} + + if point_coordinates is not None and point_labels is not None: + if isinstance(point_coordinates, torch.Tensor): + coords_list = point_coordinates.cpu().tolist() + elif isinstance(point_coordinates, np.ndarray): + coords_list = point_coordinates.tolist() + else: + coords_list = point_coordinates + + if isinstance(point_labels, torch.Tensor): + labels_list = point_labels.cpu().tolist() + elif isinstance(point_labels, np.ndarray): + labels_list = point_labels.tolist() + else: + labels_list = point_labels + + for coord, label in zip(coords_list, labels_list): + result["points"].append( + { + "x": coord[0] if isinstance(coord, (list, tuple)) else coord, + "y": coord[1] if isinstance(coord, (list, tuple)) else 0, + "positive": bool(label), + } + ) + + if boxes is not None: + if isinstance(boxes, torch.Tensor): + result["box"] = boxes.cpu().tolist() + elif isinstance(boxes, np.ndarray): + result["box"] = boxes.tolist() + else: + result["box"] = boxes + + return [result] + + +def hash_serialized_prompt(serialized_prompt: List[dict]) -> str: + serialized = json.dumps(serialized_prompt, sort_keys=True, separators=(",", ":")) + return hashlib.sha1(serialized.encode("utf-8")).hexdigest() + + +def attempt_load_image_mask_from_cache( + image_hash: str, + serialized_prompt_hash: str, + serialized_prompt: List[dict], + sam3_low_resolution_masks_cache: Sam3LowResolutionMasksCache, + device: torch.device, +) -> Optional[torch.Tensor]: + all_masks = sam3_low_resolution_masks_cache.retrieve_all_masks_for_image( + key=image_hash + ) + if not all_masks: + return None + if len(serialized_prompt) == 0: + return None + + return find_prior_prompt_in_cache( + serialized_prompt_hash=serialized_prompt_hash, + serialized_prompt=serialized_prompt, + matching_cache_entries=all_masks, + device=device, + ) + + +def find_prior_prompt_in_cache( + serialized_prompt_hash: str, + serialized_prompt: List[dict], + matching_cache_entries: List[SAM3MaskCacheEntry], + device: torch.device, +) -> Optional[torch.Tensor]: + maxed_size = 0 + best_match: Optional[SAM3MaskCacheEntry] = None + num_points = ( + 0 if not serialized_prompt else len(serialized_prompt[0].get("points", [])) + ) + if num_points <= 1: + return None # there is only 1 point, hence no prior prompt can be found + desired_size = num_points - 1 + + for cache_entry in matching_cache_entries[::-1]: + is_viable = is_prompt_strict_subset( + assumed_sub_set_prompt=( + cache_entry.prompt_hash, + cache_entry.serialized_prompt, + ), + assumed_super_set_prompt=(serialized_prompt_hash, serialized_prompt), + ) + if not is_viable: + continue + + cached_prompt = cache_entry.serialized_prompt + current_size = ( + 0 if not cached_prompt else len(cached_prompt[0].get("points", [])) + ) + if current_size == desired_size: + return cache_entry.mask.to(device=device) + if current_size >= maxed_size: + maxed_size = current_size + best_match = cache_entry + + if best_match is not None: + return best_match.mask.to(device=device) + return None + + +def is_prompt_strict_subset( + assumed_sub_set_prompt: Tuple[str, List[dict]], + assumed_super_set_prompt: Tuple[str, List[dict]], +) -> bool: + if assumed_sub_set_prompt[0] == assumed_super_set_prompt[0]: + return False + + super_set_copy = copy(assumed_super_set_prompt[1]) + for sub_element in assumed_sub_set_prompt[1]: + found_match = False + for super_element in super_set_copy: + boxes_match = sub_element.get("box") == super_element.get("box") + if not boxes_match: + continue + + sub_points = { + json.dumps(p, sort_keys=True) for p in sub_element.get("points", []) + } + super_points = { + json.dumps(p, sort_keys=True) for p in super_element.get("points", []) + } + if sub_points <= super_points: + super_set_copy.remove(super_element) + found_match = True + break + + if not found_match: + return False + + return True + + +def _build_text_query( + coco_id: int, + h: int, + w: int, + text: Optional[str], +) -> FindQueryLoaded: + return FindQueryLoaded( + query_text=text if text is not None else "visual", + image_id=0, + object_ids_output=[], + is_exhaustive=True, + query_processing_order=0, + input_bbox=None, + input_bbox_label=None, + input_points=None, + semantic_target=None, + is_pixel_exhaustive=None, + inference_metadata=InferenceMetadata( + coco_image_id=coco_id, + original_image_id=coco_id, + original_category_id=1, + original_size=(h, w), + object_id=0, + frame_index=0, + ), + ) + + +def _build_visual_query( + coco_id: int, + h: int, + w: int, + boxes: Optional[List], + labels: Optional[List], + text: Optional[str], +) -> FindQueryLoaded: + xyxy_pixels: List[List[float]] = [] + for b in boxes or []: + if isinstance(b, dict): + if "x" in b: + x0 = float(b["x"]) + y0 = float(b["y"]) + x1 = x0 + float(b["width"]) + y1 = y0 + float(b["height"]) + else: + x0 = float(b["x0"]) + y0 = float(b["y0"]) + x1 = float(b["x1"]) + y1 = float(b["y1"]) + elif hasattr(b, "x"): + x0 = float(b.x) + y0 = float(b.y) + x1 = x0 + float(b.width) + y1 = y0 + float(b.height) + elif hasattr(b, "x0"): + x0 = float(b.x0) + y0 = float(b.y0) + x1 = float(b.x1) + y1 = float(b.y1) + elif isinstance(b, (list, tuple)) and len(b) == 4: + x0, y0, x1, y1 = [float(v) for v in b] + else: + continue + xyxy_pixels.append([x0, y0, x1, y1]) + + labels_bool = [bool(int(v)) for v in (labels or [])] + + return FindQueryLoaded( + query_text=text if text is not None else "visual", + image_id=0, + object_ids_output=[], + is_exhaustive=True, + query_processing_order=0, + input_bbox=( + torch.tensor(xyxy_pixels, dtype=torch.float32) if xyxy_pixels else None + ), + input_bbox_label=( + torch.tensor(labels_bool, dtype=torch.bool) if labels_bool else None + ), + input_points=None, + semantic_target=None, + is_pixel_exhaustive=None, + inference_metadata=InferenceMetadata( + coco_image_id=coco_id, + original_image_id=coco_id, + original_category_id=1, + original_size=(h, w), + object_id=0, + frame_index=0, + ), + ) diff --git a/inference_models/pyproject.toml b/inference_models/pyproject.toml index 06c77747a4..1fdd4474a5 100644 --- a/inference_models/pyproject.toml +++ b/inference_models/pyproject.toml @@ -36,6 +36,8 @@ dependencies = [ "pybase64~=1.0.0", "rf-segment-anything==1.0", "rf-sam-2==1.0.2", + "sam3==0.1.2; sys_platform != 'darwin'", + "triton<4.0.0; sys_platform != 'darwin'", "argon2-cffi>=25.1.0,<26.0.0", ] diff --git a/inference_models/tests/integration_tests/models/conftest.py b/inference_models/tests/integration_tests/models/conftest.py index 5d7bbcddf4..6aca7cb810 100644 --- a/inference_models/tests/integration_tests/models/conftest.py +++ b/inference_models/tests/integration_tests/models/conftest.py @@ -163,6 +163,9 @@ SAM2_PACKAGE_URL = ( "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/sam2.zip" ) +SAM3_PACKAGE_URL = ( + "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/sam3.zip" +) @pytest.fixture(scope="module") @@ -1273,3 +1276,11 @@ def sam2_package() -> str: model_package_zip_url=SAM2_PACKAGE_URL, package_name="sam2", ) + + +@pytest.fixture(scope="module") +def sam3_package() -> str: + return download_model_package( + model_package_zip_url=SAM3_PACKAGE_URL, + package_name="sam3", + ) diff --git a/inference_models/tests/integration_tests/models/test_sam3_predictions.py b/inference_models/tests/integration_tests/models/test_sam3_predictions.py new file mode 100644 index 0000000000..d6eb265008 --- /dev/null +++ b/inference_models/tests/integration_tests/models/test_sam3_predictions.py @@ -0,0 +1,510 @@ +import numpy as np +import pytest +import torch + +from inference_models.configuration import DEFAULT_DEVICE +from inference_models.errors import ModelInputError +from inference_models.models.sam3.sam3_torch import SAM3Torch + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_embeddings_numpy( + sam3_package: str, truck_image_numpy: np.ndarray +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + + # when + results = model.embed_images(truck_image_numpy) + + # then + assert len(results) == 1 + assert results[0].embeddings is not None + assert results[0].image_size_hw == truck_image_numpy.shape[:2] + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_embeddings_torch( + sam3_package: str, truck_image_torch: torch.Tensor +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + + # when + results = model.embed_images(truck_image_torch) + + # then + assert len(results) == 1 + assert results[0].embeddings is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_embeddings_batch_numpy( + sam3_package: str, truck_image_numpy: np.ndarray +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + + # when + results = model.embed_images([truck_image_numpy, truck_image_numpy]) + + # then + assert len(results) == 2 + assert results[0].embeddings is not None + assert results[1].embeddings is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_embeddings_caching( + sam3_package: str, truck_image_numpy: np.ndarray +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + + # when - first call computes embeddings + results1 = model.embed_images(truck_image_numpy) + # second call should retrieve from cache + results2 = model.embed_images(truck_image_numpy) + + # then - hashes should match (same image) + assert results1[0].image_hash == results2[0].image_hash + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_without_prompting_numpy( + sam3_package: str, truck_image_numpy: np.ndarray +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + + # when + results = model.segment_images(truck_image_numpy) + + # then + assert len(results) == 1 + assert results[0].masks is not None + assert results[0].scores is not None + assert results[0].logits is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_without_prompting_batch_numpy( + sam3_package: str, truck_image_numpy: np.ndarray +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + + # when + results = model.segment_images([truck_image_numpy, truck_image_numpy]) + + # then + assert len(results) == 2 + assert results[0].masks is not None + assert results[1].masks is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_with_point_prompting( + sam3_package: str, + truck_image_numpy: np.ndarray, +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + input_point = np.array([[500, 375]]) + input_label = np.array([1]) + + # when + results = model.segment_images( + truck_image_numpy, + point_coordinates=input_point, + point_labels=input_label, + ) + + # then + assert len(results) == 1 + assert results[0].masks is not None + assert results[0].scores is not None + # With a positive point prompt, we should get a mask with some area + mask_sum = ( + results[0].masks.sum() + if isinstance(results[0].masks, torch.Tensor) + else results[0].masks.sum() + ) + assert mask_sum > 0 + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_with_multiple_points( + sam3_package: str, + truck_image_numpy: np.ndarray, +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + input_points = np.array([[500, 375], [600, 400], [450, 350]]) + input_labels = np.array([1, 1, 1]) + + # when + results = model.segment_images( + truck_image_numpy, + point_coordinates=input_points, + point_labels=input_labels, + ) + + # then + assert len(results) == 1 + assert results[0].masks is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_with_embeddings( + sam3_package: str, truck_image_numpy: np.ndarray +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + input_point = np.array([[500, 375]]) + input_label = np.array([1]) + + # when + embeddings = model.embed_images(truck_image_numpy) + results = model.segment_images( + embeddings=[embeddings[0], embeddings[0]], + point_coordinates=input_point, + point_labels=input_label, + ) + + # then + assert len(results) == 2 + assert results[0].masks is not None + assert results[1].masks is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_with_box_prompting( + sam3_package: str, + truck_image_numpy: np.ndarray, +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + input_box = np.array([425, 600, 700, 875]) + + # when + results = model.segment_images( + truck_image_numpy, + boxes=input_box, + ) + + # then + assert len(results) == 1 + assert results[0].masks is not None + assert results[0].scores is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_with_box_prompting_and_embeddings( + sam3_package: str, + truck_image_numpy: np.ndarray, +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + input_box = np.array([425, 600, 700, 875]) + + # when + embeddings = model.embed_images(truck_image_numpy) + results = model.segment_images( + embeddings=[embeddings[0], embeddings[0]], + boxes=input_box, + ) + + # then + assert len(results) == 2 + assert results[0].masks is not None + assert results[1].masks is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_with_combined_prompting( + sam3_package: str, truck_image_numpy: np.ndarray +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + input_box = np.array([425, 600, 700, 875]) + input_point = np.array([[575, 750]]) + input_label = np.array([0]) # negative point + + # when + results = model.segment_images( + truck_image_numpy, + point_coordinates=[input_point], + point_labels=[input_label], + boxes=input_box, + ) + + # then + assert len(results) == 1 + assert results[0].masks is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_with_mask_prompting( + sam3_package: str, truck_image_numpy: np.ndarray +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + input_point = np.array([[500, 375]]) + input_label = np.array([1]) + + # when - first pass to get a mask + first_results = model.segment_images( + images=truck_image_numpy, + point_coordinates=input_point, + point_labels=input_label, + ) + # second pass using the logits as mask input + second_results = model.segment_images( + images=truck_image_numpy, + mask_input=first_results[0].logits, + point_coordinates=input_point, + point_labels=input_label, + ) + + # then + assert len(second_results) == 1 + assert second_results[0].masks is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_raises_on_missing_input( + sam3_package: str, +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + + # when / then + with pytest.raises(ModelInputError): + _ = model.segment_images() + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_with_misaligned_batch_sizes( + sam3_package: str, truck_image_numpy: np.ndarray +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + input_point = np.array([[575, 750]]) + input_label = np.array([0]) + + # when / then - misaligned point_labels + with pytest.raises(ModelInputError): + _ = model.segment_images( + truck_image_numpy, + point_coordinates=[input_point], + point_labels=[input_label, input_label], # 2 labels for 1 image + ) + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_with_text_single_prompt( + sam3_package: str, + truck_image_numpy: np.ndarray, +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + prompts = [{"text": "truck"}] + + # when + results = model.segment_with_text( + images=truck_image_numpy, + prompts=prompts, + output_prob_thresh=0.3, + ) + + # then + assert len(results) == 1 + assert len(results[0]) == 1 # one prompt result + assert results[0][0]["prompt_index"] == 0 + assert "masks" in results[0][0] + assert "scores" in results[0][0] + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_with_text_multiple_prompts( + sam3_package: str, + truck_image_numpy: np.ndarray, +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + prompts = [ + {"text": "truck"}, + {"text": "wheel"}, + {"text": "sky"}, + ] + + # when + results = model.segment_with_text( + images=truck_image_numpy, + prompts=prompts, + output_prob_thresh=0.3, + ) + + # then + assert len(results) == 1 + assert len(results[0]) == 3 # three prompt results + for i, prompt_result in enumerate(results[0]): + assert prompt_result["prompt_index"] == i + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_with_text_visual_prompt( + sam3_package: str, + truck_image_numpy: np.ndarray, +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + prompts = [ + { + "text": "vehicle", + "boxes": [[425, 600, 700, 875]], # XYXY format + "box_labels": [1], + } + ] + + # when + results = model.segment_with_text( + images=truck_image_numpy, + prompts=prompts, + output_prob_thresh=0.3, + ) + + # then + assert len(results) == 1 + assert len(results[0]) == 1 + assert results[0][0]["masks"] is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_with_text_batch_images( + sam3_package: str, + truck_image_numpy: np.ndarray, +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + prompts = [{"text": "truck"}] + + # when + results = model.segment_with_text( + images=[truck_image_numpy, truck_image_numpy], + prompts=prompts, + output_prob_thresh=0.3, + ) + + # then + assert len(results) == 2 # two images + assert len(results[0]) == 1 # one prompt per image + assert len(results[1]) == 1 + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_multi_mask_output( + sam3_package: str, + truck_image_numpy: np.ndarray, +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + input_point = np.array([[500, 375]]) + input_label = np.array([1]) + + # when - with multi_mask_output=True (default) + results_multi = model.segment_images( + truck_image_numpy, + point_coordinates=input_point, + point_labels=input_label, + multi_mask_output=True, + ) + + # when - with multi_mask_output=False + results_single = model.segment_images( + truck_image_numpy, + point_coordinates=input_point, + point_labels=input_label, + multi_mask_output=False, + ) + + # then - both should return results + assert len(results_multi) == 1 + assert len(results_single) == 1 + assert results_multi[0].masks is not None + assert results_single[0].masks is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_segment_images_return_logits( + sam3_package: str, + truck_image_numpy: np.ndarray, +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + input_point = np.array([[500, 375]]) + input_label = np.array([1]) + + # when - with return_logits=True + results_logits = model.segment_images( + truck_image_numpy, + point_coordinates=input_point, + point_labels=input_label, + return_logits=True, + ) + + # when - with return_logits=False (default) + results_binary = model.segment_images( + truck_image_numpy, + point_coordinates=input_point, + point_labels=input_label, + return_logits=False, + ) + + # then + assert len(results_logits) == 1 + assert len(results_binary) == 1 + # Logits should have floating point values, binary should be 0/1 or True/False + assert results_logits[0].masks is not None + assert results_binary[0].masks is not None + + +@pytest.mark.slow +@pytest.mark.torch_models +def test_sam3_caching_disabled( + sam3_package: str, truck_image_numpy: np.ndarray +) -> None: + # given + model = SAM3Torch.from_pretrained(sam3_package, device=DEFAULT_DEVICE) + + # when - with caching disabled + results1 = model.embed_images(truck_image_numpy, use_embeddings_cache=False) + results2 = model.embed_images(truck_image_numpy, use_embeddings_cache=False) + + # then - both should succeed + assert len(results1) == 1 + assert len(results2) == 1 + assert results1[0].embeddings is not None + assert results2[0].embeddings is not None