diff --git a/modelforge/potential/__init__.py b/modelforge/potential/__init__.py index ffe63235..7348f941 100644 --- a/modelforge/potential/__init__.py +++ b/modelforge/potential/__init__.py @@ -14,6 +14,7 @@ SAKEParameters, SchNetParameters, TensorNetParameters, + DimeNetParameters ) from .processing import FromAtomToMoleculeReduction from .representation import ( @@ -31,6 +32,7 @@ from .tensornet import TensorNetCore from .aimnet2 import AimNet2Core from .ani import ANI2xCore +from .dimenet import DimeNetCore class _Implemented_NNP_Parameters(Enum): @@ -41,6 +43,7 @@ class _Implemented_NNP_Parameters(Enum): PHYSNET_PARAMETERS = PhysNetParameters SAKE_PARAMETERS = SAKEParameters AIMNET2_PARAMETERS = AimNet2Parameters + DIMENET_PARAMETERS = DimeNetParameters @classmethod def get_neural_network_parameter_class(cls, neural_network_name: str): @@ -63,6 +66,7 @@ class _Implemented_NNPs(Enum): PAINN = PaiNNCore SAKE = SAKECore AIMNET2 = AimNet2Core + DIMENET = DimeNetCore @classmethod def get_neural_network_class(cls, neural_network_name: str): diff --git a/modelforge/potential/dimenet.py b/modelforge/potential/dimenet.py new file mode 100644 index 00000000..fdc34b07 --- /dev/null +++ b/modelforge/potential/dimenet.py @@ -0,0 +1,370 @@ +""" +This module contains the dimenet++ implementation based on +"Directional Message Passing for Molecular Graphs" (ICLR 2020) +and "Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules" (NeurIPS-W 2020) +""" + +import torch +import torch.nn as nn +from loguru import logger as log + +from typing import Dict, List + +from modelforge.dataset.dataset import NNPInput +from modelforge.potential.neighbors import PairlistData + + +class EmbeddingBlock(nn.Module): + """ + Embedding block for the DimeNet++ model. + + Parameters + ---------- + embedding_size : int + Embedding size. + activation_function : torch.nn.Module + + Notes + ----- + This module computes the embedding for atom pairs based on their atomic + numbers and radial basis functions. It uses trainable embeddings for atomic + numbers up to 94 (Plutonium) and applies two dense layers. + """ + + def __init__( + self, + embedding_size: int, + number_of_radial_bessel_functions: int, + activation_function: torch.nn.Module, + ): + super().__init__() + self.embedding_size = embedding_size + import math + from modelforge.potential.utils import Dense + + num_embeddings = 95 # Elements up to atomic number 94 (Pu) + + # Initialize embeddings with Uniform(-sqrt(3), sqrt(3)) + self.embeddings = nn.Embedding(num_embeddings, embedding_size) + emb_init_range = math.sqrt(3) + nn.init.uniform_(self.embeddings.weight, -emb_init_range, emb_init_range) + + # Dense layer for radial basis functions + self.dense_rbf = Dense( + number_of_radial_bessel_functions, + embedding_size, + bias=True, + activation_function=activation_function, + ) + + # Final dense layer + self.dense = Dense( + 3 * embedding_size, + embedding_size, + bias=True, + activation_function=activation_function, + ) + + def forward( + self, + inputs: NNPInput, + pairlist_output: PairlistData, + f_ij: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the EmbeddingBlock. + + Parameters + ---------- + inputs : NNPInput + Input data including atomic numbers, positions, etc. + pairlist_output : PairlistData + Output from the pairlist module, containing pair indices and + distances. + f_ij : torch.Tensor + Returns + ------- + x : torch.Tensor + Output tensor of shape (nr_of_pairs, emb_size). + """ + + # Transform radial basis functions + # rbf: (nr_of_pairs, num_radial) -> (nr_of_pairs, emb_size) + rbf = self.dense_rbf(f_ij) + + # Gather atomic numbers for neighbor pairs + # Z_i and Z_j have shape (nr_of_pairs) + Z_i = inputs.atomic_numbers[pairlist_output.pair_indices[0]] + Z_j = inputs.atomic_numbers[pairlist_output.pair_indices[1]] + + # Get embeddings for atomic numbers + # x_i and x_j have shape (E, emb_size) + x_i = self.embeddings(Z_i) + x_j = self.embeddings(Z_j) + + # Concatenate embeddings and transformed rbf + # x has shape (E, 3 * emb_size) + x = torch.cat([x_i, x_j, rbf], dim=-1) + + # Final transformation + # x: (E, 3 * emb_size) -> (E, emb_size) + x = self.dense(x) + + return x + + +class Envelope(nn.Module): + """ + Envelope function that ensures a smooth cutoff. + """ + + def __init__(self, exponent: int, radial_cutoff: float): + super().__init__() + self.exponent = exponent + + # Precompute constants + p = torch.tensor(exponent, dtype=torch.int32) + self.register_buffer("p", p) + self.register_buffer("a", -((p + 1) * (p + 2)) / 2) + self.register_buffer("b", p * (p + 2)) + self.register_buffer("c", -p * (p + 1) / 2) + self.register_buffer("cutoff", torch.tensor([1 / radial_cutoff])) + + def forward(self, d_ij: torch.Tensor) -> torch.Tensor: + # Compute powers efficiently + normalize_d_ij = self.cutoff * d_ij + d_ij_power_p = torch.pow(normalize_d_ij, self.p) + d_ij_power_p_plus1 = d_ij_power_p * normalize_d_ij # inputs ** self.p + d_ij_power_p_plus2 = ( + d_ij_power_p_plus1 * normalize_d_ij + ) # inputs ** (self.p + 1) + + # Envelope function divided by r + env_val = ( + 1.0 + + self.a * d_ij_power_p + + self.b * d_ij_power_p_plus1 + + self.c * d_ij_power_p_plus2 + ) + + # set all negative entries to zero, because d_ij is outside cutoff + env_val1 = torch.nn.functional.relu(env_val) + env_val2 = torch.where(normalize_d_ij < 1.0, env_val, torch.zeros_like(env_val)) + assert torch.allclose(env_val1, env_val2) # FIXME: can be removed + + return env_val1 + + +class BesselBasisLayer(nn.Module): + """ + Bessel Basis Layer as used in DimeNet++. + """ + + def __init__( + self, + number_of_radial_bessel_functions: int, + radial_cutoff: float, + envelope_exponent: int = 6, + ): + super().__init__() + self.number_of_radial_bessel_functions = number_of_radial_bessel_functions + self.register_buffer( + "inv_cutoff", torch.tensor(1.0 / radial_cutoff, dtype=torch.float32) + ) + + # Initialize frequencies at canonical positions + frequencies = torch.pi * torch.arange( + 1, number_of_radial_bessel_functions + 1, dtype=torch.float32 + ) + self.frequencies = self.register_buffer(frequencies) + + pre_factor = torch.sqrt(2 / radial_cutoff) + self.prefactor = self.register_buffer(pre_factor) + + def forward(self, d_ij: torch.Tensor) -> torch.Tensor: + # d_ij: Pairwise distances between atoms. Shape: (nr_pairs, 1) + + # Scale distances + # d_scaled = d_ij * self.inv_cutoff # Shape: (nr_pairs, 1) + + # Compute Bessel basis + # NOTE: the result in basis below is alread multiplied with the envelope function + basis = torch.sin( + self.frequencies * d_ij * self.inv_cutoff + ) # Shape: nr_pairs, num_radial) + + return self.prefactor * basis / d_ij + + +class DimeNetCore(torch.nn.Module): + def __init__( + self, + featurization: Dict[str, Dict[str, int]], + number_of_blocks: int, + dimension_of_bilinear_layer: int, + number_of_spherical_harmonics: int, + number_of_radial_bessel_functions: int, + maximum_interaction_radius: float, + envelope_exponent: int, + activation_function_parameter: Dict[str, str], + predicted_properties: List[str], + predicted_dim: List[int], + potential_seed: int = -1, + ) -> None: + + from modelforge.utils.misc import seed_random_number + + if potential_seed != -1: + seed_random_number(potential_seed) + + super().__init__() + + self.activation_function = activation_function_parameter["activation_function"] + + log.debug("Initializing the DimeNet architecture.") + + self.representation_module = Representation( + number_of_radial_bessel_functions=number_of_radial_bessel_functions, + radial_cutoff=maximum_interaction_radius, + number_of_spherical_harmonics=number_of_spherical_harmonics, + envelope_exponent=envelope_exponent, + activation_function=self.activation_function, + embedding_size=32, + ) + + def compute_properties( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: + """ + Compute properties based on the input data and pair list. + + Parameters + ---------- + data : NNPInput + Input data including atomic numbers, positions, etc. + pairlist_output: PairlistData + Output from the pairlist module, containing pair indices and + distances. + + Returns + ------- + Dict[str, torch.Tensor] + A dictionary containing the computed properties for each atom. + """ + + # Compute the atomic representation, which includes + # - radial/angular bessel function + # - embedding of pairwise distances + + representation = self.representation_module( + data, pairlist_output + ) # includes 'm_ij', 'radial_bessel', 'angular_bessel' + + # Apply interaction modules to update the atomic embedding + + return None + + def forward( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: + """ + Forward pass of the DimeNet model. + + Parameters + ---------- + data : NNPInput + Input data including atomic numbers, positions, and relevant fields. + pairlist_output : PairlistData + Pair indices and distances from the pairlist module. + + Returns + ------- + Dict[str, torch.Tensor] + A dictionary of calculated properties from the forward pass. + """ + # Compute properties using the core method + results = self.compute_properties(data, pairlist_output) + atomic_embedding = results["per_atom_scalar_representation"] + + # Apply output layers to the atomic embedding + for output_name, output_layer in self.output_layers.items(): + results[output_name] = output_layer(atomic_embedding).squeeze(-1) + + return results + + +class Representation(nn.Module): + + def __init__( + self, + radial_cutoff: float, + number_of_radial_bessel_functions: int, + number_of_spherical_harmonics: int, + envelope_exponent: int, + activation_function: torch.nn.Module, + embedding_size: int, + ): + """ + Initialize the representation module. + """ + super().__init__() + + # The representation part of DimeNet++ includes + # - radial bessel basis (featurization of distances) + # - angular bessel basis (featurization of angles) + # - embedding of pairwise distances + self.radial_bessel_function = BesselBasisLayer( + number_of_radial_bessel_functions=number_of_radial_bessel_functions, + radial_cutoff=radial_cutoff, + envelope_exponent=envelope_exponent, + ) + from torch.nn import Identity + + self.angular_bessel_function = Identity() + self.envelope = Envelope(envelope_exponent, radial_cutoff) + + self.embedding = EmbeddingBlock( + embedding_size=embedding_size, + number_of_radial_bessel_functions=number_of_radial_bessel_functions, + activation_function=activation_function, + ) + + # to embed the messages + + def forward( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: + """ + Forward pass to generate the radial symmetry representation of pairwise + distances. + + Parameters + ---------- + data : NNPInput + Input data containing atomic numbers and positions. + pairlist_output : PairlistData + Output from the pairlist module, containing pair indices and distances. + + Returns + ------- + Dict[str, torch.Tensor] + A dictionary containing radial/angular bessel basis and first message. + """ + + # Convert distances to radial bessel functions + radial_bessel = self.radial_bessel_function(pairlist_output.d_ij) + # Apply envelope + d_cutoff = self.envelope(pairlist_output.d_ij / self.r) # Shape: (nr_pairs, 1) + radial_bessel = radial_bessel * d_cutoff + + # convert distances to angular bessel functions + angular_bessel = self.angular_bessel_function() + + # generate first message + m_ij = self.embedding(data, pairlist_output, radial_bessel) + + return { + "m_ij": m_ij, + "radial_bessel": radial_bessel, + "angular_bessel": angular_bessel, + } diff --git a/modelforge/potential/parameters.py b/modelforge/potential/parameters.py index 0846033d..86aa39e7 100644 --- a/modelforge/potential/parameters.py +++ b/modelforge/potential/parameters.py @@ -238,6 +238,28 @@ class CoreParameter(CoreParameterBase): postprocessing_parameter: PostProcessingParameter potential_seed: int = -1 +class DimeNetParameters(ParametersBase): + class CoreParameter(CoreParameterBase): + number_of_radial_bessel_functions: int + maximum_interaction_radius: float + dimension_of_bilinear_layer: int + number_of_blocks : int + number_of_spherical_harmonics: int + envelope_exponent: int + activation_function_parameter: ActivationFunctionConfig + featurization: Featurization + predicted_properties: List[str] + predicted_dim: List[int] + + converted_units = field_validator("maximum_interaction_radius", mode="before")( + _convert_str_or_unit_to_unit_length + ) + + potential_name: str = "DimeNet" + core_parameter: CoreParameter + postprocessing_parameter: PostProcessingParameter + potential_seed: int = -1 + class TensorNetParameters(ParametersBase): class CoreParameter(CoreParameterBase): diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index 8383f09d..5d492eb6 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -2,7 +2,7 @@ SchNet neural network potential for modeling quantum interactions. """ -from typing import Dict, List, Type +from typing import Dict, List import torch import torch.nn as nn diff --git a/modelforge/tests/data/potential_defaults/dimenet.toml b/modelforge/tests/data/potential_defaults/dimenet.toml new file mode 100644 index 00000000..db135096 --- /dev/null +++ b/modelforge/tests/data/potential_defaults/dimenet.toml @@ -0,0 +1,30 @@ +# ------------------------------------------------------------ # +[potential] +potential_name = "DimeNet" +# ------------------------------------------------------------ # +[potential.core_parameter] +number_of_radial_bessel_functions = 16 +maximum_interaction_radius = "5.0 angstrom" +dimension_of_bilinear_layer = 16 +number_of_spherical_harmonics = 6 +envelope_exponent = 6 +number_of_blocks = 2 +predicted_properties = ["per_atom_energy", 'per_atom_charge'] +predicted_dim = [1, 1] +[potential.core_parameter.activation_function_parameter] +activation_function_name = "SiLU" +[potential.core_parameter.featurization] +properties_to_featurize = ['atomic_number'] +[potential.core_parameter.featurization.atomic_number] +maximum_atomic_number = 101 +number_of_per_atom_features = 32 +# ------------------------------------------------------------ # +[potential.postprocessing_parameter] +properties_to_process = ['per_atom_energy'] +[potential.postprocessing_parameter.per_atom_energy] +normalize = true +from_atom_to_system_reduction = true +keep_per_atom_property = true +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +conserve_strategy = "default" diff --git a/modelforge/tests/test_dimenet.py b/modelforge/tests/test_dimenet.py new file mode 100644 index 00000000..f3dcc6c6 --- /dev/null +++ b/modelforge/tests/test_dimenet.py @@ -0,0 +1,144 @@ +from typing import Optional +import pytest + + +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_dimenet_temp") + return fn + + +def setup_dimenet_model(potential_seed: Optional[int] = None): + from modelforge.tests.test_potentials import load_configs_into_pydantic_models + from modelforge.potential import NeuralNetworkPotentialFactory + + # read default parameters + config = load_configs_into_pydantic_models("dimenet", "qm9") + + model = NeuralNetworkPotentialFactory.generate_potential( + use="inference", + potential_parameter=config["potential"], + training_parameter=config["training"], + dataset_parameter=config["dataset"], + runtime_parameter=config["runtime"], + potential_seed=potential_seed, + use_training_mode_neighborlist=True, + jit=False, + ) + return model + + +def test_init(): + """Test initialization of the Dimenet model.""" + potential = setup_dimenet_model() + assert potential is not None, "Dimenet model should be initialized." + + +def test_forward(single_batch_with_batchsize, prep_temp_dir): + import torch + + potential = setup_dimenet_model() + print(potential) + + batch = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ) + + yhat = potential(batch.nnp_input.to_dtype(dtype=torch.float32)) + + +def test_envelope(): + from modelforge.potential.dimenet import Envelope + import torch + + # Create an instance of the Envelope class + envelope = Envelope(exponent=5) + + # Sample input tensor + inputs = torch.tensor([0.5, 0.8, 1.0, 1.2], dtype=torch.float32) + + # Forward pass + outputs = envelope(inputs) + assert outputs.shape == inputs.shape + assert torch.allclose( + outputs, torch.tensor([1.7109, 0.2539, 0.0000, 0.0000]), rtol=1e-3 + ) + + # Script the model for optimization and deployment + scripted_envelope = torch.jit.script(envelope) + + # Verify that the scripted model works + outputs_scripted = scripted_envelope(inputs) + print(outputs_scripted) + assert torch.allclose(outputs, outputs_scripted) + + # ----------------------------------------- # + # test for correct output computation + + exponent = ( + 5 + 1 + ) # NOTE: Envelop function receives exponent = 5 but takes its increment and uses exponent = 6 FIXME: that seems strange ? + # start with test for float: + d_ij = 0.5 + # generate envelope function of d_ij value + u_05 = ( + 1 + - (exponent + 1) * (exponent + 2) / 2 * d_ij**exponent + + exponent * (exponent + 2) * d_ij ** (exponent + 1) + - exponent * (exponent + 1) / 2 * d_ij ** (exponent + 2) + ) + u_05 /= d_ij # NOTE: this is not in the paper, but in the DimNet++ implementation + + # NOTE: this test passes, but only if you divide by d_ij at the end, which is not in the paper, but in the DimNet++ implementation + u_05 = torch.tensor([u_05], dtype=torch.float32) + + assert torch.allclose(u_05, outputs[0], rtol=1e-3) + + +def test_bessel_basis(): + import torch + from modelforge.potential.dimenet import BesselBasisLayer + + # Create an instance of the BesselBasisLayer + num_radial = 6 + radial_cutoff = 0.5 + bessel_layer = BesselBasisLayer( + number_of_radial_bessel_functions=num_radial, + radial_cutoff=radial_cutoff, + envelope_exponent=5, + ) + + # Sample input tensor of distances + num_pairs = 100 + d_ij = torch.linspace(0, radial_cutoff, steps=num_pairs).unsqueeze( + -1 + ) # Shape: (100,1) + + # Forward pass + outputs = bessel_layer(d_ij) # Shape: (100, num_radial) + shape_tensor = torch.randn( + num_pairs, num_radial + ) # output from besser_layer should have this size + assert shape_tensor.shape == outputs.shape # Should print: torch.Size([100, 6]) + + +def test_representation(): + from modelforge.potential.dimenet import Representation + from torch.nn import SiLU + + # Create an instance of the RepresentationBlock + number_of_radial_bessel_functions = 5 + radial_cutoff = 0.5 + number_of_spherical_harmonics = 7 + envelope_exponent = 5 + activation_function = SiLU() + embedding_size = 32 + + rep = Representation( + number_of_radial_bessel_functions=number_of_radial_bessel_functions, + radial_cutoff=radial_cutoff, + number_of_spherical_harmonics=number_of_spherical_harmonics, + envelope_exponent=envelope_exponent, + activation_function=activation_function, + embedding_size=embedding_size, + )