Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ Then, read on for a brief overview of the usage of DeePMD-kit. You may start wit
dp
```

### `$ref` support in training input (secure by default)

DeePMD-kit supports loading external JSON/YAML snippets through `"$ref"` in training input validation (via `dargs>=0.5.0`).
For security reasons, this is **disabled by default**.

- CLI: use `dp train ... --allow-ref`
- Internal API (`deepmd.utils.argcheck.normalize`): pass `allow_ref=True`

## Code structure

The code is organized as follows:
Expand Down
5 changes: 5 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ def main_parser() -> argparse.ArgumentParser:
action="store_true",
help="(Supported backend: PyTorch) Force load from ckpt, other missing tensors will init from scratch",
)
parser_train.add_argument(
"--allow-ref",
action="store_true",
help="Allow loading external JSON/YAML snippets through `$ref`. Disabled by default for security.",
)

# * freeze script ******************************************************************
parser_frz = subparsers.add_parser(
Expand Down
12 changes: 11 additions & 1 deletion deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,16 @@ def train(
use_pretrain_script: bool = False,
force_load: bool = False,
output: str = "out.json",
allow_ref: bool = False,
) -> None:
"""Train a model with Paddle backend.

Parameters
----------
allow_ref : bool, default=False
Whether to allow loading external JSON/YAML snippets via ``$ref``
in the training input. Disabled by default for security.
"""
log.info("Configuration path: %s", input_file)
if LOCAL_RANK == 0:
SummaryPrinter()()
Expand Down Expand Up @@ -292,7 +301,7 @@ def train(

# argcheck
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config, multi_task=multi_task)
config = normalize(config, multi_task=multi_task, allow_ref=allow_ref)

# do neighbor stat
min_nbor_dist = None
Expand Down Expand Up @@ -600,6 +609,7 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
use_pretrain_script=FLAGS.use_pretrain_script,
force_load=FLAGS.force_load,
output=FLAGS.output,
allow_ref=FLAGS.allow_ref,
)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
Expand Down
12 changes: 11 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,16 @@ def train(
use_pretrain_script: bool = False,
force_load: bool = False,
output: str = "out.json",
allow_ref: bool = False,
) -> None:
"""Train a model with PyTorch backend.

Parameters
----------
allow_ref : bool, default=False
Whether to allow loading external JSON/YAML snippets via ``$ref``
in the training input. Disabled by default for security.
"""
log.info("Configuration path: %s", input_file)
env.CUSTOM_OP_USE_JIT = True
if LOCAL_RANK == 0:
Expand Down Expand Up @@ -325,7 +334,7 @@ def train(

# argcheck
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config, multi_task=multi_task)
config = normalize(config, multi_task=multi_task, allow_ref=allow_ref)

# do neighbor stat
min_nbor_dist = None
Expand Down Expand Up @@ -578,6 +587,7 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
use_pretrain_script=FLAGS.use_pretrain_script,
force_load=FLAGS.force_load,
output=FLAGS.output,
allow_ref=FLAGS.allow_ref,
)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
Expand Down
27 changes: 24 additions & 3 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3821,10 +3821,31 @@ def gen_json_schema(multi_task: bool = False) -> str:
return json.dumps(generate_json_schema(arg))


def normalize(data: dict[str, Any], multi_task: bool = False) -> dict[str, Any]:
def normalize(
data: dict[str, Any],
multi_task: bool = False,
allow_ref: bool = False,
) -> dict[str, Any]:
"""Normalize and validate DeePMD input config.

Parameters
----------
data : dict[str, Any]
Input training configuration.
multi_task : bool, default=False
Whether to use multi-task argument schema.
allow_ref : bool, default=False
Whether to allow loading external JSON/YAML snippets via ``$ref``.
Disabled by default for security.

Returns
-------
dict[str, Any]
Normalized and validated configuration.
"""
base = Argument("base", dict, gen_args(multi_task=multi_task))
data = base.normalize_value(data, trim_pattern="_*")
base.check_value(data, strict=True)
data = base.normalize_value(data, trim_pattern="_*", allow_ref=allow_ref)
base.check_value(data, strict=True, allow_ref=allow_ref)

return data

Expand Down
2 changes: 2 additions & 0 deletions doc/train/training-advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ An explanation will be provided

**`--skip-neighbor-stat`** will skip calculating neighbor statistics if one is concerned about performance. Some features will be disabled.

**`--allow-ref`** enables loading external JSON/YAML snippets via `$ref` during input validation. This option is disabled by default for security.

To maximize the performance, one should follow [FAQ: How to control the parallelism of a job](../troubleshooting/howtoset_num_nodes.md) to control the number of threads.
See [Runtime environment variables](../env.md) for all runtime environment variables.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies = [
'numpy>=1.21',
'scipy',
'pyyaml',
'dargs >= 0.4.7',
'dargs >= 0.5.0',
'typing_extensions>=4.0.0',
'importlib_metadata>=1.4; python_version < "3.8"',
'h5py',
Expand Down
Loading