Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion docs/design-docs/checkpointing.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,39 @@
# Exporting Checkpoints to Hugging Face Format
# Checkpointing

## Resume Behavior

Training uses the root `checkpointing.checkpoint_dir` as the active checkpoint namespace.
By default, if root `step_<N>` checkpoints already exist, NeMo RL resumes from the latest one:

```yaml
checkpointing:
enabled: true
resume_if_exists: true
checkpoint_dir: results/sft
```

If you want to reuse the same `checkpoint_dir` for a fresh run instead, set:

```yaml
checkpointing:
enabled: true
resume_if_exists: false
checkpoint_dir: results/sft
```

In that mode, existing root checkpoints are archived under `run_<N>/` and the new run starts from scratch:

```text
results/sft/
run_0/
step_1/
step_2/
step_1/
```

Archived checkpoints are still valid if you want to inspect or convert them later by path.

## Exporting Checkpoints to Hugging Face Format

NeMo RL provides two checkpoint formats for Hugging Face models: Torch distributed and Hugging Face format. Torch distributed is used by default for efficiency, and Hugging Face format is provided for compatibility with Hugging Face's `AutoModel.from_pretrained` API. Note that Hugging Face format checkpoints save only the model weights, ignoring the optimizer states. It is recommended to use Torch distributed format to save intermediate checkpoints and to save a Hugging Face checkpoint only at the end of training.

Expand Down
1 change: 1 addition & 0 deletions examples/configs/distillation_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ loss_fn:

checkpointing:
enabled: true
resume_if_exists: true
checkpoint_dir: "checkpoints/distillation-${policy.model_name}"
metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name
higher_is_better: true
Expand Down
1 change: 1 addition & 0 deletions examples/configs/distillation_math_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ defaults: distillation_math.yaml

checkpointing:
checkpoint_dir: "checkpoints/distillation-megatron-${policy.model_name}"
resume_if_exists: true

policy: &POLICY_BASE
model_name: "Qwen/Qwen3-1.7B-Base"
Expand Down
1 change: 1 addition & 0 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dpo:

checkpointing:
enabled: true
resume_if_exists: true
checkpoint_dir: "results/dpo"
metric_name: "val:validation-default_loss"
higher_is_better: false
Expand Down
1 change: 1 addition & 0 deletions examples/configs/gdpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:

checkpointing:
checkpoint_dir: "results/gdpo"
resume_if_exists: true

policy:
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ loss_fn:

checkpointing:
enabled: true
resume_if_exists: true
checkpoint_dir: "results/grpo"
metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name
higher_is_better: true
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ loss_fn:

checkpointing:
enabled: false
resume_if_exists: true
checkpoint_dir: "results/grpo_megatron"
metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name
higher_is_better: true
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_8B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ grpo:

checkpointing:
enabled: false
resume_if_exists: true
checkpoint_dir: "results/grpo_8b_megatron"
checkpoint_must_save_by: null

Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_sliding_puzzle.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ grpo:

checkpointing:
enabled: true
resume_if_exists: true
checkpoint_dir: "results/grpo-sliding-puzzle"
metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name
higher_is_better: true
Expand Down
1 change: 1 addition & 0 deletions examples/configs/prorlv2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ loss_fn:
# ============================================================================
checkpointing:
checkpoint_dir: "results/prorl"
resume_if_exists: true

logger:
log_dir: "logs/prorl"
1 change: 1 addition & 0 deletions examples/configs/rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ rm:

checkpointing:
enabled: true
resume_if_exists: true
checkpoint_dir: "results/rm"
metric_name: "val:validation-default_loss" # one of "val:" or "train:" followed by the metric name
higher_is_better: false
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ sft:

checkpointing:
enabled: true
resume_if_exists: true
checkpoint_dir: "results/sft"
metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name
higher_is_better: false
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft_openmathinstruct2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ sft:
val_micro_batch_size: 2
checkpointing:
checkpoint_dir: results/sft_openmathinstruct2
resume_if_exists: true
keep_top_k: 100
save_period: 500
policy:
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft_vlm_3B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ policy:

