Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions sagemaker-train/src/sagemaker/train/tuner.py
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix the unit tests based on the CI failures?

$context sagemaker-train/tests/unit/train/test_tuner.py
$context sagemaker-train/tests/unit/train/test_tuner_driver_channels.py

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can ignore the integration test failures!

Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,11 @@ def _build_training_job_definition(self, inputs):
model_trainer.stopping_condition.max_wait_time_in_seconds
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getattr(model_trainer, "environment", None) suggests environment might not exist on ModelTrainer, but it's a defined attribute on the class. Using model_trainer.environment directly would be more idiomatic and consistent with how other attributes (e.g., model_trainer.role, model_trainer.compute) are accessed in this same method. If environment is always defined on ModelTrainer (even if None), prefer:

env = model_trainer.environment
if not env or not isinstance(env, dict):
    env = None

Alternatively, if there's a concern about backward compatibility with mock objects in tests, that's a test issue, not a production code concern.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description states: "Similarly, for the multi-trainer dict path (_build_training_job_definitions), environment is also not propagated. The fix is to read model_trainer.environment in both _build_training_job_definition and _build_training_job_definitions methods." However, this diff only modifies _build_training_job_definition (singular). The multi-trainer path _build_training_job_definitions (plural) does not appear to be fixed. Is this an oversight, or was the description inaccurate? If the multi-trainer path also has this bug, it should be fixed in this PR as well.

# Get environment variables from model_trainer
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The truthiness check not env will coerce an empty dict {} to True, causing it to be set to None. While the test covers this, silently converting a user-provided {} to None could be surprising. Consider whether an empty dict should be passed through as-is (the API would accept it), or if this is intentional. If intentional, a brief comment explaining why empty dicts are normalized to None would help future maintainers.

Also, the isinstance(env, dict) check is defensive — if ModelTrainer.environment has a type annotation of dict | None, Pydantic validation should already enforce this. Is this guard necessary?

env = getattr(model_trainer, "environment", None)
if not env or not isinstance(env, dict):
env = None

definition = HyperParameterTrainingJobDefinition(
algorithm_specification=algorithm_spec,
role_arn=model_trainer.role,
Expand All @@ -1513,13 +1518,9 @@ def _build_training_job_definition(self, inputs):
stopping_condition=stopping_condition,
static_hyper_parameters=getattr(self, "static_hyperparameters", None) or {},
enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training,
environment=env,
)

# Pass through environment variables from model_trainer
env = getattr(model_trainer, "environment", None)
if env and isinstance(env, dict):
definition.environment = env

# Pass through VPC config from model_trainer
networking = getattr(model_trainer, "networking", None)
if networking and hasattr(networking, "_to_vpc_config"):
Expand Down
58 changes: 58 additions & 0 deletions sagemaker-train/tests/unit/train/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,61 @@ def test_build_training_job_definition_includes_spot_params(self):
assert isinstance(
definition.stopping_condition.max_wait_time_in_seconds, int
), "Max wait time should be set"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good test coverage for the happy path, None, and empty dict cases. However, consider adding a test for the multi-trainer path (_build_training_job_definitions) as well, since the PR description mentions it should also be fixed. If that method is not being changed, the test would at least document the current (potentially broken) behavior.

def test_build_training_job_definition_includes_environment_variables(self):
"""Test that _build_training_job_definition includes environment variables.

This test verifies the fix for GitHub issue #5613 where tuning jobs were
missing environment variables that were set on the ModelTrainer.
"""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = {
"FOO": "bar",
"RANDOM_STATE": "42",
}

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment is not None, "Environment should not be None"
assert definition.environment == {
"FOO": "bar",
"RANDOM_STATE": "42",
}, "Environment variables should match those set on ModelTrainer"

def test_build_training_job_definition_with_none_environment(self):
"""Test that _build_training_job_definition handles None environment gracefully."""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = None

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment is None, "Environment should be None when not set"

def test_build_training_job_definition_with_empty_environment(self):
"""Test that _build_training_job_definition handles empty environment gracefully."""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = {}

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment is None, (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The assertion message says "Environment should be None when empty dict is provided" — this documents the behavior but consider whether this is actually the desired UX. A user who explicitly sets environment={} might not expect it to be silently dropped. If this is intentional, it's fine, but worth confirming with the team.

"Environment should be None when empty dict is provided"
)
Loading