Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ def main_parser() -> argparse.ArgumentParser:
help="(Supported backend: PyTorch) Model branch chosen for fine-tuning if multi-task. If not specified, it will re-init the fitting net.",
)
parser_train.add_argument(
"--force-load",
"--allow-ref",
action="store_true",
help="(Supported backend: PyTorch) Force load from ckpt, other missing tensors will init from scratch",
help="Allow loading external JSON/YAML snippets through `$ref`. Disabled by default for security.",
)

# * freeze script ******************************************************************
Expand Down
12 changes: 11 additions & 1 deletion deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,16 @@ def train(
use_pretrain_script: bool = False,
force_load: bool = False,
output: str = "out.json",
allow_ref: bool = False,
) -> None:
"""Train a model with Paddle backend.

Parameters
----------
allow_ref : bool, default=False
Whether to allow loading external JSON/YAML snippets via ``$ref``
in the training input. Disabled by default for security.
"""
log.info("Configuration path: %s", input_file)
if LOCAL_RANK == 0:
SummaryPrinter()()
Expand Down Expand Up @@ -292,7 +301,7 @@ def train(

# argcheck
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config, multi_task=multi_task)
config = normalize(config, multi_task=multi_task, allow_ref=allow_ref)

# do neighbor stat
min_nbor_dist = None
Expand Down Expand Up @@ -600,6 +609,7 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
use_pretrain_script=FLAGS.use_pretrain_script,
force_load=FLAGS.force_load,
output=FLAGS.output,
allow_ref=FLAGS.allow_ref,
)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
Expand Down
12 changes: 11 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,16 @@ def train(
use_pretrain_script: bool = False,
force_load: bool = False,
output: str = "out.json",
allow_ref: bool = False,
) -> None:
"""Train a model with PyTorch backend.

Parameters
----------
allow_ref : bool, default=False
Whether to allow loading external JSON/YAML snippets via ``$ref``
in the training input. Disabled by default for security.
"""
log.info("Configuration path: %s", input_file)
env.CUSTOM_OP_USE_JIT = True
if LOCAL_RANK == 0:
Expand Down Expand Up @@ -325,7 +334,7 @@ def train(

# argcheck
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config, multi_task=multi_task)
config = normalize(config, multi_task=multi_task, allow_ref=allow_ref)

# do neighbor stat
min_nbor_dist = None
Expand Down Expand Up @@ -578,6 +587,7 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
use_pretrain_script=FLAGS.use_pretrain_script,
force_load=FLAGS.force_load,
output=FLAGS.output,
allow_ref=FLAGS.allow_ref,
)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
Expand Down
6 changes: 5 additions & 1 deletion deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def train(
skip_neighbor_stat: bool = False,
finetune: str | None = None,
use_pretrain_script: bool = False,
allow_ref: bool = False,
**kwargs: Any,
) -> None:
"""Run DeePMD model training.
Expand Down Expand Up @@ -101,6 +102,9 @@ def train(
use_pretrain_script : bool
Whether to use model script in pretrained model when doing init-model or init-frz-model.
Note that this option is true and unchangeable for fine-tuning.
allow_ref : bool, default=False
Whether to allow loading external JSON/YAML snippets via ``$ref``
in the training input. Disabled by default for security.
**kwargs
additional arguments

Expand Down Expand Up @@ -168,7 +172,7 @@ def train(

jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")

jdata = normalize(jdata)
jdata = normalize(jdata, allow_ref=allow_ref)

if not is_compress and not skip_neighbor_stat:
jdata = update_sel(jdata)
Expand Down
27 changes: 24 additions & 3 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3821,10 +3821,31 @@ def gen_json_schema(multi_task: bool = False) -> str:
return json.dumps(generate_json_schema(arg))


def normalize(data: dict[str, Any], multi_task: bool = False) -> dict[str, Any]:
def normalize(
data: dict[str, Any],
multi_task: bool = False,
allow_ref: bool = False,
) -> dict[str, Any]:
"""Normalize and validate DeePMD input config.

Parameters
----------
data : dict[str, Any]
Input training configuration.
multi_task : bool, default=False
Whether to use multi-task argument schema.
allow_ref : bool, default=False
Whether to allow loading external JSON/YAML snippets via ``$ref``.
Disabled by default for security.

Returns
-------
dict[str, Any]
Normalized and validated configuration.
"""
base = Argument("base", dict, gen_args(multi_task=multi_task))
data = base.normalize_value(data, trim_pattern="_*")
base.check_value(data, strict=True)
data = base.normalize_value(data, trim_pattern="_*", allow_ref=allow_ref)
base.check_value(data, strict=True, allow_ref=allow_ref)

return data

Expand Down
2 changes: 2 additions & 0 deletions doc/train/training-advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ An explanation will be provided

**`--skip-neighbor-stat`** will skip calculating neighbor statistics if one is concerned about performance. Some features will be disabled.

**`--allow-ref`** enables loading external JSON/YAML snippets via `$ref` during input validation. This option is disabled by default for security.

To maximize the performance, one should follow [FAQ: How to control the parallelism of a job](../troubleshooting/howtoset_num_nodes.md) to control the number of threads.
See [Runtime environment variables](../env.md) for all runtime environment variables.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies = [
'numpy>=1.21',
'scipy',
'pyyaml',
'dargs >= 0.4.7',
'dargs >= 0.5.0',
'typing_extensions>=4.0.0',
'importlib_metadata>=1.4; python_version < "3.8"',
'h5py',
Expand Down
17 changes: 7 additions & 10 deletions source/tests/common/test_argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,16 +287,13 @@ def test_parser_train_finetune(self) -> None:

self.run_test(command="train", mapping=ARGS)

def test_parser_train_wrong_subcommand(self) -> None:
"""Test train with multiple subparsers."""
ARGS = {
"INPUT": {"type": str, "value": "INFILE"},
"--init-model": {"type": (str, type(None)), "value": "SYSTEM_DIR"},
"--restart": {"type": (str, type(None)), "value": "RESTART"},
"--output": {"type": str, "value": "OUTPUT"},
}
with self.assertRaises(SystemExit):
self.run_test(command="train", mapping=ARGS)
def test_parser_train_allow_ref(self) -> None:
"""Test train --allow-ref option."""
args = parse_args(["train", "INFILE", "--allow-ref"])
self.assertTrue(args.allow_ref)

args_default = parse_args(["train", "INFILE"])
self.assertFalse(args_default.allow_ref)

def test_parser_freeze(self) -> None:
"""Test freeze subparser."""
Expand Down
Loading