Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sagemaker-serve/src/sagemaker/serve/model_builder_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _build_for_djl(self) -> Model:

if isinstance(self.model, str) and not self._is_jumpstart_model_id():
# Configure HuggingFace model for DJL
self.env_vars.update({"HF_MODEL_ID": self.model})
self.env_vars.setdefault("HF_MODEL_ID", self.model)

# Get model configuration for DJL optimization
self.hf_model_config = _get_model_config_properties_from_hf(
Expand All @@ -345,7 +345,9 @@ def _build_for_djl(self) -> Model:
"SERVING_MAX_WORKERS": "1",
"OPTION_MODEL_LOADING_TIMEOUT": "240",
"OPTION_PREDICT_TIMEOUT": "60",
"TENSOR_PARALLEL_DEGREE": "1" # Default, will be overridden below
"TENSOR_PARALLEL_DEGREE": "1", # Default, will be overridden below
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential duplicate/conflicting env var setting: HF_HOME and HUGGINGFACE_HUB_CACHE are set here (inside the not self._is_jumpstart_model_id() branch, line 348-349), and then again in the else branch at lines 375-376 (non-local mode). This means for non-local, non-JumpStart models, these values get set twice (which is harmless but redundant). However, for JumpStart models or local modes, the behavior differs:

  • Local mode: The if self.mode in LOCAL_MODES branch sets HF_HUB_OFFLINE but does NOT set HF_HOME/HUGGINGFACE_HUB_CACHE. If the model is not a JumpStart model, these were already set at line 348. But if it IS a JumpStart model, they won't be set at all in local mode. Is that intentional?
  • Non-local JumpStart models: They'll get the env vars from lines 375-376 but not from 348-349.

Consider consolidating the HF cache env var setting to a single location (e.g., always set them regardless of JumpStart status and mode) to make the logic clearer and avoid subtle gaps.

"HF_HOME": "/tmp",
"HUGGINGFACE_HUB_CACHE": "/tmp",
}

# Add HuggingFace authentication
Expand All @@ -370,6 +372,9 @@ def _build_for_djl(self) -> Model:
# Cache management based on mode
if self.mode in LOCAL_MODES:
self.env_vars.update({"HF_HUB_OFFLINE": "1"})
else:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using setdefault here too: For consistency with the HF_MODEL_ID change, consider using self.env_vars.setdefault("HF_HOME", "/tmp") and self.env_vars.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp") so that if a user explicitly provides these env vars (e.g., pointing to a different writable directory), their values are preserved. The same applies to lines 348-349.

self.env_vars["HF_HOME"] = "/tmp"
self.env_vars["HUGGINGFACE_HUB_CACHE"] = "/tmp"

# GPU-based tensor parallel calculation for SAGEMAKER_ENDPOINT mode
if self.mode == Mode.SAGEMAKER_ENDPOINT:
Expand Down
Empty file.
318 changes: 318 additions & 0 deletions sagemaker-serve/tests/unit/servers/test_djl_hf_cache_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
"""Tests for DJL builder HF cache environment variables and HF_MODEL_ID handling.

Verifies that _build_for_djl() correctly:
- Sets HF_HOME and HUGGINGFACE_HUB_CACHE to /tmp for writable cache
- Preserves user-provided HF_MODEL_ID values (uses setdefault)
- Sets HF_MODEL_ID when not provided by user
- Sets HF_HUB_OFFLINE in local modes
"""

import unittest
from unittest.mock import Mock, patch, MagicMock
import tempfile
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused imports: MagicMock, os, and shutil are imported but MagicMock is never used. os and shutil are only used for temp dir cleanup which pytest's tmp_path fixture handles automatically. Clean up unused imports.

import os
import shutil

from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.core.resources import Model


def _mock_sagemaker_session():
"""Create a mock SageMaker session."""
session = Mock()
session.boto_region_name = "us-east-1"
session.sagemaker_config = {}
session.default_bucket.return_value = "mock-bucket"
session.upload_data.return_value = "s3://mock-bucket/model.tar.gz"
return session


MOCK_ROLE_ARN = "arn:aws:iam::123456789012:role/SageMakerRole"
MOCK_IMAGE_URI = "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.36.0-lmi22.0.0-cu129"
MOCK_HF_MODEL_CONFIG = {"model_type": "gpt2", "architectures": ["GPT2LMHeadModel"]}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded AWS account ID in mock: MOCK_ROLE_ARN contains 123456789012 and MOCK_IMAGE_URI contains a real ECR registry ID (763104351884). While these are mocks, using a clearly fake ECR URI (e.g., 000000000000.dkr.ecr.us-east-1.amazonaws.com/djl-inference:latest) would be more consistent with test standards that avoid real account/region references.


class TestDjlHfCacheEnv(unittest.TestCase):
"""Test DJL builder HF cache environment variable handling."""

