diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 774cf33d72..5881b3a0b3 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -345,6 +345,14 @@ def get_rcut(self) -> float: """Returns the cut-off radius.""" return self.rcut + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + def get_nsel(self) -> int: """Returns the number of selected atoms in the cut-off radius.""" return sum(self.sel) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 8f32ca660c..c44dc07652 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -153,14 +153,21 @@ def _trace_and_compile( ) -> torch.nn.Module: """Trace ``forward_lower`` with ``make_fx`` and compile with ``torch.compile``. + Uses symbolic tracing (``tracing_mode="symbolic"``) so the resulting + FX graph captures shape-polymorphic operations. The graph is then + compiled with ``torch.compile(dynamic=True)`` and the inductor + backend, which automatically pads tensor shapes for efficient kernel + execution (``shape_padding=True``). + Parameters ---------- model : torch.nn.Module - The (uncompiled) model. Temporarily set to eval mode for tracing. + The (uncompiled) model. ext_coord, ext_atype, nlist, mapping, fparam, aparam - Sample tensors (already padded to the desired max_nall). + Sample tensors used to drive the symbolic trace. compile_opts : dict - Options forwarded to ``torch.compile`` (excluding ``dynamic``). + Options forwarded to ``torch.compile``. Keys ``dynamic`` and + ``backend`` are set internally and ignored if provided. Returns ------- @@ -197,84 +204,57 @@ def fn( aparam=aparam, ) - # Use default tracing_mode="real" (concrete shapes) for best - # runtime performance. If data-dependent intermediate shapes - # change at runtime, the caller catches the error and retraces. - traced_lower = make_fx(fn)(ext_coord, ext_atype, nlist, mapping, fparam, aparam) + # Symbolic tracing captures shape-polymorphic ops, pairing with + # dynamic=True in torch.compile to handle varying nall without + # manual padding or recompilation. + traced_lower = make_fx( + fn, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + )(ext_coord, ext_atype, nlist, mapping, fparam, aparam) if not was_training: model.eval() - # The inductor backend does not propagate gradients through the - # make_fx-decomposed autograd.grad ops (second-order gradients for - # force training). Use "aot_eager" which correctly preserves the - # gradient chain while still benefiting from make_fx decomposition. - if "backend" not in compile_opts: - compile_opts["backend"] = "aot_eager" - compiled_lower = torch.compile(traced_lower, dynamic=False, **compile_opts) + # Override backend and dynamic — the inductor backend with + # dynamic=True handles varying shapes automatically. + compile_opts.pop("dynamic", None) + compile_opts.pop("backend", None) + if "options" not in compile_opts: + compile_opts["options"] = {} + opts = compile_opts["options"] + opts.setdefault("max_autotune", False) + opts.setdefault("epilogue_fusion", False) + opts.setdefault("triton.cudagraphs", False) + opts.setdefault("shape_padding", True) + opts.setdefault("max_fusion_size", 8) + + compiled_lower = torch.compile( + traced_lower, + backend="inductor", + dynamic=True, + **compile_opts, + ) return compiled_lower class _CompiledModel(torch.nn.Module): - """Coord extension (eager) -> pad nall -> compiled forward_lower. + """Coord extension (eager) -> compiled forward_lower. - If a batch's ``nall`` exceeds the current ``max_nall``, the model is - automatically re-traced and recompiled with a larger pad size. + Coord extension and neighbor list construction involve data-dependent + control flow and are kept in eager mode. The compiled ``forward_lower`` + handles varying ``nall`` via ``dynamic=True`` — no manual padding or + recompilation needed. """ def __init__( self, original_model: torch.nn.Module, compiled_forward_lower: torch.nn.Module, - max_nall: int, - compile_opts: dict[str, Any], ) -> None: super().__init__() self.original_model = original_model self.compiled_forward_lower = compiled_forward_lower - self._max_nall = max_nall - self._compile_opts = compile_opts - - def _recompile( - self, - ext_coord: torch.Tensor, - ext_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor, - fparam: torch.Tensor | None, - aparam: torch.Tensor | None, - new_max_nall: int, - ) -> None: - """Re-trace and recompile for the given inputs. - - If *new_max_nall* differs from the current ``_max_nall``, the - inputs are padded (or already padded by the caller). - """ - # Pad if the caller provides unpadded tensors (nall growth case) - actual_nall = ext_coord.shape[1] - pad_n = new_max_nall - actual_nall - if pad_n > 0: - ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n)) - ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n)) - mapping = torch.nn.functional.pad(mapping, (0, pad_n)) - - ext_coord = ext_coord.detach() - - self.compiled_forward_lower = _trace_and_compile( - self.original_model, - ext_coord, - ext_atype, - nlist, - mapping, - fparam, - aparam, - self._compile_opts, - ) - self._max_nall = new_max_nall - log.info( - "Recompiled model with max_nall=%d.", - new_max_nall, - ) def forward( self, @@ -318,27 +298,6 @@ def forward( distinguish_types=False, ) ext_coord = ext_coord.reshape(nframes, -1, 3) - - # Grow max_nall if needed (retrace + recompile) - actual_nall = ext_coord.shape[1] - if actual_nall > self._max_nall: - new_max_nall = ((int(actual_nall * 1.2) + 7) // 8) * 8 - log.info( - "nall=%d exceeds max_nall=%d; recompiling with max_nall=%d.", - actual_nall, - self._max_nall, - new_max_nall, - ) - self._recompile( - ext_coord, ext_atype, nlist, mapping, fparam, aparam, new_max_nall - ) - - # Pad to max_nall so compiled graph sees a fixed shape - pad_n = self._max_nall - actual_nall - if pad_n > 0: - ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n)) - ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n)) - mapping = torch.nn.functional.pad(mapping, (0, pad_n)) ext_coord = ext_coord.detach().requires_grad_(True) result = self.compiled_forward_lower( @@ -354,18 +313,13 @@ def forward( out["atom_energy"] = result["atom_energy"] out["energy"] = result["energy"] if "extended_force" in result: - ext_force = result["extended_force"] # (nf, nall_padded, 3) - # mapping may be padded; only use actual_nall entries - map_actual = mapping[:, :actual_nall] # (nf, actual_nall) - ext_force_actual = ext_force[:, :actual_nall, :] # (nf, actual_nall, 3) + ext_force = result["extended_force"] # (nf, nall, 3) # scatter-sum extended forces onto local atoms - idx = map_actual.unsqueeze(-1).expand_as( - ext_force_actual - ) # (nf, actual_nall, 3) + idx = mapping.unsqueeze(-1).expand_as(ext_force) # (nf, nall, 3) force = torch.zeros( nframes, nloc, 3, dtype=ext_force.dtype, device=ext_force.device ) - force.scatter_add_(1, idx, ext_force_actual) + force.scatter_add_(1, idx, ext_force) out["force"] = force if "virial" in result: out["virial"] = result["virial"] @@ -642,21 +596,19 @@ def get_sample() -> list[dict[str, np.ndarray]]: def _compile_model(self, compile_opts: dict[str, Any]) -> None: """Replace ``self.model`` with a compiled version. - The model's ``forward`` uses ``torch.autograd.grad`` (for force - computation) with ``create_graph=True``, which creates a "double - backward" that ``torch.compile`` cannot handle. - - Solution: use ``make_fx`` to trace ``forward_lower``, decomposing - ``torch.autograd.grad`` into primitive ops. The coord extension + - nlist build (data-dependent control flow) are kept outside the - compiled region. - - To avoid the overhead of symbolic tracing and dynamic shapes, the - extended-atom dimension (nall) is padded to a fixed maximum - estimated from the training data. This allows concrete-shape - tracing and ``dynamic=False``. If a batch exceeds the current - max_nall at runtime, the model is automatically re-traced and - recompiled with a larger pad size. + The model's ``forward`` uses ``torch.autograd.grad`` (for forces) + with ``create_graph=True``, which creates a "double backward" that + ``torch.compile`` cannot handle. + + Solution: use ``make_fx`` with ``tracing_mode="symbolic"`` to trace + ``forward_lower``, decomposing ``torch.autograd.grad`` into + primitive ops with symbolic shapes. The traced graph is compiled + with ``torch.compile(dynamic=True, backend="inductor")`` so + varying ``nall`` across batches is handled automatically — no + manual padding or recompilation needed. + + Coord extension + nlist build (data-dependent control flow) are + kept outside the compiled region. """ from deepmd.dpmodel.utils.nlist import ( build_neighbor_list, @@ -668,105 +620,53 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: model = self.model - # --- Estimate max_nall by sampling multiple batches --- - n_sample = 20 - max_nall = 0 - best_sample: ( - tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, dict] | None - ) = None - - for _ii in range(n_sample): - inp, _ = self.get_data(is_train=True) - coord = inp["coord"].detach() - atype = inp["atype"].detach() - box = inp.get("box") - if box is not None: - box = box.detach() - - nframes, nloc = atype.shape[:2] - coord_np = coord.cpu().numpy().reshape(nframes, nloc, 3) - atype_np = atype.cpu().numpy() - box_np = box.cpu().numpy().reshape(nframes, 9) if box is not None else None - - if box_np is not None: - coord_norm = normalize_coord(coord_np, box_np.reshape(nframes, 3, 3)) - else: - coord_norm = coord_np + # --- Get one sample batch to drive the symbolic trace --- + inp, _ = self.get_data(is_train=True) + coord = inp["coord"].detach() + atype = inp["atype"].detach() + box = inp.get("box") + if box is not None: + box = box.detach() - ext_coord_np, ext_atype_np, mapping_np = extend_coord_with_ghosts( - coord_norm, atype_np, box_np, model.get_rcut() - ) - nlist_np = build_neighbor_list( - ext_coord_np, - ext_atype_np, - nloc, - model.get_rcut(), - model.get_sel(), - distinguish_types=False, - ) - ext_coord_np = ext_coord_np.reshape(nframes, -1, 3) - nall = ext_coord_np.shape[1] - if nall > max_nall: - max_nall = nall - best_sample = ( - ext_coord_np, - ext_atype_np, - mapping_np, - nlist_np, - nloc, - inp, - ) + nframes, nloc = atype.shape[:2] + coord_3d = coord.reshape(nframes, nloc, 3) + box_flat = box.reshape(nframes, 9) if box is not None else None - # Add 20 % margin and round up to a multiple of 8. - max_nall = ((int(max_nall * 1.2) + 7) // 8) * 8 - log.info( - "Estimated max_nall=%d for compiled model (sampled %d batches).", - max_nall, - n_sample, - ) + if box_flat is not None: + coord_norm = normalize_coord(coord_3d, box_flat.reshape(nframes, 3, 3)) + else: + coord_norm = coord_3d - # --- Pad the largest sample to max_nall and trace --- - assert best_sample is not None - ext_coord_np, ext_atype_np, mapping_np, nlist_np, nloc, sample_input = ( - best_sample + ext_coord, ext_atype, mapping = extend_coord_with_ghosts( + coord_norm, atype, box_flat, model.get_rcut() ) - nframes = ext_coord_np.shape[0] - actual_nall = ext_coord_np.shape[1] - pad_n = max_nall - actual_nall - - if pad_n > 0: - ext_coord_np = np.pad(ext_coord_np, ((0, 0), (0, pad_n), (0, 0))) - ext_atype_np = np.pad(ext_atype_np, ((0, 0), (0, pad_n))) - mapping_np = np.pad(mapping_np, ((0, 0), (0, pad_n))) - - ext_coord = torch.tensor( - ext_coord_np, dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + nlist_t = build_neighbor_list( + ext_coord, + ext_atype, + nloc, + model.get_rcut(), + model.get_sel(), + distinguish_types=False, ) - ext_atype = torch.tensor(ext_atype_np, dtype=torch.int64, device=DEVICE) - nlist_t = torch.tensor(nlist_np, dtype=torch.int64, device=DEVICE) - mapping_t = torch.tensor(mapping_np, dtype=torch.int64, device=DEVICE) - fparam = sample_input.get("fparam") - aparam = sample_input.get("aparam") + ext_coord = ext_coord.reshape(nframes, -1, 3) - compile_opts.pop("dynamic", None) # always False for padded approach + fparam = inp.get("fparam") + aparam = inp.get("aparam") compiled_lower = _trace_and_compile( model, ext_coord, ext_atype, nlist_t, - mapping_t, + mapping, fparam, aparam, compile_opts, ) - self.wrapper.model = _CompiledModel( - model, compiled_lower, max_nall, compile_opts - ) + self.wrapper.model = _CompiledModel(model, compiled_lower) log.info( - "Model compiled with padded nall=%d (tracing_mode=real, dynamic=False).", - max_nall, + "Model compiled (tracing_mode=symbolic, dynamic=True, backend=inductor).", ) # ------------------------------------------------------------------ diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 1629ecb83a..adef443de9 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import math from typing import ( Any, ClassVar, @@ -182,6 +183,14 @@ def _torch_activation(x: torch.Tensor, name: str) -> torch.Tensor: return torch.sigmoid(x) elif name == "silu": return torch.nn.functional.silu(x) + elif name.startswith("silut") or name.startswith("custom_silu"): + threshold = float(name.split(":")[-1]) if ":" in name else 3.0 + sig_t = 1.0 / (1.0 + math.exp(-threshold)) + slope = sig_t + threshold * sig_t * (1.0 - sig_t) + const = threshold * sig_t + silu = x * torch.sigmoid(x) + tanh_branch = torch.tanh(slope * (x - threshold)) + const + return torch.where(x < threshold, silu, tanh_branch) elif name in ("none", "linear"): return x else: diff --git a/source/tests/common/dpmodel/test_descriptor_dpa2.py b/source/tests/common/dpmodel/test_descriptor_dpa2.py index 7867fee874..af58d12790 100644 --- a/source/tests/common/dpmodel/test_descriptor_dpa2.py +++ b/source/tests/common/dpmodel/test_descriptor_dpa2.py @@ -10,6 +10,9 @@ RepformerArgs, RepinitArgs, ) +from deepmd.dpmodel.descriptor.repformers import ( + DescrptBlockRepformers, +) from ...seed import ( GLOBAL_SEED, @@ -69,3 +72,36 @@ def test_self_consistency( for ii in [0, 1, 2, 3, 4]: np.testing.assert_equal(mm0[ii].shape, desired_shape[ii]) np.testing.assert_allclose(mm0[ii], mm1[ii]) + + +class TestDescrptBlockRepformersAccessors(unittest.TestCase): + def test_get_rcut_smth(self) -> None: + block = DescrptBlockRepformers( + rcut=6.0, + rcut_smth=5.0, + sel=40, + ntypes=2, + nlayers=3, + ) + self.assertEqual(block.get_rcut_smth(), 5.0) + + def test_get_env_protection(self) -> None: + block = DescrptBlockRepformers( + rcut=6.0, + rcut_smth=5.0, + sel=40, + ntypes=2, + nlayers=3, + env_protection=1.0, + ) + self.assertEqual(block.get_env_protection(), 1.0) + + def test_get_env_protection_default(self) -> None: + block = DescrptBlockRepformers( + rcut=6.0, + rcut_smth=5.0, + sel=40, + ntypes=2, + nlayers=3, + ) + self.assertEqual(block.get_env_protection(), 0.0) diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index 31351d4a9d..b46319e338 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -19,6 +19,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, parameterized, ) @@ -29,6 +30,13 @@ from deepmd.pt.utils.utils import ( to_torch_tensor, ) +if INSTALLED_PT_EXPT: + import torch + + from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE + from deepmd.pt_expt.utils.network import ( + _torch_activation, + ) if INSTALLED_TF: from deepmd.tf.common import get_activation_func as get_activation_fn_tf from deepmd.tf.env import ( @@ -98,3 +106,54 @@ def test_pd_consistent_with_ref(self): ActivationFn_pd(self.activation)(to_paddle_tensor(self.random_input)) ) np.testing.assert_allclose(self.ref, test, atol=1e-10) + + @unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed") + def test_pt_expt_consistent_with_ref(self) -> None: + if INSTALLED_PT_EXPT: + x = torch.tensor( + self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE + ) + test = _torch_activation(x, self.activation).detach().numpy() + np.testing.assert_allclose(self.ref, test, atol=1e-10) + + +@parameterized( + ( + "silut", # default threshold 3.0 + "silut:3.0", # explicit threshold 3.0 + "silut:10.0", # large threshold + "custom_silu:5.0", # alias + ), +) +class TestSilutVariantsConsistent(unittest.TestCase): + """Cross-backend consistency for silut with different thresholds.""" + + def setUp(self) -> None: + (self.activation,) = self.param + # Parse threshold to build input that covers both branches + threshold = ( + float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0 + ) + rng = np.random.default_rng(GLOBAL_SEED) + # Values below threshold (silu branch) and above threshold (tanh branch) + below = rng.uniform(-threshold - 5, threshold - 0.1, size=(5, 10)) + above = rng.uniform(threshold + 0.1, threshold + 20, size=(5, 10)) + self.random_input = np.concatenate([below, above], axis=0) + self.ref = get_activation_fn_dp(self.activation)(self.random_input) + + @unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed") + def test_pt_consistent_with_ref(self) -> None: + if INSTALLED_PT: + test = torch_to_numpy( + ActivationFn_pt(self.activation)(to_torch_tensor(self.random_input)) + ) + np.testing.assert_allclose(self.ref, test, atol=1e-10) + + @unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed") + def test_pt_expt_consistent_with_ref(self) -> None: + if INSTALLED_PT_EXPT: + x = torch.tensor( + self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE + ) + test = _torch_activation(x, self.activation).detach().numpy() + np.testing.assert_allclose(self.ref, test, atol=1e-10) diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 3b3ab247bb..d2026d3e93 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -164,8 +164,8 @@ def test_training_loop_compiled(self) -> None: self._run_training(config) -class TestCompiledRecompile(unittest.TestCase): - """Test that _CompiledModel recompiles when nall exceeds max_nall.""" +class TestCompiledDynamicShapes(unittest.TestCase): + """Test that compiled model handles varying nall via dynamic shapes.""" @classmethod def setUpClass(cls) -> None: @@ -174,8 +174,8 @@ def setUpClass(cls) -> None: raise unittest.SkipTest(f"Example data not found: {data_dir}") cls.data_dir = data_dir - def test_nall_growth_triggers_recompile(self) -> None: - """Shrink max_nall to force a recompile, then verify training works.""" + def test_compiled_handles_varying_nall(self) -> None: + """Run multiple training steps — nall may vary across batches.""" from deepmd.pt_expt.train.training import ( _CompiledModel, ) @@ -185,7 +185,7 @@ def test_nall_growth_triggers_recompile(self) -> None: config = update_deepmd_input(config, warning=False) config = normalize(config) - tmpdir = tempfile.mkdtemp(prefix="pt_expt_recompile_") + tmpdir = tempfile.mkdtemp(prefix="pt_expt_dynamic_") try: old_cwd = os.getcwd() os.chdir(tmpdir) @@ -196,36 +196,19 @@ def test_nall_growth_triggers_recompile(self) -> None: compiled_model = trainer.wrapper.model self.assertIsInstance(compiled_model, _CompiledModel) - original_max_nall = compiled_model._max_nall - self.assertGreater(original_max_nall, 0) - - # Artificially shrink max_nall to 1 so the next batch - # will certainly exceed it and trigger recompilation. - compiled_model._max_nall = 1 - old_compiled_lower = compiled_model.compiled_forward_lower - - # Run one training step — should trigger recompile + # Run several training steps — each may have different nall trainer.wrapper.train() - trainer.optimizer.zero_grad(set_to_none=True) - inp, lab = trainer.get_data(is_train=True) - lr = trainer.scheduler.get_last_lr()[0] - _, loss, more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab) - loss.backward() - trainer.optimizer.step() - - # max_nall should have grown beyond 1 - new_max_nall = compiled_model._max_nall - self.assertGreater(new_max_nall, 1) - - # compiled_forward_lower should be a new object - self.assertIsNot( - compiled_model.compiled_forward_lower, - old_compiled_lower, - ) - - # Loss should be a finite scalar - self.assertFalse(torch.isnan(loss)) - self.assertFalse(torch.isinf(loss)) + for _ in range(3): + trainer.optimizer.zero_grad(set_to_none=True) + inp, lab = trainer.get_data(is_train=True) + lr = trainer.scheduler.get_last_lr()[0] + _, loss, _more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab) + loss.backward() + trainer._optimizer_step() + + # Loss should be a finite scalar + self.assertFalse(torch.isnan(loss)) + self.assertFalse(torch.isinf(loss)) finally: os.chdir(old_cwd) finally: diff --git a/source/tests/pt_expt/utils/test_activation.py b/source/tests/pt_expt/utils/test_activation.py new file mode 100644 index 0000000000..23550d3315 --- /dev/null +++ b/source/tests/pt_expt/utils/test_activation.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.utils.network import ( + get_activation_fn, +) +from deepmd.pt_expt.utils.network import ( + _torch_activation, +) + + +class TestSilutActivation: + """Tests for silut activation in _torch_activation.""" + + def setup_method(self) -> None: + # x values spanning both branches: below threshold and above + self.x_np = np.array( + [-5.0, -1.0, 0.0, 1.0, 2.5, 3.0, 5.0, 10.0, 15.0, 20.0], + dtype=np.float64, + ) + self.x_torch = torch.tensor(self.x_np, dtype=torch.float64) + + def test_silut_with_threshold(self) -> None: + """silut:10.0 matches dpmodel numerically.""" + result = _torch_activation(self.x_torch, "silut:10.0") + dp_fn = get_activation_fn("silut:10.0") + expected = dp_fn(self.x_np) + np.testing.assert_allclose( + result.detach().numpy(), expected, rtol=1e-12, atol=1e-12 + ) + + def test_silut_default_threshold(self) -> None: + """Silut without parameter uses default threshold 3.0.""" + result = _torch_activation(self.x_torch, "silut") + dp_fn = get_activation_fn("silut") + expected = dp_fn(self.x_np) + np.testing.assert_allclose( + result.detach().numpy(), expected, rtol=1e-12, atol=1e-12 + ) + + def test_silut_custom_silu_alias(self) -> None: + """custom_silu:5.0 is an alias for silut:5.0.""" + result = _torch_activation(self.x_torch, "custom_silu:5.0") + dp_fn = get_activation_fn("custom_silu:5.0") + expected = dp_fn(self.x_np) + np.testing.assert_allclose( + result.detach().numpy(), expected, rtol=1e-12, atol=1e-12 + ) + + def test_silut_gradient(self) -> None: + """Gradient flows through both branches of silut.""" + x = self.x_torch.clone().requires_grad_(True) + y = _torch_activation(x, "silut:3.0") + loss = y.sum() + loss.backward() + grad = x.grad + assert grad is not None + # gradient should be finite everywhere + assert torch.all(torch.isfinite(grad)) + # gradient should be non-zero for non-zero inputs + nonzero_mask = self.x_np != 0.0 + assert torch.all(grad[nonzero_mask] != 0.0) + + def test_silut_make_fx(self) -> None: + """make_fx can trace through silut activation.""" + + def fn(x: torch.Tensor) -> torch.Tensor: + return _torch_activation(x, "silut:10.0") + + traced = make_fx(fn)(self.x_torch) + result = traced(self.x_torch) + expected = _torch_activation(self.x_torch, "silut:10.0") + np.testing.assert_allclose( + result.detach().numpy(), expected.detach().numpy(), rtol=1e-12, atol=1e-12 + ) + + def test_silut_below_threshold_is_silu(self) -> None: + """Below threshold, silut equals silu exactly.""" + x_below = torch.tensor([-5.0, 0.0, 1.0, 5.0, 9.9], dtype=torch.float64) + result = _torch_activation(x_below, "silut:10.0") + silu = x_below * torch.sigmoid(x_below) + np.testing.assert_allclose( + result.detach().numpy(), silu.detach().numpy(), rtol=1e-14, atol=1e-14 + ) + + def test_silut_above_threshold_is_tanh_branch(self) -> None: + """Above threshold, silut equals tanh(slope*(x-T))+const.""" + import math + + threshold = 3.0 + sig_t = 1.0 / (1.0 + math.exp(-threshold)) + slope = sig_t + threshold * sig_t * (1.0 - sig_t) + const = threshold * sig_t + + x_above = torch.tensor([3.5, 5.0, 10.0, 20.0], dtype=torch.float64) + result = _torch_activation(x_above, "silut:3.0") + expected = torch.tanh(slope * (x_above - threshold)) + const + np.testing.assert_allclose( + result.detach().numpy(), expected.detach().numpy(), rtol=1e-14, atol=1e-14 + ) + + def test_silut_export(self) -> None: + """torch.export.export can trace through silut activation.""" + + class SilutModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _torch_activation(x, "silut:10.0") + + mod = SilutModule() + exported = torch.export.export(mod, (self.x_torch,)) + result = exported.module()(self.x_torch) + expected = _torch_activation(self.x_torch, "silut:10.0") + np.testing.assert_allclose( + result.detach().numpy(), expected.detach().numpy(), rtol=1e-12, atol=1e-12 + )