Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions configs/drugs-sfm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# task name for logging
task_name: sfm-drugs/base-so3

# unique seed for experiment reproducibility
seed: 42

# data config
datamodule: BaseDataModule
datamodule_args:
dataset: EuclideanDataset
dataset_args:
dataset_name: geom
use_ogb_feat: true

train_indices_path: /nfs/scratch/students/data/geom/preprocessed/drugs_train_0.9.npy
val_indices_path: /nfs/scratch/students/data/geom/preprocessed/drugs_val_0.8.npy
test_indices_path: /nfs/scratch/students/data/geom/preprocessed/drugs_val_0.1.npy

# dataloader args
dataloader_args:
batch_size: 32
num_workers: 4
pin_memory: false
persistent_workers: true

# model config
model: BaseSFM
model_args:
# network args
network_type: TorchMDDynamicsWithScore
hidden_channels: 160
num_layers: 20
num_rbf: 64
rbf_type: expnorm
trainable_rbf: true
activation: silu
neighbor_embedding: true
cutoff_lower: 0.0
cutoff_upper: 10.0
max_z: 100
node_attr_dim: 10
edge_attr_dim: 1
attn_activation: silu
num_heads: 8
distance_influence: both
reduce_op: sum
qk_norm: true
so3_equivariant: true
output_layer_norm: true
clip_during_norm: true

# flow matching specific
normalize_node_invariants: false
sigma: 0.1
prior_type: harmonic
separate_encoders: false

# optimizer args
optimizer_type: AdamW
lr: 8.e-4
weight_decay: 1.e-8

# lr scheduler args
lr_scheduler_type: CosineAnnealingWarmupRestarts
first_cycle_steps: 500_000
cycle_mult: 1.0
max_lr: 8.e-4
min_lr: 1.e-8
warmup_steps: 0
gamma: 0.05
last_epoch: -1
lr_scheduler_monitor: val/loss
lr_scheduler_interval: step
lr_scheduler_frequency: 1

# callbacks
callbacks:
- callback: ModelCheckpoint
callback_args:
dirpath: './checkpoint'
monitor: val/loss
mode: min
save_last: true
every_n_epochs: 1
save_top_k: 3

- callback: LearningRateMonitor
callback_args:
log_momentum: false
logging_interval: null


# logger
logger: WandbLogger
logger_args:
project: Energy-Aware-MCG
entity: doms-lab

# trainer
trainer: Trainer
trainer_args:
max_epochs: 200
devices: 8
limit_train_batches: 5000
strategy: ddp_find_unused_parameters_true
accelerator: auto
1 change: 1 addition & 0 deletions etflow/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .model import BaseFlow
from .sfm import BaseSFM

__all__ = [
"BaseFlow",
Expand Down
31 changes: 29 additions & 2 deletions etflow/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
rmsd_align,
unsqueeze_like,
)
from etflow.networks.torchmd_net import TensorNetDynamics, TorchMDDynamics
from etflow.networks.torchmd_net import (
TensorNetDynamics,
TorchMDDynamics,
TorchMDDynamicsWithScore,
)


class BaseFlow(BaseModel):
Expand Down Expand Up @@ -45,7 +49,7 @@ def __init__(
distance_influence: str = "both",
reduce_op: str = "sum",
qk_norm: bool = False,
output_layer_norm: bool = False,
output_layer_norm: bool = True,
clip_during_norm: bool = False,
max_num_neighbors: int = 32,
so3_equivariant: bool = False,
Expand Down Expand Up @@ -132,6 +136,29 @@ def __init__(
clip_during_norm=clip_during_norm,
so3_equivariant=so3_equivariant,
)
elif network_type == "TorchMDDynamicsWithScore":
self.network = TorchMDDynamicsWithScore(
hidden_channels=hidden_channels,
num_layers=num_layers,
num_rbf=num_rbf,
rbf_type=rbf_type,
trainable_rbf=trainable_rbf,
activation=activation,
neighbor_embedding=neighbor_embedding,
cutoff_lower=cutoff_lower,
cutoff_upper=cutoff_upper,
max_z=max_z,
node_attr_dim=node_attr_dim,
edge_attr_dim=edge_attr_dim,
attn_activation=attn_activation,
num_heads=num_heads,
distance_influence=distance_influence,
reduce_op=reduce_op,
qk_norm=qk_norm,
output_layer_norm=output_layer_norm,
clip_during_norm=clip_during_norm,
so3_equivariant=so3_equivariant,
)
elif network_type == "TensorNetDynamics":
self.network = TensorNetDynamics(
hidden_channels=hidden_channels,
Expand Down
Loading