Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions sagemaker-serve/src/sagemaker/serve/model_builder_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _build_for_torchserve(self) -> Model:
if isinstance(self.model, str):
# Configure HuggingFace model support
if not self._is_jumpstart_model_id():
self.env_vars.update({"HF_MODEL_ID": self.model})
self.env_vars.setdefault("HF_MODEL_ID", self.model)

# Add HuggingFace token if available
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
Expand Down Expand Up @@ -212,7 +212,7 @@ def _build_for_tgi(self) -> Model:

if isinstance(self.model, str) and not self._is_jumpstart_model_id():
# Configure HuggingFace model for TGI
self.env_vars.update({"HF_MODEL_ID": self.model})
self.env_vars.setdefault("HF_MODEL_ID", self.model)

self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
Expand Down Expand Up @@ -320,7 +320,7 @@ def _build_for_djl(self) -> Model:

if isinstance(self.model, str) and not self._is_jumpstart_model_id():
# Configure HuggingFace model for DJL
self.env_vars.update({"HF_MODEL_ID": self.model})
self.env_vars.setdefault("HF_MODEL_ID", self.model)

# Get model configuration for DJL optimization
self.hf_model_config = _get_model_config_properties_from_hf(
Expand Down Expand Up @@ -426,7 +426,7 @@ def _build_for_triton(self) -> Model:
self.env_vars.update({"HF_TASK": model_task})

# Configure HuggingFace authentication
self.env_vars.update({"HF_MODEL_ID": self.model})
self.env_vars.setdefault("HF_MODEL_ID", self.model)
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")

Expand Down Expand Up @@ -532,7 +532,7 @@ def _build_for_tei(self) -> Model:

if isinstance(self.model, str) and not self._is_jumpstart_model_id():
# Configure HuggingFace model for TEI
self.env_vars.update({"HF_MODEL_ID": self.model})
self.env_vars.setdefault("HF_MODEL_ID", self.model)

self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
Expand Down Expand Up @@ -676,7 +676,7 @@ def _build_for_transformers(self) -> Model:
if self.inference_spec is not None:
hf_model_id = self.inference_spec.get_model()
if isinstance(hf_model_id, str): # Only if it's a valid HF model ID
self.env_vars.update({"HF_MODEL_ID": hf_model_id})
self.env_vars.setdefault("HF_MODEL_ID", hf_model_id)
# Get HF config only for string model IDs
if hasattr(self.env_vars, "HF_API_TOKEN"):
self.hf_model_config = _get_model_config_properties_from_hf(
Expand All @@ -687,7 +687,7 @@ def _build_for_transformers(self) -> Model:
hf_model_id, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
elif isinstance(self.model, str): # Only set HF_MODEL_ID if model is a string
self.env_vars.update({"HF_MODEL_ID": self.model})
self.env_vars.setdefault("HF_MODEL_ID", self.model)
# Get HF config for string model IDs
if hasattr(self.env_vars, "HF_API_TOKEN"):
self.hf_model_config = _get_model_config_properties_from_hf(
Expand Down
275 changes: 275 additions & 0 deletions sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
"""Unit tests to verify HF_MODEL_ID is not overwritten when user provides it."""
import unittest
from unittest.mock import Mock, patch, MagicMock, PropertyMock

from sagemaker.serve.model_builder_servers import _ModelBuilderServers
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode


def _create_mock_builder(env_vars=None, model="Qwen/Qwen3-VL-4B-Instruct"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using @pytest.fixture for the mock builder creation instead of a plain helper function. This would be more idiomatic pytest and allow parameterization:

@pytest.fixture
def mock_builder():
    """Create a mock builder with common attributes set."""
    ...

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type annotations on the helper function. Per SDK conventions, public/utility functions should have type hints:

def _create_mock_builder(
    env_vars: dict[str, str] | None = None,
    model: str = "Qwen/Qwen3-VL-4B-Instruct",
) -> MagicMock:

"""Create a mock builder with common attributes set."""
builder = MagicMock(spec=_ModelBuilderServers)
builder.model = model
builder.env_vars = env_vars if env_vars is not None else {}
builder.model_path = "/tmp/test_model_path"
builder.mode = Mode.SAGEMAKER_ENDPOINT
builder.model_server = ModelServer.DJL_SERVING
builder.secret_key = ""
builder.s3_upload_path = None
builder.s3_model_data_url = None
builder.shared_libs = []
builder.dependencies = {}
builder.image_uri = "test-image-uri"
builder.instance_type = "ml.g5.2xlarge"
builder.sagemaker_session = Mock()
builder.schema_builder = MagicMock()
builder.schema_builder.sample_input = {"inputs": "Hello", "parameters": {}}
builder.inference_spec = None
builder.hf_model_config = {}
builder.model_data_download_timeout = None
builder._user_provided_instance_type = True
builder._is_jumpstart_model_id = Mock(return_value=False)
builder._auto_detect_image_uri = Mock()
builder._prepare_for_mode = Mock(return_value=("s3://model-data", None))
builder._create_model = Mock(return_value=Mock())
builder._optimizing = False
builder._validate_djl_serving_sample_data = Mock()
builder._validate_tgi_serving_sample_data = Mock()
builder._validate_for_triton = Mock()
builder.get_huggingface_model_metadata = Mock(return_value={"pipeline_tag": "text-generation"})
builder.role_arn = "arn:aws:iam::123456789012:role/SageMakerRole"
return builder


class TestDjlPreservesHfModelId(unittest.TestCase):
"""Test that _build_for_djl preserves user-provided HF_MODEL_ID."""

@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
@patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations")
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
@patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1)
@patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1)
def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config):
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten."""
mock_hf_config.return_value = {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using @pytest.mark.parametrize to reduce duplication across the test classes. Many tests follow the same pattern (preserve vs. set default) across different server types. For example:

@pytest.mark.parametrize("build_method,server_type,patches", [
    ("_build_for_djl", ModelServer.DJL_SERVING, [...]),
    ("_build_for_tgi", ModelServer.TGI, [...]),
    ...
])
def test_preserves_user_provided_hf_model_id(build_method, server_type, patches):
    ...

This would significantly reduce the test file size while maintaining coverage.

mock_djl_config.return_value = ({}, 256)

s3_path = "s3://my-bucket/models/Qwen/"
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})

with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"):
_ModelBuilderServers._build_for_djl(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)

@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
@patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations")
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line exceeds 100 characters (the SDK's line length limit). Several other decorator lines in this file also exceed the limit (lines 69, 97, 98, etc.). Please wrap long lines to stay within 100 characters.

@patch(
    "sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree",
    return_value=1,
)

@patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1)
@patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long function signature exceeds 100 characters. Please wrap parameters across multiple lines.

def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config):
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
mock_hf_config.return_value = {}
mock_djl_config.return_value = ({}, 256)

builder = _create_mock_builder(env_vars={})

with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"):
_ModelBuilderServers._build_for_djl(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")


class TestTgiPreservesHfModelId(unittest.TestCase):
"""Test that _build_for_tgi preserves user-provided HF_MODEL_ID."""

@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
@patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations")
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
@patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1)
@patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1)
def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config):
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten."""
mock_hf_config.return_value = {}
mock_tgi_config.return_value = ({}, 256)

s3_path = "s3://my-bucket/models/Qwen/"
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
builder.model_server = ModelServer.TGI

with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"):
_ModelBuilderServers._build_for_tgi(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)

@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
@patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations")
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
@patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1)
@patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1)
def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config):
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
mock_hf_config.return_value = {}
mock_tgi_config.return_value = ({}, 256)

