Skip to content
275 changes: 88 additions & 187 deletions deepmd/pt_expt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,21 @@
) -> torch.nn.Module:
"""Trace ``forward_lower`` with ``make_fx`` and compile with ``torch.compile``.

Uses symbolic tracing (``tracing_mode="symbolic"``) so the resulting
FX graph captures shape-polymorphic operations. The graph is then
compiled with ``torch.compile(dynamic=True)`` and the inductor
backend, which automatically pads tensor shapes for efficient kernel
execution (``shape_padding=True``).

Parameters
----------
model : torch.nn.Module
The (uncompiled) model. Temporarily set to eval mode for tracing.
The (uncompiled) model.
ext_coord, ext_atype, nlist, mapping, fparam, aparam
Sample tensors (already padded to the desired max_nall).
Sample tensors used to drive the symbolic trace.
compile_opts : dict
Options forwarded to ``torch.compile`` (excluding ``dynamic``).
Options forwarded to ``torch.compile``. Keys ``dynamic`` and
``backend`` are set internally and ignored if provided.

Returns
-------
Expand Down Expand Up @@ -197,84 +204,57 @@
aparam=aparam,
)

# Use default tracing_mode="real" (concrete shapes) for best
# runtime performance. If data-dependent intermediate shapes
# change at runtime, the caller catches the error and retraces.
traced_lower = make_fx(fn)(ext_coord, ext_atype, nlist, mapping, fparam, aparam)
# Symbolic tracing captures shape-polymorphic ops, pairing with
# dynamic=True in torch.compile to handle varying nall without
# manual padding or recompilation.
traced_lower = make_fx(
fn,
tracing_mode="symbolic",
_allow_non_fake_inputs=True,
)(ext_coord, ext_atype, nlist, mapping, fparam, aparam)

if not was_training:
model.eval()

# The inductor backend does not propagate gradients through the
# make_fx-decomposed autograd.grad ops (second-order gradients for
# force training). Use "aot_eager" which correctly preserves the
# gradient chain while still benefiting from make_fx decomposition.
if "backend" not in compile_opts:
compile_opts["backend"] = "aot_eager"
compiled_lower = torch.compile(traced_lower, dynamic=False, **compile_opts)
# Override backend and dynamic — the inductor backend with
# dynamic=True handles varying shapes automatically.
compile_opts.pop("dynamic", None)
compile_opts.pop("backend", None)
if "options" not in compile_opts:
compile_opts["options"] = {}
opts = compile_opts["options"]
opts.setdefault("max_autotune", False)
opts.setdefault("epilogue_fusion", False)
opts.setdefault("triton.cudagraphs", False)
opts.setdefault("shape_padding", True)
opts.setdefault("max_fusion_size", 8)

compiled_lower = torch.compile(
traced_lower,
backend="inductor",
dynamic=True,
**compile_opts,
)
return compiled_lower


class _CompiledModel(torch.nn.Module):
"""Coord extension (eager) -> pad nall -> compiled forward_lower.
"""Coord extension (eager) -> compiled forward_lower.

If a batch's ``nall`` exceeds the current ``max_nall``, the model is
automatically re-traced and recompiled with a larger pad size.
Coord extension and neighbor list construction involve data-dependent
control flow and are kept in eager mode. The compiled ``forward_lower``
handles varying ``nall`` via ``dynamic=True`` — no manual padding or
recompilation needed.
"""

def __init__(
self,
original_model: torch.nn.Module,
compiled_forward_lower: torch.nn.Module,
max_nall: int,
compile_opts: dict[str, Any],
) -> None:
super().__init__()
self.original_model = original_model
self.compiled_forward_lower = compiled_forward_lower
self._max_nall = max_nall
self._compile_opts = compile_opts

def _recompile(
self,
ext_coord: torch.Tensor,
ext_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor,
fparam: torch.Tensor | None,
aparam: torch.Tensor | None,
new_max_nall: int,
) -> None:
"""Re-trace and recompile for the given inputs.

If *new_max_nall* differs from the current ``_max_nall``, the
inputs are padded (or already padded by the caller).
"""
# Pad if the caller provides unpadded tensors (nall growth case)
actual_nall = ext_coord.shape[1]
pad_n = new_max_nall - actual_nall
if pad_n > 0:
ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n))
ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n))
mapping = torch.nn.functional.pad(mapping, (0, pad_n))

ext_coord = ext_coord.detach()

