-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: ModelTrainer and HyperparameterTuner missing environment variables (5613) #5725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1504,6 +1504,11 @@ def _build_training_job_definition(self, inputs): | |
| model_trainer.stopping_condition.max_wait_time_in_seconds | ||
| ) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
env = model_trainer.environment
if not env or not isinstance(env, dict):
env = NoneAlternatively, if there's a concern about backward compatibility with mock objects in tests, that's a test issue, not a production code concern.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PR description states: "Similarly, for the multi-trainer dict path ( |
||
| # Get environment variables from model_trainer. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The old code used However, the PR description mentions that |
||
| # environment is a defined attribute on ModelTrainer (dict | None). | ||
| # We pass it through as-is; even an empty dict is valid for the API. | ||
| env = model_trainer.environment | ||
|
|
||
| definition = HyperParameterTrainingJobDefinition( | ||
| algorithm_specification=algorithm_spec, | ||
| role_arn=model_trainer.role, | ||
|
|
@@ -1513,13 +1518,9 @@ def _build_training_job_definition(self, inputs): | |
| stopping_condition=stopping_condition, | ||
| static_hyper_parameters=getattr(self, "static_hyperparameters", None) or {}, | ||
| enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training, | ||
| environment=env, | ||
| ) | ||
|
|
||
| # Pass through environment variables from model_trainer | ||
| env = getattr(model_trainer, "environment", None) | ||
| if env and isinstance(env, dict): | ||
| definition.environment = env | ||
|
|
||
| # Pass through VPC config from model_trainer | ||
| networking = getattr(model_trainer, "networking", None) | ||
| if networking and hasattr(networking, "_to_vpc_config"): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -596,3 +596,65 @@ def test_build_training_job_definition_includes_spot_params(self): | |
| assert isinstance( | ||
| definition.stopping_condition.max_wait_time_in_seconds, int | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ), "Max wait time should be set" | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good test coverage for the happy path, |
||
| def test_build_training_job_definition_includes_environment_variables(self): | ||
| """Test that _build_training_job_definition includes environment variables. | ||
|
|
||
| This test verifies the fix for GitHub issue #5613 where tuning jobs were | ||
| missing environment variables that were set on the ModelTrainer. | ||
| """ | ||
| mock_trainer = _create_mock_model_trainer() | ||
| mock_trainer.environment = { | ||
| "FOO": "bar", | ||
| "RANDOM_STATE": "42", | ||
| } | ||
|
|
||
| tuner = HyperparameterTuner( | ||
| model_trainer=mock_trainer, | ||
| objective_metric_name="accuracy", | ||
| hyperparameter_ranges=_create_single_hp_range(), | ||
| ) | ||
|
|
||
| definition = tuner._build_training_job_definition(None) | ||
|
|
||
| assert definition.environment is not None, "Environment should not be None" | ||
| assert definition.environment == { | ||
| "FOO": "bar", | ||
| "RANDOM_STATE": "42", | ||
| }, "Environment variables should match those set on ModelTrainer" | ||
|
|
||
| def test_build_training_job_definition_with_none_environment(self): | ||
| """Test that _build_training_job_definition handles None environment gracefully.""" | ||
| mock_trainer = _create_mock_model_trainer() | ||
| mock_trainer.environment = None | ||
|
|
||
| tuner = HyperparameterTuner( | ||
| model_trainer=mock_trainer, | ||
| objective_metric_name="accuracy", | ||
| hyperparameter_ranges=_create_single_hp_range(), | ||
| ) | ||
|
|
||
| definition = tuner._build_training_job_definition(None) | ||
|
|
||
| assert definition.environment is None, "Environment should be None when not set" | ||
|
|
||
| def test_build_training_job_definition_with_empty_environment(self): | ||
| """Test that _build_training_job_definition passes through empty environment. | ||
|
|
||
| An empty dict is valid for the SageMaker API, so we pass it through as-is | ||
| rather than silently converting it to None. | ||
| """ | ||
| mock_trainer = _create_mock_model_trainer() | ||
| mock_trainer.environment = {} | ||
|
|
||
| tuner = HyperparameterTuner( | ||
| model_trainer=mock_trainer, | ||
| objective_metric_name="accuracy", | ||
| hyperparameter_ranges=_create_single_hp_range(), | ||
| ) | ||
|
|
||
| definition = tuner._build_training_job_definition(None) | ||
|
|
||
| assert definition.environment == {}, ( | ||
| "Empty dict environment should be passed through as-is" | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you fix the unit tests based on the CI failures?
$context sagemaker-train/tests/unit/train/test_tuner.py
$context sagemaker-train/tests/unit/train/test_tuner_driver_channels.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can ignore the integration test failures!