builder = _create_mock_builder(env_vars={})
builder.model_server = ModelServer.TGI

with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"):
_ModelBuilderServers._build_for_tgi(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")


class TestTeiPreservesHfModelId(unittest.TestCase):
"""Test that _build_for_tei preserves user-provided HF_MODEL_ID."""

@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
def test_preserves_user_provided_s3_uri(self, mock_nb, mock_hf_config):
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten."""
mock_hf_config.return_value = {}

s3_path = "s3://my-bucket/models/embedding-model/"
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
builder.model_server = ModelServer.TEI

with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"):
_ModelBuilderServers._build_for_tei(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)

@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
def test_sets_hf_model_id_when_not_provided(self, mock_nb, mock_hf_config):
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
mock_hf_config.return_value = {}

builder = _create_mock_builder(env_vars={})
builder.model_server = ModelServer.TEI

with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"):
_ModelBuilderServers._build_for_tei(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")


class TestTorchservePreservesHfModelId(unittest.TestCase):
"""Test that _build_for_torchserve preserves user-provided HF_MODEL_ID."""

def test_preserves_user_provided_s3_uri(self):
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten."""
s3_path = "s3://my-bucket/models/my-model/"
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
builder.model_server = ModelServer.TORCHSERVE
builder.mode = Mode.SAGEMAKER_ENDPOINT
builder._save_model_inference_spec = Mock()

_ModelBuilderServers._build_for_torchserve(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)

def test_sets_hf_model_id_when_not_provided(self):
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
builder = _create_mock_builder(env_vars={})
builder.model_server = ModelServer.TORCHSERVE
builder.mode = Mode.SAGEMAKER_ENDPOINT
builder._save_model_inference_spec = Mock()

_ModelBuilderServers._build_for_torchserve(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")


class TestTritonPreservesHfModelId(unittest.TestCase):
"""Test that _build_for_triton preserves user-provided HF_MODEL_ID."""

def test_preserves_user_provided_s3_uri(self):
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten."""
s3_path = "s3://my-bucket/models/my-model/"
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
builder.model_server = ModelServer.TRITON
builder._save_inference_spec = Mock()
builder._prepare_for_triton = Mock()
builder._auto_detect_image_for_triton = Mock()

_ModelBuilderServers._build_for_triton(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)

def test_sets_hf_model_id_when_not_provided(self):
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
builder = _create_mock_builder(env_vars={})
builder.model_server = ModelServer.TRITON
builder._save_inference_spec = Mock()
builder._prepare_for_triton = Mock()
builder._auto_detect_image_for_triton = Mock()

_ModelBuilderServers._build_for_triton(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")


class TestTransformersPreservesHfModelId(unittest.TestCase):
"""Test that _build_for_transformers preserves user-provided HF_MODEL_ID."""

@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
def test_preserves_user_provided_s3_uri_with_model_string(self, mock_nb, mock_hf_config):
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten when model is a string."""
mock_hf_config.return_value = {}

s3_path = "s3://my-bucket/models/my-model/"
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
builder.model_server = ModelServer.MMS
builder.mode = Mode.SAGEMAKER_ENDPOINT
builder.model_data_download_timeout = None

with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"):
_ModelBuilderServers._build_for_transformers(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)

@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
def test_sets_hf_model_id_when_not_provided_with_model_string(self, mock_nb, mock_hf_config):
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
mock_hf_config.return_value = {}

builder = _create_mock_builder(env_vars={})
builder.model_server = ModelServer.MMS
builder.mode = Mode.SAGEMAKER_ENDPOINT
builder.model_data_download_timeout = None

with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"):
_ModelBuilderServers._build_for_transformers(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")

@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
@patch("sagemaker.serve.model_builder_servers.save_pkl")
def test_preserves_user_provided_hf_model_id_with_inference_spec(self, mock_pkl, mock_nb, mock_hf_config):
"""User-provided HF_MODEL_ID should not be overwritten when inference_spec provides a model ID."""
mock_hf_config.return_value = {}

s3_path = "s3://my-bucket/models/my-model/"
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
builder.model_server = ModelServer.MMS
builder.mode = Mode.SAGEMAKER_ENDPOINT
builder.model_data_download_timeout = None
builder.model = None # No model string, using inference_spec
builder.inference_spec = Mock()
builder.inference_spec.get_model.return_value = "some-hf-model-id"
builder._is_jumpstart_model_id = Mock(return_value=False)

with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"):
with patch("os.makedirs"):
_ModelBuilderServers._build_for_transformers(builder)

self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)


if __name__ == "__main__":
unittest.main()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove if __name__ == '__main__': unittest.main() — the SDK runs tests via pytest, not unittest's runner.

Loading