self.compiled_forward_lower = _trace_and_compile(
self.original_model,
ext_coord,
ext_atype,
nlist,
mapping,
fparam,
aparam,
self._compile_opts,
)
self._max_nall = new_max_nall
log.info(
"Recompiled model with max_nall=%d.",
new_max_nall,
)

def forward(
self,
Expand Down Expand Up @@ -318,54 +298,29 @@
distinguish_types=False,
)
ext_coord = ext_coord.reshape(nframes, -1, 3)

# Grow max_nall if needed (retrace + recompile)
actual_nall = ext_coord.shape[1]
if actual_nall > self._max_nall:
new_max_nall = ((int(actual_nall * 1.2) + 7) // 8) * 8
log.info(
"nall=%d exceeds max_nall=%d; recompiling with max_nall=%d.",
actual_nall,
self._max_nall,
new_max_nall,
)
self._recompile(
ext_coord, ext_atype, nlist, mapping, fparam, aparam, new_max_nall
)

# Pad to max_nall so compiled graph sees a fixed shape
pad_n = self._max_nall - actual_nall
if pad_n > 0:
ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n))
ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n))
mapping = torch.nn.functional.pad(mapping, (0, pad_n))
ext_coord = ext_coord.detach().requires_grad_(True)

result = self.compiled_forward_lower(
ext_coord, ext_atype, nlist, mapping, fparam, aparam
)

# Translate forward_lower keys -> forward keys.

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable actual_nall is not used.
# ``extended_force`` lives on all extended atoms (nf, nall, 3).
# Ghost-atom forces must be scatter-summed back to local atoms
# via ``mapping`` — the same operation ``communicate_extended_output``
# performs in the uncompiled path.
actual_nall = ext_coord.shape[1]
out: dict[str, torch.Tensor] = {}
out["atom_energy"] = result["atom_energy"]
out["energy"] = result["energy"]
if "extended_force" in result:
ext_force = result["extended_force"] # (nf, nall_padded, 3)
# mapping may be padded; only use actual_nall entries
map_actual = mapping[:, :actual_nall] # (nf, actual_nall)
ext_force_actual = ext_force[:, :actual_nall, :] # (nf, actual_nall, 3)
ext_force = result["extended_force"] # (nf, nall, 3)
# scatter-sum extended forces onto local atoms
idx = map_actual.unsqueeze(-1).expand_as(
ext_force_actual
) # (nf, actual_nall, 3)
idx = mapping.unsqueeze(-1).expand_as(ext_force) # (nf, nall, 3)
force = torch.zeros(
nframes, nloc, 3, dtype=ext_force.dtype, device=ext_force.device
)
force.scatter_add_(1, idx, ext_force_actual)
force.scatter_add_(1, idx, ext_force)
out["force"] = force
if "virial" in result:
out["virial"] = result["virial"]
Expand Down Expand Up @@ -642,21 +597,19 @@
def _compile_model(self, compile_opts: dict[str, Any]) -> None:
"""Replace ``self.model`` with a compiled version.

The model's ``forward`` uses ``torch.autograd.grad`` (for force
computation) with ``create_graph=True``, which creates a "double
backward" that ``torch.compile`` cannot handle.

Solution: use ``make_fx`` to trace ``forward_lower``, decomposing
``torch.autograd.grad`` into primitive ops. The coord extension +
nlist build (data-dependent control flow) are kept outside the
compiled region.

To avoid the overhead of symbolic tracing and dynamic shapes, the
extended-atom dimension (nall) is padded to a fixed maximum
estimated from the training data. This allows concrete-shape
tracing and ``dynamic=False``. If a batch exceeds the current
max_nall at runtime, the model is automatically re-traced and
recompiled with a larger pad size.
The model's ``forward`` uses ``torch.autograd.grad`` (for forces)
with ``create_graph=True``, which creates a "double backward" that
``torch.compile`` cannot handle.

Solution: use ``make_fx`` with ``tracing_mode="symbolic"`` to trace
``forward_lower``, decomposing ``torch.autograd.grad`` into
primitive ops with symbolic shapes. The traced graph is compiled
with ``torch.compile(dynamic=True, backend="inductor")`` so
varying ``nall`` across batches is handled automatically — no
manual padding or recompilation needed.

Coord extension + nlist build (data-dependent control flow) are
kept outside the compiled region.
"""
from deepmd.dpmodel.utils.nlist import (
build_neighbor_list,
Expand All @@ -668,105 +621,53 @@

model = self.model

# --- Estimate max_nall by sampling multiple batches ---
n_sample = 20
max_nall = 0
best_sample: (
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, dict] | None
) = None

