Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions scripts/performance/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,14 @@ def parse_cli_args():
help="Comma separated string of environment variables",
default={},
)
slurm_args.add_argument(
"--container_env",
type=list_of_strings,
metavar="KEY[,KEY2,...]",
help="Comma-separated list of environment variable names that should override same-named "
"values from the container image. Use -E/--env or -ce/--custom_env_vars to set the value explicitly.",
default=[],
)
slurm_args.add_argument(
"-E",
"--env",
Expand Down
3 changes: 3 additions & 0 deletions scripts/performance/setup_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def main(
dgxc_pvc_mount_path: str,
config_variant: str = "v1",
gres: Optional[str] = None,
container_env: Optional[List[str]] = None,
):
"""Sets up the experiment and runs it."""
if (
Expand Down Expand Up @@ -329,6 +330,7 @@ def main(
nemo_home=nemo_home,
additional_slurm_params=additional_slurm_params,
wandb_key=wandb_key,
container_env=container_env or [],
)
else:
executor = dgxc_executor(
Expand Down Expand Up @@ -668,4 +670,5 @@ def main(
dgxc_pvc_mount_path=args.dgxc_pvc_mount_path,
config_variant=config_variant,
gres=args.gres,
container_env=args.container_env,
)
18 changes: 11 additions & 7 deletions scripts/performance/utils/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def slurm_executor(
custom_mounts: List[str] = [],
custom_env_vars: Dict[str, str] = {},
custom_srun_args: List[str] = [],
container_env: List[str] = [],
hf_token: str = None,
nemo_home: str = DEFAULT_NEMO_HOME,
wandb_key: str = None,
Expand Down Expand Up @@ -96,20 +97,22 @@ def slurm_executor(
f"Logs will be written to {get_nemorun_home()}, which is probably not desired. export NEMORUN_HOME in your shell environment or use the --log_dir argument"
)

perf_env = PERF_ENV_VARS.copy()

if wandb_key is not None:
PERF_ENV_VARS["WANDB_API_KEY"] = wandb_key
perf_env["WANDB_API_KEY"] = wandb_key

if gpu.lower() == "gb200":
PERF_ENV_VARS["NCCL_NET_GDR_LEVEL"] = "PHB" # For NCCL 2.25
PERF_ENV_VARS["NCCL_NET_GDR_C2C"] = "1" # For NCCL 2.26
perf_env["NCCL_NET_GDR_LEVEL"] = "PHB" # For NCCL 2.25
perf_env["NCCL_NET_GDR_C2C"] = "1" # For NCCL 2.26

if nemo_home != DEFAULT_NEMO_CACHE_HOME: # DO NOT change this to 'DEFAULT_NEMO_HOME'/'NEMO_HOME'
PERF_ENV_VARS["NEMO_HOME"] = nemo_home
perf_env["NEMO_HOME"] = nemo_home
mounts.extend([f"{nemo_home}:{nemo_home}"])
if hf_token is not None:
PERF_ENV_VARS.update({"HF_TOKEN": hf_token, "TRANSFORMERS_OFFLINE": "0"})
perf_env.update({"HF_TOKEN": hf_token, "TRANSFORMERS_OFFLINE": "0"})

PERF_ENV_VARS.update(custom_env_vars)
perf_env.update(custom_env_vars)
mounts.extend(custom_mounts)

# add --segment flag to sbatch if job uses GB200.
Expand Down Expand Up @@ -143,7 +146,8 @@ def slurm_executor(
gres=gres,
container_image=container_image,
container_mounts=mounts,
env_vars=PERF_ENV_VARS,
env_vars=perf_env,
container_env=sorted(set(perf_env.keys()) | set(container_env)),
srun_args=srun_args,
time=time_limit,
mem="0",
Expand Down
70 changes: 70 additions & 0 deletions tests/unit_tests/scripts/performance/test_executors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for scripts/performance/utils/executors.py — container_env on SlurmExecutor."""

import sys
from pathlib import Path

import pytest

# scripts/performance is not an installed package; add it to sys.path so we
# can import ``utils.executors`` the same way the scripts themselves do.
_PERF_SCRIPTS_DIR = Path(__file__).resolve().parents[4] / "scripts" / "performance"
if str(_PERF_SCRIPTS_DIR) not in sys.path:
sys.path.insert(0, str(_PERF_SCRIPTS_DIR))

try:
import nemo_run # noqa: F401

HAS_NEMO_RUN = True
except ImportError:
HAS_NEMO_RUN = False

if HAS_NEMO_RUN:
from utils.executors import PERF_ENV_VARS, slurm_executor


@pytest.mark.skipif(not HAS_NEMO_RUN, reason="nemo_run not installed")
def test_container_env_includes_perf_vars(tmp_path):
"""PERF_ENV_VARS keys must appear in container_env so they override container defaults."""
executor = slurm_executor(
gpu="h100", account="test", partition="test",
log_dir=str(tmp_path), nodes=1, num_gpus_per_node=8,
)
assert executor.container_env is not None, "container_env is None — was the field removed from the executor?"
missing = set(PERF_ENV_VARS) - set(executor.container_env)
assert not missing, f"PERF_ENV_VARS keys missing from container_env: {missing}"


@pytest.mark.skipif(not HAS_NEMO_RUN, reason="nemo_run not installed")
def test_custom_env_vars_in_container_env(tmp_path):
"""Vars passed via custom_env_vars must also appear in container_env."""
executor = slurm_executor(
gpu="h100", account="test", partition="test",
log_dir=str(tmp_path), nodes=1, num_gpus_per_node=8,
custom_env_vars={"MY_CUSTOM_VAR": "1"},
)
assert "MY_CUSTOM_VAR" in executor.container_env


@pytest.mark.skipif(not HAS_NEMO_RUN, reason="nemo_run not installed")
def test_container_env_param_forwarded(tmp_path):
"""Keys passed via the container_env parameter must appear in container_env."""
executor = slurm_executor(
gpu="h100", account="test", partition="test",
log_dir=str(tmp_path), nodes=1, num_gpus_per_node=8,
container_env=["UPSTREAM_SET_VAR"],
)
assert "UPSTREAM_SET_VAR" in executor.container_env
Loading