Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sagemaker-core/src/sagemaker/core/training/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ class SourceCode(BaseConfig):
ignore_patterns: (Optional[List[str]]) :
The ignore patterns to ignore specific files/folders when uploading to S3. If not specified,
default to: ['.env', '.git', '__pycache__', '.DS_Store', '.cache', '.ipynb_checkpoints'].
dependencies (Optional[List[str]]):
A list of paths to local directories (absolute or relative) containing additional
libraries that will be copied into the training container and added to PYTHONPATH.
Each path must be a valid local directory or file.
"""

source_dir: Optional[StrPipeVar] = None
Expand All @@ -123,6 +127,7 @@ class SourceCode(BaseConfig):
".cache",
".ipynb_checkpoints",
]
dependencies: Optional[List[str]] = None

class OutputDataConfig(shapes.OutputDataConfig):
"""OutputDataConfig.
Expand Down
3 changes: 3 additions & 0 deletions sagemaker-train/src/sagemaker/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
"amazon.nova-pro-v1:0": ["us-east-1"]
}

SM_DEPENDENCIES = "sm_dependencies"
SM_DEPENDENCIES_CONTAINER_PATH = "/opt/ml/input/data/sm_dependencies"

SM_RECIPE = "recipe"
SM_RECIPE_YAML = "recipe.yaml"
SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}"
33 changes: 33 additions & 0 deletions sagemaker-train/src/sagemaker/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
SM_CODE_CONTAINER_PATH,
SM_DRIVERS,
SM_DRIVERS_LOCAL_PATH,
SM_DEPENDENCIES,
SM_DEPENDENCIES_CONTAINER_PATH,
SM_RECIPE,
SM_RECIPE_YAML,
SM_RECIPE_CONTAINER_PATH,
Expand All @@ -99,6 +101,7 @@
EXECUTE_BASIC_SCRIPT_DRIVER,
INSTALL_AUTO_REQUIREMENTS,
INSTALL_REQUIREMENTS,
INSTALL_DEPENDENCIES,
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
Expand Down Expand Up @@ -484,6 +487,13 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
f"Invalid 'entry_script': {entry_script}. "
"Must be a valid file within the 'source_dir'.",
)
if source_code.dependencies:
for dep_path in source_code.dependencies:
if not _is_valid_path(dep_path):
raise ValueError(
f"Invalid dependency path: {dep_path}. "
"Each dependency must be a valid local directory or file path."
)

@staticmethod
def _validate_and_fetch_hyperparameters_file(hyperparameters_file: str):
Expand Down Expand Up @@ -654,6 +664,24 @@ def _create_training_job_args(
)
final_input_data_config.append(source_code_channel)

# If dependencies are provided, create a channel for the dependencies
# The dependencies will be mounted at /opt/ml/input/data/sm_dependencies
if self.source_code.dependencies:
deps_tmp_dir = TemporaryDirectory()
for dep_path in self.source_code.dependencies:
dep_basename = os.path.basename(os.path.normpath(dep_path))
dest_path = os.path.join(deps_tmp_dir.name, dep_basename)
if os.path.isdir(dep_path):
shutil.copytree(dep_path, dest_path, dirs_exist_ok=True)
else:
shutil.copy2(dep_path, dest_path)
dependencies_channel = self.create_input_data_channel(
channel_name=SM_DEPENDENCIES,
data_source=deps_tmp_dir.name,
key_prefix=input_data_key_prefix,
)
final_input_data_config.append(dependencies_channel)

self._prepare_train_script(
tmp_dir=self._temp_code_dir,
source_code=self.source_code,
Expand Down Expand Up @@ -1010,6 +1038,10 @@ def _prepare_train_script(
base_command = source_code.command.split()
base_command = " ".join(base_command)

install_dependencies = ""
if source_code.dependencies:
install_dependencies = INSTALL_DEPENDENCIES

install_requirements = ""
if source_code.requirements:
if self._jumpstart_config and source_code.requirements == "auto":
Expand Down Expand Up @@ -1049,6 +1081,7 @@ def _prepare_train_script(

train_script = TRAIN_SCRIPT_TEMPLATE.format(
working_dir=working_dir,
install_dependencies=install_dependencies,
install_requirements=install_requirements,
execute_driver=execute_driver,
)
Expand Down
15 changes: 15 additions & 0 deletions sagemaker-train/src/sagemaker/train/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@
$SM_PIP_CMD install -r {requirements_file}
"""

INSTALL_DEPENDENCIES = """
echo "Setting up additional dependencies"
if [ -d /opt/ml/input/data/sm_dependencies ]; then
for dep_dir in /opt/ml/input/data/sm_dependencies/*/; do
if [ -d "$dep_dir" ]; then
echo "Adding $dep_dir to PYTHONPATH"
export PYTHONPATH="$dep_dir:$PYTHONPATH"
fi
done
# Also add the root dependencies dir in case of single files
export PYTHONPATH="/opt/ml/input/data/sm_dependencies:$PYTHONPATH"
fi
"""

EXEUCTE_DISTRIBUTED_DRIVER = """
echo "Running {driver_name} Driver"
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/{driver_script}
Expand Down Expand Up @@ -95,6 +109,7 @@
set -x

{working_dir}
{install_dependencies}
{install_requirements}
{execute_driver}

Expand Down
Loading
Loading