for _ii in range(n_sample):
inp, _ = self.get_data(is_train=True)
coord = inp["coord"].detach()
atype = inp["atype"].detach()
box = inp.get("box")
if box is not None:
box = box.detach()

nframes, nloc = atype.shape[:2]
coord_np = coord.cpu().numpy().reshape(nframes, nloc, 3)
atype_np = atype.cpu().numpy()
box_np = box.cpu().numpy().reshape(nframes, 9) if box is not None else None

if box_np is not None:
coord_norm = normalize_coord(coord_np, box_np.reshape(nframes, 3, 3))
else:
coord_norm = coord_np
# --- Get one sample batch to drive the symbolic trace ---
inp, _ = self.get_data(is_train=True)
coord = inp["coord"].detach()
atype = inp["atype"].detach()
box = inp.get("box")
if box is not None:
box = box.detach()

ext_coord_np, ext_atype_np, mapping_np = extend_coord_with_ghosts(
coord_norm, atype_np, box_np, model.get_rcut()
)
nlist_np = build_neighbor_list(
ext_coord_np,
ext_atype_np,
nloc,
model.get_rcut(),
model.get_sel(),
distinguish_types=False,
)
ext_coord_np = ext_coord_np.reshape(nframes, -1, 3)
nall = ext_coord_np.shape[1]
if nall > max_nall:
max_nall = nall
best_sample = (
ext_coord_np,
ext_atype_np,
mapping_np,
nlist_np,
nloc,
inp,
)
nframes, nloc = atype.shape[:2]
coord_3d = coord.reshape(nframes, nloc, 3)
box_flat = box.reshape(nframes, 9) if box is not None else None

# Add 20 % margin and round up to a multiple of 8.
max_nall = ((int(max_nall * 1.2) + 7) // 8) * 8
log.info(
"Estimated max_nall=%d for compiled model (sampled %d batches).",
max_nall,
n_sample,
)
if box_flat is not None:
coord_norm = normalize_coord(coord_3d, box_flat.reshape(nframes, 3, 3))
else:
coord_norm = coord_3d

# --- Pad the largest sample to max_nall and trace ---
assert best_sample is not None
ext_coord_np, ext_atype_np, mapping_np, nlist_np, nloc, sample_input = (
best_sample
ext_coord, ext_atype, mapping = extend_coord_with_ghosts(
coord_norm, atype, box_flat, model.get_rcut()
)
nframes = ext_coord_np.shape[0]
actual_nall = ext_coord_np.shape[1]
pad_n = max_nall - actual_nall

if pad_n > 0:
ext_coord_np = np.pad(ext_coord_np, ((0, 0), (0, pad_n), (0, 0)))
ext_atype_np = np.pad(ext_atype_np, ((0, 0), (0, pad_n)))
mapping_np = np.pad(mapping_np, ((0, 0), (0, pad_n)))

ext_coord = torch.tensor(
ext_coord_np, dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
nlist_t = build_neighbor_list(
ext_coord,
ext_atype,
nloc,
model.get_rcut(),
model.get_sel(),
distinguish_types=False,
)
ext_atype = torch.tensor(ext_atype_np, dtype=torch.int64, device=DEVICE)
nlist_t = torch.tensor(nlist_np, dtype=torch.int64, device=DEVICE)
mapping_t = torch.tensor(mapping_np, dtype=torch.int64, device=DEVICE)
fparam = sample_input.get("fparam")
aparam = sample_input.get("aparam")
ext_coord = ext_coord.reshape(nframes, -1, 3)

compile_opts.pop("dynamic", None) # always False for padded approach
fparam = inp.get("fparam")
aparam = inp.get("aparam")

compiled_lower = _trace_and_compile(
model,
ext_coord,
ext_atype,
nlist_t,
mapping_t,
mapping,
fparam,
aparam,
compile_opts,
)

self.wrapper.model = _CompiledModel(
model, compiled_lower, max_nall, compile_opts
)
self.wrapper.model = _CompiledModel(model, compiled_lower)
log.info(
"Model compiled with padded nall=%d (tracing_mode=real, dynamic=False).",
max_nall,
"Model compiled (tracing_mode=symbolic, dynamic=True, backend=inductor).",
)

# ------------------------------------------------------------------
Expand Down
Loading
Loading