checkpointing:
enabled: true
resume_if_exists: true
checkpoint_dir: "results/sft_${policy.model_name}"
metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name
higher_is_better: false
Expand Down
1 change: 1 addition & 0 deletions examples/configs/vlm_grpo_3B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:

checkpointing:
checkpoint_dir: results/clevr_grpo_${policy.model_name}
resume_if_exists: true

policy:
model_name: Qwen/Qwen2.5-VL-3B-Instruct
Expand Down
1 change: 1 addition & 0 deletions examples/nemo_gym/grpo_nanov3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ loss_fn:

checkpointing:
enabled: true
resume_if_exists: true
checkpoint_dir: "results/grpo"
metric_name: "val:total_reward/mean"
higher_is_better: true
Expand Down
1 change: 1 addition & 0 deletions examples/nemo_gym/grpo_qwen3_30ba3b_instruct.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,6 @@ checkpointing:
# 1. For this config Qwen 3 30BA3B on math with 32k context length, the validation could take up to 10 mins.
# 3. The step time for this config on 32 nodes takes around 30 mins.
# 4. The checkpoint time for this model is around 10 mins.
resume_if_exists: true
checkpoint_must_save_by: "00:03:30:00"
save_period: 1
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ loss_fn:

checkpointing:
enabled: true
resume_if_exists: true
checkpoint_dir: "results/grpo"
metric_name: "val:accuracy"
higher_is_better: true
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def setup(
# Checkpointing
# ==========================
checkpointer = CheckpointManager(master_config["checkpointing"])
last_checkpoint_path = checkpointer.get_latest_checkpoint_path()
last_checkpoint_path = checkpointer.resolve_training_start_checkpoint()
distillation_save_state: Optional[DistillationSaveState] = cast(
Optional[DistillationSaveState],
checkpointer.load_training_info(last_checkpoint_path),
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def setup(
# Checkpointing
# ==========================
checkpointer = CheckpointManager(master_config["checkpointing"])
last_checkpoint_path = checkpointer.get_latest_checkpoint_path()
last_checkpoint_path = checkpointer.resolve_training_start_checkpoint()
dpo_save_state: Optional[DPOSaveState] = cast(
Optional[DPOSaveState], checkpointer.load_training_info(last_checkpoint_path)
)
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def setup(
# Checkpointing
# ==========================
checkpointer = CheckpointManager(master_config["checkpointing"])
last_checkpoint_path = checkpointer.get_latest_checkpoint_path()
last_checkpoint_path = checkpointer.resolve_training_start_checkpoint()
grpo_save_state: Optional[GRPOSaveState] = cast(
Optional[GRPOSaveState], checkpointer.load_training_info(last_checkpoint_path)
)
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/algorithms/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def setup(
# Checkpointing
# ==========================
checkpointer = CheckpointManager(master_config["checkpointing"])
last_checkpoint_path = checkpointer.get_latest_checkpoint_path()
last_checkpoint_path = checkpointer.resolve_training_start_checkpoint()
rm_save_state: Optional[RMSaveState] = checkpointer.load_training_info(
last_checkpoint_path
)
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def setup(
# Checkpointing
# ==========================
checkpointer = CheckpointManager(master_config["checkpointing"])
last_checkpoint_path = checkpointer.get_latest_checkpoint_path()
last_checkpoint_path = checkpointer.resolve_training_start_checkpoint()
sft_save_state: Optional[SFTSaveState] = cast(
Optional[SFTSaveState], checkpointer.load_training_info(last_checkpoint_path)
)
Expand Down
74 changes: 61 additions & 13 deletions nemo_rl/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class CheckpointingConfig(TypedDict):

Attributes:
enabled (bool): Whether checkpointing is enabled.
resume_if_exists (bool): Whether to resume from an existing root checkpoint if one exists.
If False, existing root step_* checkpoints are moved under run_<N>/ and training
starts from scratch.
checkpoint_dir (PathLike): Directory where checkpoints will be saved.
metric_name (str | None): Name of the metric to use for determining best checkpoints.
Must be of the form "val:<metric_name>" or "train:<metric_name>" to indicate whether
Expand All @@ -53,6 +56,7 @@ class CheckpointingConfig(TypedDict):
"""

enabled: bool
resume_if_exists: bool
checkpoint_dir: PathLike
metric_name: str | None
higher_is_better: bool
Expand Down Expand Up @@ -86,6 +90,9 @@ class CheckpointManager:
...
step_1/
...
run_0/
step_0/
...
```

Attributes: Derived from the CheckpointingConfig.
Expand All @@ -97,6 +104,8 @@ def __init__(self, config: CheckpointingConfig):
Args:
config (CheckpointingConfig)
"""
self.enabled = config["enabled"]
self.resume_if_exists = config["resume_if_exists"]
self.checkpoint_dir = Path(config["checkpoint_dir"])
self.metric_name: str | None = config["metric_name"]
self.higher_is_better = config["higher_is_better"]
Expand All @@ -110,6 +119,45 @@ def __init__(self, config: CheckpointingConfig):
self.model_repo_id = config.get("model_repo_id", "")
self.is_peft = config.get("is_peft", False)

def _get_next_run_archive_dir(self) -> Path:
"""Get the next run_<N> directory for archiving an old root checkpoint set."""
next_run_idx = 0
for path in self.checkpoint_dir.glob("run_*"):
match = re.fullmatch(r"run_(\d+)", path.name)
if match:
next_run_idx = max(next_run_idx, int(match.group(1)) + 1)
return self.checkpoint_dir / f"run_{next_run_idx}"

def _archive_root_checkpoints(self, step_dirs: list[Path]) -> Path:
"""Archive the active root checkpoints under the next run_<N> directory."""
archive_dir = self._get_next_run_archive_dir()
archive_dir.mkdir(parents=True, exist_ok=False)

for step_dir in step_dirs:
step_dir.rename(archive_dir / step_dir.name)

return archive_dir

def resolve_training_start_checkpoint(self) -> Optional[str]:
"""Resolve the checkpoint path to resume from at training startup."""
if not self.enabled:
return None

step_dirs = _get_root_step_dirs(self.checkpoint_dir)
if len(step_dirs) == 0:
return None

if self.resume_if_exists:
return str(step_dirs[-1])

archive_dir = self._archive_root_checkpoints(step_dirs)
print(
f"Archived {len(step_dirs)} checkpoint(s) from {self.checkpoint_dir} to {archive_dir} "
"because checkpointing.resume_if_exists is false. Starting from scratch.",
flush=True,
)
return None

@staticmethod
def get_resume_paths(
last_checkpoint_path: Optional[PathLike],
Expand Down Expand Up @@ -319,13 +367,7 @@ def get_latest_checkpoint_path(self) -> Optional[str]:
Returns:
Optional[str]: Path to the latest checkpoint, or None if no checkpoints exist.
"""
# find checkpoint directory with highest step number
step_dirs = [
x
for x in glob.glob(str(self.checkpoint_dir / "step_*"))
if re.fullmatch(r"step_\d+", Path(x).name)
]
step_dirs.sort(key=lambda x: int(Path(x).name.split("_")[1]))
step_dirs = _get_root_step_dirs(self.checkpoint_dir)
if len(step_dirs) == 0:
return None
return str(step_dirs[-1])
Expand Down Expand Up @@ -363,12 +405,7 @@ def _load_checkpoint_history(
"""
checkpoint_history: list[tuple[int, PathLike, dict[str, Any]]] = []

# Find all step directories
step_dirs = [
x
for x in glob.glob(str(checkpoint_dir / "step_*"))
if re.fullmatch(r"step_\d+", Path(x).name)
]
step_dirs = _get_root_step_dirs(checkpoint_dir)

for step_dir in step_dirs:
info_file = Path(step_dir) / "training_info.json"
Expand All @@ -379,3 +416,14 @@ def _load_checkpoint_history(
checkpoint_history.append((step, step_dir, info))

return checkpoint_history


def _get_root_step_dirs(checkpoint_dir: Path) -> list[Path]:
"""Get the active root step_<N> checkpoint directories sorted by step number."""
step_dirs = [
Path(x)
for x in glob.glob(str(checkpoint_dir / "step_*"))
if re.fullmatch(r"step_\d+", Path(x).name)
]
step_dirs.sort(key=lambda path: int(path.name.split("_")[1]))
return step_dirs
Loading
Loading