def setUp(self):
"""Set up test fixtures."""
self.mock_session = _mock_sagemaker_session()
self.temp_dir = tempfile.mkdtemp()

def tearDown(self):
"""Clean up temp directory."""
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)

@patch('sagemaker.serve.model_builder_servers._get_gpu_info')
@patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree')
@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
def test_build_for_djl_sets_hf_home_to_tmp(
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
mock_validate, mock_auto_detect, mock_prepare, mock_create,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Massive test duplication: Nearly every test method has the same ~30 lines of mock setup and builder construction copied verbatim. Extract the common mock setup and builder creation into a pytest.fixture (or at minimum a helper method). This would reduce the file by ~60% and make it much easier to maintain. For example:

@pytest.fixture
def djl_builder(mock_session, temp_dir):
    # common mock patches and builder setup
    ...
    return builder

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! Only a few tests would suffice. We should strive for test quality over quantity!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests test_build_for_djl_sets_hf_home_to_tmp and test_build_for_djl_sets_huggingface_hub_cache_to_tmp are redundant with test_build_for_djl_with_source_code_and_hf_model_id: The last test already asserts both HF_HOME and HUGGINGFACE_HUB_CACHE. Consider consolidating these three tests into one that checks both env vars, following the "one logical assertion per test" guideline (checking two related env vars from the same operation is one logical assertion).

mock_tp_degree, mock_gpu_info
):
"""Verify HF_HOME=/tmp is set in SAGEMAKER_ENDPOINT mode."""
mock_nb.return_value = None
mock_is_js.return_value = False
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
mock_djl_config.return_value = ({}, 256)
mock_create.return_value = Mock(spec=Model)
mock_prepare.return_value = ("s3://bucket/model", None)
mock_gpu_info.return_value = 4
mock_tp_degree.return_value = 4

builder = ModelBuilder(
model="chromadb/context-1",
role_arn=MOCK_ROLE_ARN,
sagemaker_session=self.mock_session,
model_path=self.temp_dir,
mode=Mode.SAGEMAKER_ENDPOINT,
image_uri=MOCK_IMAGE_URI,
model_server=ModelServer.DJL_SERVING,
instance_type="ml.g6e.12xlarge",
)
builder.schema_builder = Mock()
builder.schema_builder.sample_input = {"inputs": "Hello"}
builder._optimizing = False
builder.hf_model_config = MOCK_HF_MODEL_CONFIG

builder._build_for_djl()

self.assertEqual(builder.env_vars.get("HF_HOME"), "/tmp")

@patch('sagemaker.serve.model_builder_servers._get_gpu_info')
@patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree')
@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
def test_build_for_djl_sets_huggingface_hub_cache_to_tmp(
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
mock_validate, mock_auto_detect, mock_prepare, mock_create,
mock_tp_degree, mock_gpu_info
):
"""Verify HUGGINGFACE_HUB_CACHE=/tmp is set in SAGEMAKER_ENDPOINT mode."""
mock_nb.return_value = None
mock_is_js.return_value = False
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
mock_djl_config.return_value = ({}, 256)
mock_create.return_value = Mock(spec=Model)
mock_prepare.return_value = ("s3://bucket/model", None)
mock_gpu_info.return_value = 4
mock_tp_degree.return_value = 4

builder = ModelBuilder(
model="chromadb/context-1",
role_arn=MOCK_ROLE_ARN,
sagemaker_session=self.mock_session,
model_path=self.temp_dir,
mode=Mode.SAGEMAKER_ENDPOINT,
image_uri=MOCK_IMAGE_URI,
model_server=ModelServer.DJL_SERVING,
instance_type="ml.g6e.12xlarge",
)
builder.schema_builder = Mock()
builder.schema_builder.sample_input = {"inputs": "Hello"}
builder._optimizing = False
builder.hf_model_config = MOCK_HF_MODEL_CONFIG

builder._build_for_djl()

self.assertEqual(builder.env_vars.get("HUGGINGFACE_HUB_CACHE"), "/tmp")

@patch('sagemaker.serve.model_builder_servers._get_gpu_info')
@patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree')
@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
def test_build_for_djl_preserves_user_provided_hf_model_id(
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
mock_validate, mock_auto_detect, mock_prepare, mock_create,
mock_tp_degree, mock_gpu_info
):
"""Verify user-provided HF_MODEL_ID is NOT overridden."""
mock_nb.return_value = None
mock_is_js.return_value = False
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
mock_djl_config.return_value = ({}, 256)
mock_create.return_value = Mock(spec=Model)
mock_prepare.return_value = ("s3://bucket/model", None)
mock_gpu_info.return_value = 4
mock_tp_degree.return_value = 4

