Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions examples/converters/convert_dcp_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import argparse

import os
import yaml

from nemo_rl.utils.native_checkpoint import convert_dcp_to_hf
Expand Down Expand Up @@ -50,13 +50,15 @@ def main():
config = yaml.safe_load(f)

model_name_or_path = config["policy"]["model_name"]
# TODO: After the following PR gets merged:
# https://github.com/NVIDIA-NeMo/RL/pull/148/files
# tokenizer should be copied from policy/tokenizer/* instead of relying on the model name
# We can expose a arg at the top level --tokenizer_path to plumb that through.
# This is more stable than relying on the current NeMo-RL get_tokenizer() which can
# change release to release.
tokenizer_name_or_path = config["policy"]["model_name"]

# Some algorithms may change the tokenizer property at runtime.
# The train loop ensures dcp_ckpt_path is policy/weights/ and tokenizer files live under policy/tokenizer.
if os.path.exists(tokenizer_path := os.path.join(args.dcp_ckpt_path, "..", "tokenizer")):
print(f"Using local tokenizer path at {tokenizer_path} for HF conversion")
tokenizer_name_or_path = tokenizer_path
else:
print(f"WARNING: No local tokenizer path found at {tokenizer_path}. Falling back to loading the vanilla tokenizer based on the config file. Please ensure this is what you want.")
tokenizer_name_or_path = config["policy"]["tokenizer"]["name"]
hf_overrides = config["policy"].get("hf_overrides", {}) or {}

hf_ckpt = convert_dcp_to_hf(
Expand Down
Loading