builder = ModelBuilder(
model="chromadb/context-1",
role_arn=MOCK_ROLE_ARN,
sagemaker_session=self.mock_session,
model_path=self.temp_dir,
mode=Mode.SAGEMAKER_ENDPOINT,
image_uri=MOCK_IMAGE_URI,
model_server=ModelServer.DJL_SERVING,
instance_type="ml.g6e.12xlarge",
env_vars={"HF_MODEL_ID": "/opt/ml/model"},
)
builder.schema_builder = Mock()
builder.schema_builder.sample_input = {"inputs": "Hello"}
builder._optimizing = False
builder.hf_model_config = MOCK_HF_MODEL_CONFIG

builder._build_for_djl()

# User-provided value should be preserved, NOT overridden by model param
self.assertEqual(builder.env_vars["HF_MODEL_ID"], "/opt/ml/model")

@patch('sagemaker.serve.model_builder_servers._get_gpu_info')
@patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree')
@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
def test_build_for_djl_sets_hf_model_id_when_not_provided(
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
mock_validate, mock_auto_detect, mock_prepare, mock_create,
mock_tp_degree, mock_gpu_info
):
"""Verify HF_MODEL_ID is set from model param when not user-provided."""
mock_nb.return_value = None
mock_is_js.return_value = False
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
mock_djl_config.return_value = ({}, 256)
mock_create.return_value = Mock(spec=Model)
mock_prepare.return_value = ("s3://bucket/model", None)
mock_gpu_info.return_value = 4
mock_tp_degree.return_value = 4

builder = ModelBuilder(
model="chromadb/context-1",
role_arn=MOCK_ROLE_ARN,
sagemaker_session=self.mock_session,
model_path=self.temp_dir,
mode=Mode.SAGEMAKER_ENDPOINT,
image_uri=MOCK_IMAGE_URI,
model_server=ModelServer.DJL_SERVING,
instance_type="ml.g6e.12xlarge",
)
builder.schema_builder = Mock()
builder.schema_builder.sample_input = {"inputs": "Hello"}
builder._optimizing = False
builder.hf_model_config = MOCK_HF_MODEL_CONFIG

builder._build_for_djl()

# When no user-provided HF_MODEL_ID, it should be set from model param
self.assertEqual(builder.env_vars["HF_MODEL_ID"], "chromadb/context-1")

@patch('sagemaker.serve.model_builder_servers._get_gpu_info')
@patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree')
@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
def test_build_for_djl_with_source_code_and_hf_model_id(
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
mock_validate, mock_auto_detect, mock_prepare, mock_create,
mock_tp_degree, mock_gpu_info
):
"""Verify HF cache env vars are set to /tmp when source_code is provided.

This is the key scenario from the bug: source_code makes /opt/ml/model
read-only, so HF cache must be redirected to /tmp.
"""
mock_nb.return_value = None
mock_is_js.return_value = False
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
mock_djl_config.return_value = ({}, 256)
mock_create.return_value = Mock(spec=Model)
mock_prepare.return_value = ("s3://bucket/model", None)
mock_gpu_info.return_value = 4
mock_tp_degree.return_value = 4

builder = ModelBuilder(
model="chromadb/context-1",
role_arn=MOCK_ROLE_ARN,
sagemaker_session=self.mock_session,
model_path=self.temp_dir,
mode=Mode.SAGEMAKER_ENDPOINT,
image_uri=MOCK_IMAGE_URI,
model_server=ModelServer.DJL_SERVING,
instance_type="ml.g6e.12xlarge",
)
builder.schema_builder = Mock()
builder.schema_builder.sample_input = {"inputs": "Hello"}
builder._optimizing = False
builder.hf_model_config = MOCK_HF_MODEL_CONFIG

builder._build_for_djl()

# HF cache should be redirected to /tmp to avoid read-only /opt/ml/model
self.assertEqual(builder.env_vars.get("HF_HOME"), "/tmp")
self.assertEqual(builder.env_vars.get("HUGGINGFACE_HUB_CACHE"), "/tmp")

@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
def test_build_for_djl_local_mode_sets_hf_hub_offline(
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
mock_validate, mock_auto_detect, mock_prepare, mock_create
):
"""Verify HF_HUB_OFFLINE=1 is set in LOCAL_CONTAINER mode."""
mock_nb.return_value = None
mock_is_js.return_value = False
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
mock_djl_config.return_value = ({}, 256)
mock_create.return_value = Mock(spec=Model)

builder = ModelBuilder(
model="chromadb/context-1",
role_arn=MOCK_ROLE_ARN,
sagemaker_session=self.mock_session,
model_path=self.temp_dir,
mode=Mode.LOCAL_CONTAINER,
image_uri=MOCK_IMAGE_URI,
model_server=ModelServer.DJL_SERVING,
)
builder.schema_builder = Mock()
builder.schema_builder.sample_input = {"inputs": "Hello"}
builder._optimizing = False
builder.hf_model_config = MOCK_HF_MODEL_CONFIG

builder._build_for_djl()

self.assertEqual(builder.env_vars.get("HF_HUB_OFFLINE"), "1")


if __name__ == "__main__":
unittest.main()
Loading