diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index a896c36a22..164373df10 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -12,7 +12,9 @@ "source=${localEnv:HOME}/.cursor,target=/home/ubuntu/.cursor,type=bind,consistency=cached", "source=${localEnv:HOME}/.gnupg,target=/home/ubuntu/.gnupg,type=bind,consistency=cached", "source=${localEnv:HOME}/.netrc,target=/home/ubuntu/.netrc,type=bind,consistency=cached", - "source=${localEnv:HOME}/.ssh,target=/home/ubuntu/.ssh,readonly,type=bind,consistency=cached" + "source=${localEnv:HOME}/.ssh,target=/home/ubuntu/.ssh,readonly,type=bind,consistency=cached", + "source=${localEnv:HOME}/Projects/minifold,target=/workspaces/minifold,type=bind,consistency=cached" + ], "postCreateCommand": ".devcontainer/postCreateCommand.sh", "initializeCommand": ".devcontainer/initializeCommand.sh", diff --git a/.devcontainer/requirements.txt b/.devcontainer/requirements.txt index e88a8b58cf..b73a0871a9 100644 --- a/.devcontainer/requirements.txt +++ b/.devcontainer/requirements.txt @@ -4,6 +4,7 @@ deepspeed hydra-core lm-eval megatron-fsdp +ml_collections peft pytest torch diff --git a/bionemo-recipes/recipes/esm2_minifold_te/.dockerignore b/bionemo-recipes/recipes/esm2_minifold_te/.dockerignore new file mode 100644 index 0000000000..4e9043c001 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/.dockerignore @@ -0,0 +1,38 @@ +# Docker +Dockerfile +Dockerfile.* +.dockerignore + +# Docs +README.md + +# Python caches +__pycache__ +.pytest_cache +.ruff_cache +.venv/ + +# Linting +.ruff.toml + +# Profiling & debugging artifacts +memory_snapshots/ +nsight_profiling/ +*.nsys-rep +*.sqlite +logs/ +wandb/ + +# Downloaded CIF files (reproduced by data/prepare_*.py scripts) +data/cif_files/ +data/eval_cif_files/ + +# Hydra / training outputs +outputs/ +checkpoints/ + +# Checkpoint export +checkpoint_export/ + +# Temp / scratch +j/ diff --git a/bionemo-recipes/recipes/esm2_minifold_te/.ruff.toml b/bionemo-recipes/recipes/esm2_minifold_te/.ruff.toml new file mode 100644 index 0000000000..9c7dce4ff9 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/.ruff.toml @@ -0,0 +1,26 @@ +extend = "../../../.ruff.toml" + +[lint] +# Extend ignores for code ported from MiniFold +ignore = [ + "C901", # Complexity + "D100", # Missing module docstring + "D101", # Missing class docstring + "D102", # Missing method docstring + "D103", # Missing function docstring + "D105", # Missing magic method docstring + "D107", # Missing __init__ docstring + "D205", # 1 blank line required between summary and description + "D415", # First line should end with period + "D417", # Missing argument descriptions + "E501", # Line too long + "E721", # Type comparison + "E731", # Lambda assignment + "E741", # Ambiguous variable name + "F841", # Unused variable + "N806", # Uppercase variable in function + "N812", # Lowercase imported as non-lowercase + "PLW2901", # Loop variable overwritten + "RUF005", # Collection literal concatenation + "RUF010", # Explicit f-string conversion +] diff --git a/bionemo-recipes/recipes/esm2_minifold_te/Dockerfile b/bionemo-recipes/recipes/esm2_minifold_te/Dockerfile new file mode 100644 index 0000000000..3b35d25fd0 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/Dockerfile @@ -0,0 +1,9 @@ +# syntax=docker/dockerfile:1.4 +FROM nvcr.io/nvidia/pytorch:26.02-py3 + +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=requirements.txt,target=/requirements.txt \ + PIP_CONSTRAINT= pip install -r /requirements.txt + +WORKDIR /workspace/bionemo +COPY . . diff --git a/bionemo-recipes/recipes/esm2_minifold_te/checkpoint.py b/bionemo-recipes/recipes/esm2_minifold_te/checkpoint.py new file mode 100644 index 0000000000..8a551d6499 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/checkpoint.py @@ -0,0 +1,608 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import os +import shutil +from dataclasses import dataclass, field +from pathlib import Path +from typing import NamedTuple + +import torch +import transformers +from safetensors.torch import save_file +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_state_dict, + set_state_dict, +) +from torch.distributed.checkpoint.state_dict_loader import load as dcp_load +from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save +from torch.distributed.checkpoint.state_dict_saver import save as dcp_save +from torch.distributed.checkpoint.stateful import Stateful +from torchdata.stateful_dataloader import StatefulDataLoader + +from distributed_config import DistributedConfig + + +logger = logging.getLogger(__name__) +_ckpt_futures: dict = {} + + +class CheckpointOutput(NamedTuple): + """Output of checkpoint loading.""" + + model: torch.nn.Module + optimizer: torch.optim.Optimizer + scheduler: torch.optim.lr_scheduler.LRScheduler + dataloader: StatefulDataLoader | None + step: int + epoch: int + + +# ============================================================================ +# Helper functions +# ============================================================================ + + +def get_latest_checkpoint(ckpt_path: str | os.PathLike) -> tuple[Path | None, int]: + """Get the latest checkpoint path and step number. + + Returns: + Tuple of (checkpoint path, step number). + If no checkpoint files are found, returns (None, 0). + """ + ckpt_path = Path(ckpt_path) + if not ckpt_path.exists(): + return None, 0 + + checkpoints = [f for f in ckpt_path.iterdir() if f.name.startswith("step_")] + + if not checkpoints: + return None, 0 + + latest = max(checkpoints, key=lambda x: int(Path(x).stem.split("_")[1])) + step = int(Path(latest).stem.split("_")[1]) + return latest, step + + +def should_save_checkpoint(step: int, save_every_n_steps: int) -> bool: + """Determine if a checkpoint should be saved.""" + return save_every_n_steps > 0 and step % save_every_n_steps == 0 and step > 0 + + +def prune_checkpoints(ckpt_path: str | os.PathLike, max_checkpoints: int) -> None: + """Prune checkpoints to keep only the latest `max_checkpoints` checkpoints.""" + ckpt_path = Path(ckpt_path) + checkpoints = [f for f in ckpt_path.iterdir() if f.name.startswith("step_")] + checkpoints.sort(key=lambda x: int(Path(x).stem.split("_")[1])) + if len(checkpoints) > max_checkpoints: + for checkpoint in checkpoints[:-max_checkpoints]: + logger.info(f"Pruning checkpoint {checkpoint}") + if checkpoint.is_dir(): + shutil.rmtree(checkpoint) + else: + os.remove(checkpoint) + + +# ============================================================================ +# DDP Checkpointing +# ============================================================================ + + +def load_checkpoint_ddp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, +) -> CheckpointOutput: + """Load DDP checkpoint.""" + checkpoint_path, _ = get_latest_checkpoint(ckpt_path) + + if not checkpoint_path: + logger.info("No DDP checkpoint found, starting from scratch") + return CheckpointOutput(model, optimizer, scheduler, dataloader, 0, 0) + + checkpoint = torch.load( + checkpoint_path / "checkpoint.pt", + map_location=f"cuda:{dist_config.local_rank}", + weights_only=True, + ) + + model.load_state_dict(checkpoint["model"], strict=False) + optimizer.load_state_dict(checkpoint["optimizer"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + dataloader = load_dataloader(dataloader, checkpoint_path, dist_config) + step = checkpoint["step"] + epoch = checkpoint["epoch"] + + if dist_config.is_main_process(): + logger.info(f"Loaded DDP checkpoint from step {step}") + + # Increment the step by one to avoid re-running the previous step. + return CheckpointOutput(model, optimizer, scheduler, dataloader, step + 1, epoch) + + +def save_checkpoint_ddp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + step: int, + epoch: int, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + max_checkpoints: int | None = None, +) -> None: + """Saves the Dataloader state and the DDP checkpoint.""" + ckpt_path = Path(ckpt_path) + checkpoint_path = ckpt_path / f"step_{step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + # Dataloader checkpointing needs to happen on all ranks, while DDP model checkpointing only needs to happen on the + # main process. + save_dataloader(dataloader, checkpoint_path, dist_config) + + if not dist_config.is_main_process(): + return + + torch.save( + { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "step": step, + "epoch": epoch, + }, + checkpoint_path / "checkpoint.pt", + ) + + logger.info(f"Saved DDP checkpoint to {checkpoint_path}") + + if max_checkpoints is not None and dist_config.is_main_process(): + prune_checkpoints(ckpt_path, max_checkpoints) + + +def save_final_model_ddp( + model: torch.nn.Module, + save_directory: str | os.PathLike, + dist_config: DistributedConfig, +) -> None: + """Save final model for DDP - only on main process.""" + if not dist_config.is_main_process(): + return + + # Unwrap model if wrapped + underlying_model: transformers.PreTrainedModel = model.module if hasattr(model, "module") else model # type: ignore + + os.makedirs(save_directory, exist_ok=True) + underlying_model.save_pretrained(save_directory, state_dict=underlying_model.state_dict(), safe_serialization=True) + logger.info(f"Saved final DDP model to {save_directory}") + + +# ============================================================================ +# mFSDP Checkpointing +# ============================================================================ + + +def load_checkpoint_mfsdp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, +) -> CheckpointOutput: + """Load mFSDP distributed checkpoint. + + Args: + model: The model to load. + optimizer: The optimizer to load. + scheduler: The LR scheduler to load. + ckpt_path: The directory containing checkpoints. + dist_config: The distributed configuration. + dataloader: The dataloader to load. + + Returns: + Tuple of (model, optimizer, scheduler, step). + """ + checkpoint_path, step = get_latest_checkpoint(ckpt_path) + if not checkpoint_path: + logger.info("No mFSDP checkpoint found, starting from scratch") + return CheckpointOutput(model, optimizer, scheduler, dataloader, 0, 0) + + ckpt_state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "metadata": { + "step": step, # Initialize with current step from filename + "epoch": 0, # Initialize with default epoch + }, + } + torch.distributed.checkpoint.load(state_dict=ckpt_state_dict, checkpoint_id=checkpoint_path) + + model.load_state_dict(ckpt_state_dict["model"], strict=False) + optimizer.load_state_dict(ckpt_state_dict["optimizer"]) + scheduler.load_state_dict(ckpt_state_dict["scheduler"]) + dataloader = load_dataloader(dataloader, checkpoint_path, dist_config) + + step = ckpt_state_dict["metadata"]["step"] + epoch = ckpt_state_dict["metadata"]["epoch"] + + # Ensure all ranks have completed loading before proceeding + torch.distributed.barrier() + + logger.info(f"Loaded mFSDP checkpoint from step {step}") + + # Increment the step by one to avoid re-running the previous step. + return CheckpointOutput(model, optimizer, scheduler, dataloader, step + 1, epoch) + + +def save_checkpoint_mfsdp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + step: int, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + epoch: int = 0, + max_checkpoints: int | None = None, +) -> None: + """Save mFSDP distributed checkpoint. + + Args: + model: The model to save. + optimizer: The optimizer to save. + scheduler: The LR scheduler to save. + ckpt_path: The directory to save the checkpoint. + step: The step number to save the checkpoint. + dist_config: The distributed configuration. + dataloader: The dataloader to save. + epoch: The epoch number to save the checkpoint. + max_checkpoints: The maximum number of checkpoints to keep. + """ + ckpt_path = Path(ckpt_path) + checkpoint_path = ckpt_path / f"step_{step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + # Save dataloader state, if provided. + save_dataloader(dataloader, checkpoint_path, dist_config) + + # Save model, optimizer, scheduler state, and metadata + state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "metadata": { + "step": step, + "epoch": epoch, + }, + } + + torch.distributed.checkpoint.save(state_dict, checkpoint_id=checkpoint_path) + + if dist_config.is_main_process(): + logger.info(f"Saved mFSDP checkpoint to {checkpoint_path}") + + if max_checkpoints is not None and dist_config.is_main_process(): + prune_checkpoints(ckpt_path, max_checkpoints) + + +def save_final_model_mfsdp( + model: torch.nn.Module, + save_directory: str | os.PathLike, + dist_config: DistributedConfig, +) -> None: + """Save final model for mFSDP - requires parameter gathering on all ranks.""" + from megatron_fsdp.uneven_dtensor import gather_uneven_dtensor_to_full_tensor + + if dist_config.is_main_process(): + logger.info("Starting mFSDP parameter gathering...") + + # Parameter gathering must happen on ALL processes + unsharded_state_dict = { + # Gather all parameters to CPU, and remove the "module." prefix from the Megatron-FSDP class wrapper. + k.removeprefix("module."): gather_uneven_dtensor_to_full_tensor( + v, target_device=torch.device("cpu") + ).to_local() + if isinstance(v, torch.distributed.tensor.DTensor) + else v + for k, v in model.state_dict().items() + } + + # Only main process saves the model + if not dist_config.is_main_process(): + return + + os.makedirs(save_directory, exist_ok=True) + model.module.save_pretrained(save_directory, state_dict=unsharded_state_dict, safe_serialization=True) + logger.info(f"Saved final mFSDP model to {save_directory}") + + +# ============================================================================ +# FSDP2 Checkpointing +# ============================================================================ + + +@dataclass +class AppState(Stateful): + """AppState for FSDP2 checkpoint. + + Adapted from https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html + """ + + model: torch.nn.Module + optimizer: torch.optim.Optimizer + scheduler: torch.optim.lr_scheduler.LRScheduler + step: int = 0 + epoch: int = 0 + state_dict_options: StateDictOptions = field( + default_factory=lambda: StateDictOptions( + full_state_dict=False, + cpu_offload=True, + strict=False, + ) + ) + + def state_dict(self): + """Get the state dict for the model, optimizer, scheduler, and step.""" + model_state_dict, optimizer_state_dict = get_state_dict( + self.model, self.optimizer, options=self.state_dict_options + ) + return { + "model": model_state_dict, + "optim": optimizer_state_dict, + "scheduler": self.scheduler.state_dict(), + "step": self.step, + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict: dict): + """Load the state dict for the model, optimizer, scheduler, and step.""" + set_state_dict( + self.model, + self.optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optim"], + options=self.state_dict_options, + ) + self.scheduler.load_state_dict(state_dict["scheduler"]) + self.step = state_dict["step"] + self.epoch = state_dict["epoch"] + + +def load_checkpoint_fsdp2( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + process_group: torch.distributed.ProcessGroup | None = None, +) -> CheckpointOutput: + """Load FSDP2 checkpoint. + + Args: + model: The model to load. + optimizer: The optimizer to load. + scheduler: The LR scheduler to load. + ckpt_path: The directory containing checkpoints. + dist_config: The distributed configuration. + dataloader: The dataloader to load. + process_group: The process group to use for checkpointing. + """ + checkpoint_path, _ = get_latest_checkpoint(ckpt_path) + if not checkpoint_path: + logger.info("No FSDP2 checkpoint found, starting from scratch") + return CheckpointOutput(model, optimizer, scheduler, dataloader, 0, 0) + + app_state = AppState( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + state_dict = {"app": app_state} + dcp_load(state_dict, checkpoint_id=checkpoint_path, process_group=process_group) + + if dataloader is not None: + load_dataloader( + dataloader=dataloader, + ckpt_path=checkpoint_path, + dist_config=dist_config, + ) + + logger.info(f"Loaded distributed FSDP2 checkpoint from step {app_state.step}") + + # Increment the step by one to avoid re-running the previous step. + return CheckpointOutput(model, optimizer, scheduler, dataloader, app_state.step + 1, app_state.epoch) + + +def save_checkpoint_fsdp2( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + step: int, + epoch: int, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + process_group: torch.distributed.ProcessGroup | None = None, + max_checkpoints: int | None = None, + async_save: bool = False, +) -> None: + """Save FSDP2 checkpoint. + + Args: + model: The model to save. + optimizer: The optimizer to save. + scheduler: The LR scheduler to save. + ckpt_path: The directory to save the checkpoint. + step: The step number to save the checkpoint. + epoch: The epoch number to save the checkpoint. + dist_config: The distributed configuration. + dataloader: The dataloader to save. + process_group: The process group to use for checkpointing. + max_checkpoints: The maximum number of checkpoints to keep. + async_save: Whether to save the checkpoint asynchronously. + """ + ckpt_path = Path(ckpt_path) + checkpoint_path = ckpt_path / f"step_{step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + if dataloader is not None: + save_dataloader( + dataloader=dataloader, + ckpt_path=checkpoint_path, + dist_config=dist_config, + ) + logger.info(f"Saved FSDP2 dataloader to {ckpt_path}") + + # If we're using asynchronous checkpointing, make sure we only have one checkpoint future at a time. + if async_save and "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None: + _ckpt_futures["fsdp2"].result() + + # Clear GPU cache before checkpointing to free up fragmented memory. + gc.collect() + torch.cuda.empty_cache() + torch.distributed.barrier(group=process_group) + + state_dict = {"app": AppState(model=model, optimizer=optimizer, scheduler=scheduler, step=step, epoch=epoch)} + ckpt_save_func = dcp_async_save if async_save else dcp_save + _ckpt_futures["fsdp2"] = ckpt_save_func(state_dict, checkpoint_id=checkpoint_path, process_group=process_group) + + if dist_config.is_main_process(): + logger.info(f"Saved distributed FSDP2 checkpoint to {checkpoint_path}") + + if max_checkpoints is not None and dist_config.is_main_process(): + prune_checkpoints(ckpt_path, max_checkpoints) + + +def save_final_model_fsdp2( + model: torch.nn.Module, + save_directory: str | os.PathLike, + dist_config: DistributedConfig, +) -> None: + """Save final model for FSDP2 - gather on all ranks, save on main.""" + # ALL ranks must participate in gathering + model_state_dict = get_model_state_dict( + model=model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=True, + ), + ) + + # Only main process saves + if not dist_config.is_main_process(): + return + + os.makedirs(save_directory, exist_ok=True) + + # Save just the weights using safetensors + save_file(model_state_dict, os.path.join(save_directory, "model.safetensors")) + + # Save the config + underlying_model = model.module if hasattr(model, "module") else model + if hasattr(underlying_model, "config"): + underlying_model.config.save_pretrained(save_directory) + + logger.info(f"Saved final FSDP2 model to {save_directory} (weights + config only)") + + +# ============================================================================ +# Dataloader Checkpointing +# ============================================================================ + + +def save_dataloader( + dataloader: StatefulDataLoader | None, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, +): + """Save the dataloader state to a file. + + For resuming training with long epochs, we save the dataloader state as part of the checkpoint to allow for resuming + from the exact same step. Here we save the dataloader state based on global rank. Note, the total number of ranks + and dataloader num_workers should match for resuming training. + + Args: + dataloader: The dataloader to save the state of. + ckpt_path: The path to save the dataloader state to. + dist_config: The distributed configuration. + """ + if dataloader is None: + return + + ckpt_path = Path(ckpt_path) + ckpt_path.mkdir(parents=True, exist_ok=True) + dataloader_path = ckpt_path / f"dataloader_rank_{dist_config.rank}.pt" + + dataloader_state = dataloader.state_dict() + dataloader_state["num_workers"] = dataloader.num_workers + dataloader_state["num_ranks"] = dist_config.world_size + torch.save(dataloader_state, dataloader_path) + if dist_config.is_main_process(): + logger.info(f"Saved dataloader state to {dataloader_path}") + + +def load_dataloader( + dataloader: StatefulDataLoader | None, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, +) -> StatefulDataLoader | None: + """Load the dataloader state from a file. + + Here we load the dataloader state based on global rank. + + Args: + dataloader: The dataloader to load the state of. + ckpt_path: The path to load the dataloader state from. + dist_config: The distributed configuration. + """ + if dataloader is None: + return dataloader + + dataloader_path = Path(ckpt_path) / f"dataloader_rank_{dist_config.rank}.pt" + if not dataloader_path.exists(): + logger.warning( + f"No dataloader checkpoint found for rank {dist_config.rank}, starting dataloader from scratch." + ) + return dataloader + + dataloader_state = torch.load(dataloader_path, weights_only=True) + + if ( + dataloader.num_workers != dataloader_state["num_workers"] + or dist_config.world_size != dataloader_state["num_ranks"] + ): + logger.warning( + f"Dataloader num_workers mismatch: {dataloader.num_workers} != {dataloader_state['num_workers']} or " + f"num_ranks mismatch: {dist_config.world_size} != {dataloader_state['num_ranks']}, " + "starting dataloader from scratch." + ) + return dataloader + + dataloader.load_state_dict(dataloader_state) + if dist_config.is_main_process(): + logger.info(f"Loaded dataloader state from {dataloader_path}") + + return dataloader diff --git a/bionemo-recipes/recipes/esm2_minifold_te/data/eval_manifest.json b/bionemo-recipes/recipes/esm2_minifold_te/data/eval_manifest.json new file mode 100644 index 0000000000..e8865626f4 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/data/eval_manifest.json @@ -0,0 +1,902 @@ +{ + "1UCS": { + "sha256": "e17817ba7168f317da3058fdad555f0d1980f134dfb0b82cc3c198ddb376bc69", + "chain_id": "A", + "num_residues": 64, + "sequence_length": 64 + }, + "2DSX": { + "sha256": "d4310ceede850d3fb69481d87a15e6d2b861158244964c332b4bef0eddaa00b5", + "chain_id": "A", + "num_residues": 52, + "sequence_length": 52 + }, + "5D8V": { + "sha256": "f75545980992e2a173e8aca54365bd6abb9fd4758d52e9c0c46c6d075f709fd7", + "chain_id": "A", + "num_residues": 83, + "sequence_length": 83 + }, + "1YK4": { + "sha256": "0838af6da34a78fcae2a3abe31a0b1b151a42ac6bd8e421050df1bbd9e5f9b3a", + "chain_id": "A", + "num_residues": 52, + "sequence_length": 52 + }, + "5NW3": { + "sha256": "04ca1b0d6a3726e184d0d2c1bfa9a3b4ed10c6c576aa5e557fd1be42eed63d4c", + "chain_id": "A", + "num_residues": 54, + "sequence_length": 54 + }, + "3X2M": { + "sha256": "2b2e3d96c846b697a5b214fb352d954a8b8b919b8f28e5b1a2a3183255d4f1b3", + "chain_id": "A", + "num_residues": 180, + "sequence_length": 180 + }, + "7VOS": { + "sha256": "08d095120630193f83e4505d02ad2687c336f6c69d8e3534922a16818e27ddd3", + "chain_id": "A", + "num_residues": 83, + "sequence_length": 83 + }, + "2VB1": { + "sha256": "484adb8123401a095531460c53db1a68e34fd6c3e4f229deb21fe8c090ee9ce9", + "chain_id": "A", + "num_residues": 129, + "sequence_length": 129 + }, + "3A38": { + "sha256": "65d400d02f66730f47e3275413bf3724391284fd7404c820c46b1f18dca579b4", + "chain_id": "A", + "num_residues": 83, + "sequence_length": 83 + }, + "4ZM7": { + "sha256": "7c70023adb684aa67bedd6f951a516fd77110ef6f32dcf4646ac18e5d2210f44", + "chain_id": "A", + "num_residues": 180, + "sequence_length": 180 + }, + "3A39": { + "sha256": "4123c4598770e811a4b7e0e29d286a573a1652a09c45e977eb06d5c4b1c00e0f", + "chain_id": "A", + "num_residues": 83, + "sequence_length": 83 + }, + "5OME": { + "sha256": "48f6c84fd55769d9b47921d6efe134e4692a2acc6d7c40c4be730b7ad284c756", + "chain_id": "A", + "num_residues": 54, + "sequence_length": 54 + }, + "2OV0": { + "sha256": "dfdf9f23c5146363ac46d6578371ea281a872c40e375de35a6e7643f11b1b421", + "chain_id": "A", + "num_residues": 105, + "sequence_length": 105 + }, + "1R6J": { + "sha256": "ab29c369918d49fe107d664ea94522502066541e9cec750f9eaf529f49702c60", + "chain_id": "A", + "num_residues": 82, + "sequence_length": 82 + }, + "6S2M": { + "sha256": "5048cea6b134efb5f364617f01cd5bb0903a4dcb4bafaf367918a5ccc19e6f6d", + "chain_id": "A", + "num_residues": 131, + "sequence_length": 131 + }, + "2B97": { + "sha256": "826da99b86b1a10292ad7fd37fdd651f902550e8bf5d89cb1f5ea16c9127ce46", + "chain_id": "A", + "num_residues": 70, + "sequence_length": 70 + }, + "2WFI": { + "sha256": "c598d83d9c7a042d84459adb5f4ca3fcd04e1822c1c71a09164e6823881f422d", + "chain_id": "A", + "num_residues": 171, + "sequence_length": 171 + }, + "4I8H": { + "sha256": "25cd2be76a155e745761b0db7d2ed7e44e4bca7d5927d936007a96a5f90a67d8", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "2WFJ": { + "sha256": "5ea60b3462e3212c069bcc904a0c1c705841cffd88668a3f946f5b1a435b7df9", + "chain_id": "A", + "num_residues": 172, + "sequence_length": 172 + }, + "3X34": { + "sha256": "2b3ca879a7fb437751c88beae47725b62b3d6dc5366ba3fececc49dfd502519c", + "chain_id": "A", + "num_residues": 87, + "sequence_length": 87 + }, + "8C5N": { + "sha256": "abda46e2185866df21b10a5cd70e4548c5cc62faf5fdb89e03706631378073bb", + "chain_id": "A", + "num_residues": 186, + "sequence_length": 186 + }, + "5YCE": { + "sha256": "a5b1da703f206df9818ccbb8632118bcdc3ac35c792414c8d348a432cb9a955e", + "chain_id": "A", + "num_residues": 151, + "sequence_length": 151 + }, + "6L27": { + "sha256": "764d24fd8398017043532166ab2cc4f87886c1ff7fd02be34c80086767862d6e", + "chain_id": "A", + "num_residues": 226, + "sequence_length": 226 + }, + "7FEZ": { + "sha256": "ec5db59281cf8acf04e7f06703bee996871221c9d018e42ea663805b8cb30951", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "1X6Z": { + "sha256": "4428237a7a4d30523d7f386210b52b5d1363abbc57f3f795a655bbd3c2d5df9b", + "chain_id": "A", + "num_residues": 120, + "sequence_length": 120 + }, + "5KWM": { + "sha256": "0e6a9bda52eed54705b3d07db9cd3895c574854112bc7b86b14fbaab9a1e55d2", + "chain_id": "A", + "num_residues": 222, + "sequence_length": 222 + }, + "7KR0": { + "sha256": "009f875a14b8f498ae33a1d867a666dbeb1cfbe561bad49f132c1a8f561d0629", + "chain_id": "A", + "num_residues": 169, + "sequence_length": 169 + }, + "5GV8": { + "sha256": "9d32571105b2ea5cb1dac8a83d52e6f044b7490a58d137ae0559e052dae91ecf", + "chain_id": "A", + "num_residues": 272, + "sequence_length": 272 + }, + "6JGJ": { + "sha256": "305fb2448e6f199a2fc38cfed9905845eb2be7c57dd9b81af34e91be5314af30", + "chain_id": "A", + "num_residues": 227, + "sequence_length": 227 + }, + "3W5H": { + "sha256": "92d17738f582c7abf59da27748cf7491a42d4db2f8f98755a3bf06cd3d2de9e3", + "chain_id": "A", + "num_residues": 272, + "sequence_length": 272 + }, + "6ZM8": { + "sha256": "96ce6fb6fa00c53a2f1dce073c46f1eaaeb98da3862dca4fce728e5453b95b45", + "chain_id": "A", + "num_residues": 208, + "sequence_length": 208 + }, + "7A5M": { + "sha256": "41bb008db674b64ad1d71cced991ea3c3ec083cd10edf0ecc10b31a17ffe6f95", + "chain_id": "A", + "num_residues": 112, + "sequence_length": 112 + }, + "2PVE": { + "sha256": "80976be3c729a1cfb852418e5dabf4cd01a6dc8c1d5e51a213a5a4f33e85467e", + "chain_id": "A", + "num_residues": 52, + "sequence_length": 52 + }, + "5TDA": { + "sha256": "7b04a3ad3dee45646979aae83dd3561f1a3288e9bd7ea3fd28b42773f7455369", + "chain_id": "A", + "num_residues": 71, + "sequence_length": 71 + }, + "4UA6": { + "sha256": "453da4269e25a025ba15efe7851f3fab85d61aefa5b3852be41e91c902099a76", + "chain_id": "A", + "num_residues": 262, + "sequence_length": 262 + }, + "5MN1": { + "sha256": "bf72718324a409498c40e5088333d57cdd951eff6669a49e9fe71fd1d8950763", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "7R2H": { + "sha256": "254ed93f8fe79009f4066f633ab78e0962b5469ad79b736db73c4f68464aac40", + "chain_id": "A", + "num_residues": 164, + "sequence_length": 164 + }, + "1IUA": { + "sha256": "02dbcb9d02315d27f75427ca571421529aef96a8a91f04ac8320566e06a90584", + "chain_id": "A", + "num_residues": 83, + "sequence_length": 83 + }, + "5MNK": { + "sha256": "89cb57551522b3ec723e58d2a96af2ac620d20ef34473862c72dfd711a490bd9", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "1PQ7": { + "sha256": "ffb97320ba49fb2669e18bb0e6c9ddfb3e02234f29bda03887ee1bbb08e0318c", + "chain_id": "A", + "num_residues": 224, + "sequence_length": 224 + }, + "6KL0": { + "sha256": "5ec00a1b1963f6d0fe0e72b2890689117b484bb927e62871fd87198b3a8ab0be", + "chain_id": "A", + "num_residues": 227, + "sequence_length": 227 + }, + "3MFJ": { + "sha256": "cb42c3f3aaf1d1b2d726fed3541aaaac3b04ba22ce4145e63d312f136e822950", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "1W0N": { + "sha256": "03b57edfcbecee940937a24e2497e46de05de79b666aa2f975772e131d3b6538", + "chain_id": "A", + "num_residues": 120, + "sequence_length": 120 + }, + "3MI4": { + "sha256": "28f6505de9b06e3905001e76e5e972bd6a1b2b26a38b9f5c2977f8a1e000d6b4", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "4I8G": { + "sha256": "9ec2de5673c86b74abd5c49ca775486bccbabafaf1818558d17523c33c12fbae", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "3UI4": { + "sha256": "5c028ff339bd03f9813c1da8951ae24fecb37a2a33db08fcd8cf5e375519079a", + "chain_id": "A", + "num_residues": 101, + "sequence_length": 101 + }, + "5NFM": { + "sha256": "17d92e7ed3a4c1965409bb0eb1b3f901ca21a604ae72a5cd5b6981e755a04c15", + "chain_id": "A", + "num_residues": 75, + "sequence_length": 75 + }, + "5GV7": { + "sha256": "308dfa93feeb123c498ab5635211304fa745c521b72df080e3052c7f4eb0c6d5", + "chain_id": "A", + "num_residues": 272, + "sequence_length": 272 + }, + "5WQQ": { + "sha256": "509fe26b058f9e3e7b8d56711e8d8ece98456d6251a6a7bc96b5bfe01fcc8326", + "chain_id": "A", + "num_residues": 83, + "sequence_length": 83 + }, + "5WQR": { + "sha256": "40407c8ecc2c3a063ecce93d70530bc57458d1e5eff271bb8d5a1032aec74200", + "chain_id": "A", + "num_residues": 83, + "sequence_length": 83 + }, + "9M1X": { + "sha256": "38778ba9b75c3b44fb1cc41e6722918420ce81bf18f12cfd02891a057f1347ce", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "8ZST": { + "sha256": "ea1577cd5477253edf8e109e8ded56f95f61d8d82c74470747a73c7922211525", + "chain_id": "A", + "num_residues": 129, + "sequence_length": 129 + }, + "9CY0": { + "sha256": "3c2f377318289d5c612a0c51a3327f386481ba11d0cee5568c6abb40a9969354", + "chain_id": "A", + "num_residues": 168, + "sequence_length": 168 + }, + "1FN8": { + "sha256": "2c38fa9e555f6da9048f03613b8e8732d6ecb54d8c109e43777a2600a88f39a2", + "chain_id": "A", + "num_residues": 224, + "sequence_length": 224 + }, + "9M22": { + "sha256": "3d9d6865cc129b1b7586e77334f65b1adae2d20ce3b9e9c7d65f9be249938539", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "1FY5": { + "sha256": "35c44374a7810b97c907e1be513688715342522d1dccf99d1afbe761bfca6cb8", + "chain_id": "A", + "num_residues": 224, + "sequence_length": 224 + }, + "1FY4": { + "sha256": "5074491ca28c74c03529ce00bba413ec288f97ae28b14891a2865ea6ab458b7e", + "chain_id": "A", + "num_residues": 224, + "sequence_length": 224 + }, + "1GDN": { + "sha256": "e7c286503092c5d71111d8013ef3a35b53a58fb43e4aa4d3acc229158adf8e67", + "chain_id": "A", + "num_residues": 224, + "sequence_length": 224 + }, + "4M7G": { + "sha256": "ee8f8322f86b8fc95edd51265bf1acbbfcd5d79a9dcf27e1b99a70dcec7a4b0b", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "7WKB": { + "sha256": "3d2a8e0e41baaf91d26a3348eb8ddd0552ceba0c4acb9fb46e2302fcb5d5e38c", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "5HB7": { + "sha256": "dc1840ea201dc74326725c259ebc4c578fc1a6f96f7e75aeb735f7e8c40af303", + "chain_id": "A", + "num_residues": 125, + "sequence_length": 125 + }, + "2VXN": { + "sha256": "ebc8092611358b8602e9736be86a5807be5e4b261399ebf9eb1f12fbd3100cdc", + "chain_id": "A", + "num_residues": 249, + "sequence_length": 249 + }, + "2H5C": { + "sha256": "020b1a82883c3759097e73b7d97aca79b2d0f6160e848045e31866576bca8d47", + "chain_id": "A", + "num_residues": 198, + "sequence_length": 198 + }, + "1NWZ": { + "sha256": "0d35eaf5bcfcc294720fb6248f438d82bb2ce9f6c703d9b097f587a9585f6051", + "chain_id": "A", + "num_residues": 125, + "sequence_length": 125 + }, + "7AVK": { + "sha256": "8d2fffc95ecb4df772fb2c6d6ca3c3cff009e51d117d1229b9f98319414bb462", + "chain_id": "A", + "num_residues": 83, + "sequence_length": 83 + }, + "1N55": { + "sha256": "98a3234420a6eb8b716f8dfeadd084bdd872e35ede08ea8be3c49f05decea2b6", + "chain_id": "A", + "num_residues": 249, + "sequence_length": 249 + }, + "8C3X": { + "sha256": "c6ec890926e4ff5d0a8e9258290dd5e0b1d7db21218af2fdc494a7ae43c52aba", + "chain_id": "A", + "num_residues": 230, + "sequence_length": 230 + }, + "1SSX": { + "sha256": "ef95cce053dcd158bc56e6b4007778c6a418870f4238b60a0e340d12e8f6072d", + "chain_id": "A", + "num_residues": 198, + "sequence_length": 198 + }, + "2JFR": { + "sha256": "dcd5b3c0cad9f6bb51a497b0dc06cb678c51ab05acaa85727ba4d6b86bd76588", + "chain_id": "A", + "num_residues": 234, + "sequence_length": 234 + }, + "2PWA": { + "sha256": "fb70cd8bd91e9bd6cac2283f3b87bcafa6c1d9edae44f228b93602105f698983", + "chain_id": "A", + "num_residues": 279, + "sequence_length": 279 + }, + "2O9S": { + "sha256": "503fce3a26dc387135dd775e5788e8ff7810a2c215232178d1c366227a79892e", + "chain_id": "A", + "num_residues": 67, + "sequence_length": 67 + }, + "3X32": { + "sha256": "58bd130c90f15493ca6193b5a48909f3263dd9d1706d323425e650316998b60c", + "chain_id": "A", + "num_residues": 87, + "sequence_length": 87 + }, + "3X2L": { + "sha256": "f47cfb125414a730108e380918b061e2bb8f7ec634c8c233b9a4cd5d9147b9b9", + "chain_id": "A", + "num_residues": 180, + "sequence_length": 180 + }, + "7FF6": { + "sha256": "52edd40323f7b2ae2c7227c94344f5e9f9ad2c1403947a34ddd841292d2e1f0b", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "7WF0": { + "sha256": "6ec1c4398b58826b346161b30a50f0b33d7c59c6dddb9147b369c753b9d6c386", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "8RMA": { + "sha256": "598663ac6f3445e4ce4c6a1e5c02d966ade56cf8cba6bcfca194decf069023e8", + "chain_id": "A", + "num_residues": 263, + "sequence_length": 263 + }, + "9HDW": { + "sha256": "162dff9ae4db890f6d008a107bbffad3017938731ee802dcdb54311dd887bb4c", + "chain_id": "A", + "num_residues": 256, + "sequence_length": 256 + }, + "6MZ2": { + "sha256": "60a162cefb94a40c9a87fc24ddb4ddbe4cbbbb05c568a1b2cab53a7af82ae157", + "chain_id": "A", + "num_residues": 262, + "sequence_length": 262 + }, + "2HS1": { + "sha256": "6766eb72cbce8bcb05504e5ccc0c3115cffde7d297fb40f1608b2d01e2c659fb", + "chain_id": "A", + "num_residues": 99, + "sequence_length": 99 + }, + "2O7A": { + "sha256": "03df16d3c852fe375a2f84fdd7fefcfa042319025dc90a3b873d43717c76160a", + "chain_id": "A", + "num_residues": 115, + "sequence_length": 115 + }, + "1XVO": { + "sha256": "bcf882109c73174979901447d3bfb18dd2408be7670df3edfad34fb0d5f239a4", + "chain_id": "A", + "num_residues": 224, + "sequence_length": 224 + }, + "2YKZ": { + "sha256": "2bf3b349af3fb14233ae67aa1c15b77fa2d9c8ef818d9400041049e885b97865", + "chain_id": "A", + "num_residues": 126, + "sequence_length": 126 + }, + "4EIC": { + "sha256": "0092f125b9ebecef3ae72ef302b53cc7eb072bcbf9aa2a087fb8c854a37a68e7", + "chain_id": "A", + "num_residues": 93, + "sequence_length": 93 + }, + "3ZQV": { + "sha256": "8854536d165d7add76ebe5043dd3f671ee731c25c79f928baf83f3aaa94ec633", + "chain_id": "A", + "num_residues": 126, + "sequence_length": 126 + }, + "7FFK": { + "sha256": "73ff10b7fb784f47001950882c2fbd691c9c8f158989ad2016fef4af8fb042fd", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "7BNH": { + "sha256": "807a31d8c3e7e9a38f62d11ad7a745bf85563664c7c16f954b61b163ab78bba8", + "chain_id": "A", + "num_residues": 96, + "sequence_length": 96 + }, + "4UA9": { + "sha256": "2adb3e4fcec5884f66a68876d91b1f4b7f6b02e5e3a8fbe9ffa93c8245abd8a9", + "chain_id": "A", + "num_residues": 262, + "sequence_length": 262 + }, + "6EIO": { + "sha256": "8da9baa2c55f78e3bec35aff8fe004128036ada44ea0669e2844c2d566ffef62", + "chain_id": "A", + "num_residues": 222, + "sequence_length": 222 + }, + "7G0Z": { + "sha256": "43389144d736325903b950fbd9f473d13844fc80a66bcbe9795a814b2ad3b7ac", + "chain_id": "A", + "num_residues": 134, + "sequence_length": 134 + }, + "7WKG": { + "sha256": "0a24705e7800e98dd5d842c0f44d4db3cda1ba1f9e5256bc73d0715101b5de55", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "4PSS": { + "sha256": "991758ec6004ac6a1175c83a4379d258570fea25afcad8d7dc24cbd438d6528e", + "chain_id": "A", + "num_residues": 158, + "sequence_length": 158 + }, + "1MC2": { + "sha256": "73cbc130e42a9043d795f895a3d6f5264e48491b73854f3528cb6021af2e7784", + "chain_id": "A", + "num_residues": 122, + "sequence_length": 122 + }, + "7TX0": { + "sha256": "4bebbe275906cd8da9b9c6d0e0b5232944addc05a3b9dd69919d5739d6efdf38", + "chain_id": "A", + "num_residues": 165, + "sequence_length": 165 + }, + "1PQ5": { + "sha256": "5de807865151288969cf77b978e6960da9a8dfbb378b4752d0904306a26ce84e", + "chain_id": "A", + "num_residues": 224, + "sequence_length": 224 + }, + "1M40": { + "sha256": "b448c0e12f3b79718a3ed977b147d2b25dde54c6f56bc4c5f2abb62fcb0f9d33", + "chain_id": "A", + "num_residues": 263, + "sequence_length": 263 + }, + "1X8P": { + "sha256": "ba89b27d40c6cb58c35509cbbbe8a64d42684b24ec859758e8236d07aec9d5a8", + "chain_id": "A", + "num_residues": 184, + "sequence_length": 184 + }, + "2FMA": { + "sha256": "7536ce9a6c5ac56b727c4e337a7ae54871fca693d4e6fa17a086555a78aaf8e6", + "chain_id": "A", + "num_residues": 59, + "sequence_length": 59 + }, + "1X8Q": { + "sha256": "7870143b7722407d0e27f83ecd4edf7b13890d5e3faa38df9b4aa375ed9ab922", + "chain_id": "A", + "num_residues": 184, + "sequence_length": 184 + }, + "2F01": { + "sha256": "4efc2b65348e376907e338e2daab2edf7cbbfabc9156109af0e65a9455dba9c2", + "chain_id": "A", + "num_residues": 121, + "sequence_length": 121 + }, + "3PYP": { + "sha256": "42c8138c946e01839871583aa9269048f6a3ddb22de41803bd418cff9d3ed4d7", + "chain_id": "A", + "num_residues": 125, + "sequence_length": 125 + }, + "5EMB": { + "sha256": "2d51ed2e9b7b93307c3c9da105fb5a851f2776177921d79ad163e08d5c80bc32", + "chain_id": "A", + "num_residues": 96, + "sequence_length": 96 + }, + "3QPA": { + "sha256": "f5cbded1e4e61171e94de02fda4f776ea89ee04243830df56a7abfabc7192e39", + "chain_id": "A", + "num_residues": 196, + "sequence_length": 196 + }, + "4O8H": { + "sha256": "b7b5cef07a90b0a403df6f141182302666a341cdeb02e93b1aff7c3f7e369999", + "chain_id": "A", + "num_residues": 164, + "sequence_length": 164 + }, + "4PSY": { + "sha256": "3714ee64fab75683b1085cbd035dbf397ceaa3566fbf0922c50b3aa348e00bde", + "chain_id": "A", + "num_residues": 158, + "sequence_length": 158 + }, + "4I8K": { + "sha256": "1eb9064a29ab91a3a9ea8cbf3001829c1148aaa227d8e0bd9e88464a36364c33", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "5VLE": { + "sha256": "7e5c958d6704a2598a68b78ed9e44f1d555bf0961942fc6fff26e118dff56dab", + "chain_id": "A", + "num_residues": 262, + "sequence_length": 262 + }, + "5YOK": { + "sha256": "59479a26eb637059b0b882fcb1a952ce17580dfbf01a1bc1d1bf44733e33e08f", + "chain_id": "A", + "num_residues": 99, + "sequence_length": 99 + }, + "6AIR": { + "sha256": "c1ebcffab3a3ba6e88ecd314938366ff642e5d63e25dfd6273bef5585ce18182", + "chain_id": "A", + "num_residues": 83, + "sequence_length": 83 + }, + "6AIQ": { + "sha256": "4c023447ff4b2d5822f6deb572b49211d7d86f7e3bd971fe4c537801a03687e0", + "chain_id": "A", + "num_residues": 83, + "sequence_length": 83 + }, + "6ETK": { + "sha256": "694e03d4ced22a66c47f8bec1b9babaff50fd3b89cd4716b83d4aa42edcb7fca", + "chain_id": "A", + "num_residues": 124, + "sequence_length": 124 + }, + "6ETL": { + "sha256": "f10e2662c024d888fcda6ec2bb07edef7bcab3ec7de53e5259a4a08f61592bc7", + "chain_id": "A", + "num_residues": 124, + "sequence_length": 124 + }, + "6Q00": { + "sha256": "14b9d334cb8b7277bf452c38ac61f31618d338aa93b5feff3aba8f869b9a0d87", + "chain_id": "A", + "num_residues": 76, + "sequence_length": 76 + }, + "6UN0": { + "sha256": "b6102b7b012143679ab42132154d44a82104418e97b3a2b5a54601df12c3cc8e", + "chain_id": "A", + "num_residues": 125, + "sequence_length": 125 + }, + "6JGI": { + "sha256": "10a8484222f2ef3afba1539624fe866fbf68050936ff2ad057633dedcdb04a12", + "chain_id": "A", + "num_residues": 227, + "sequence_length": 227 + }, + "7AOT": { + "sha256": "e17c38e6e9d3cb035e98a7c3fd3061f7e165e960924e0d9882aeea2e8255e7ac", + "chain_id": "A", + "num_residues": 127, + "sequence_length": 127 + }, + "7KQO": { + "sha256": "f6953091379b7880df77ca3a969d8f3d7cc0acd112b68336a38821a248e0542a", + "chain_id": "A", + "num_residues": 169, + "sequence_length": 169 + }, + "7P4R": { + "sha256": "b1014f1d6be0a3fe1765768fc37ff086a34da07f10b292b901ce8fecd5e75ced", + "chain_id": "AAA", + "num_residues": 124, + "sequence_length": 124 + }, + "6KL1": { + "sha256": "a6ded65f0f68c2491a7119a1576fd96bed1b092d61e9d7a78f23e4f0d6cbf441", + "chain_id": "A", + "num_residues": 227, + "sequence_length": 227 + }, + "1G6X": { + "sha256": "3fb28aa601068173d4cfdbcae927ec8502ddbcacf349aa4c2282779b2c631615", + "chain_id": "A", + "num_residues": 58, + "sequence_length": 58 + }, + "8R30": { + "sha256": "ce68acdb21983942e3ae2438bb81dd3543bdfbbef1941cdbf37b5b571f50da88", + "chain_id": "A", + "num_residues": 262, + "sequence_length": 262 + }, + "7WCI": { + "sha256": "5b12eac34e83acc9ab700aac3a0b4ef50938addad69788518c8eb637602013e9", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "6Q01": { + "sha256": "e1f0fe07df4a4bfefc5f4557f3ed8daf97b07fb53e43a8b077430fd8d50a26cd", + "chain_id": "A", + "num_residues": 76, + "sequence_length": 76 + }, + "5MNN": { + "sha256": "2b99297854e7c65e42f2c5a292dc6f42e04b49ad2903d4ea06163f42367276e4", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "2PYA": { + "sha256": "fb10eafab8670f00737bf8159436da20a1ded70ab25b472032ffe10c7cfc5a36", + "chain_id": "A", + "num_residues": 52, + "sequence_length": 52 + }, + "4PTH": { + "sha256": "be3cfdae81b72710d71af8524356b749cf84d369ba92a541d3c37e90400262df", + "chain_id": "A", + "num_residues": 158, + "sequence_length": 158 + }, + "3WDN": { + "sha256": "c66bb7c17fc672c9e2eb55b147ba268485d397a7e606ca3cb660bf7b93c65a0a", + "chain_id": "A", + "num_residues": 125, + "sequence_length": 125 + }, + "4TKB": { + "sha256": "b3bbf85a4893b3cb2bd11fbf80302af1ecac76776cc035387689ba4612f28fb8", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "5AVD": { + "sha256": "1fa1764d9476fc2b5e888bd10394e884139af0a5072978c72de26fc065abe8aa", + "chain_id": "A", + "num_residues": 240, + "sequence_length": 240 + }, + "3ZSJ": { + "sha256": "f2bf739ffd9d01bd9efd632ccd360d21f7cea9bc86f00bc7f67e2b88181c5878", + "chain_id": "A", + "num_residues": 138, + "sequence_length": 138 + }, + "5MNG": { + "sha256": "899ee87f6637e3f0ad11463406aa21a5252626f6085f53313501981f8dd90301", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "7FDT": { + "sha256": "19ecd435df66bb7c77d3f6b7e596b9602e3f1f5fff79b82be040d79018da275c", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "6S2S": { + "sha256": "404a01f5f2c747951dac6ab2c2ef358566ed060c4da9167ea0c16f9160abc785", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "7FDU": { + "sha256": "5a9abfb62e824b09ccd6a1ad4bd1523e5c66d8ebd799fc5ca97f456d5fbb6e0a", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "7FBF": { + "sha256": "f847afc5107ef3c122276686d1ddfdc0427cda7cc5842d4169067f9713421916", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "7X48": { + "sha256": "818fcffd3bbd47d83324b267386cefd771b688960329344ca2f0c873b760106a", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "4UAA": { + "sha256": "25e414f55c64139e59b7d74a9d5270f927c32c75ec2b86e220fa6826dae1bb5e", + "chain_id": "A", + "num_residues": 262, + "sequence_length": 262 + }, + "7WJ1": { + "sha256": "6483b80e3a56541c44153b2097b681d699a61d0cd0bd69ea0841942d537ea4b7", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "1DY5": { + "sha256": "cd356539ac8512220e570e5fa3ae88bcfc6244256b408edd29ccb3763b8f3b95", + "chain_id": "A", + "num_residues": 123, + "sequence_length": 123 + }, + "1V6P": { + "sha256": "8db8665fed5d8fcbf25c078b7bbb7ffa4db237eaef935ef43f34a0ab579cd080", + "chain_id": "A", + "num_residues": 62, + "sequence_length": 62 + }, + "6ZPA": { + "sha256": "6e7274f0d1c90007e1cc91284d5bb01b5c9fe8f6d6c4fc3da772758af261137a", + "chain_id": "A", + "num_residues": 173, + "sequence_length": 173 + }, + "4HS1": { + "sha256": "c93cd043efcdf3aaff74a904b490c5fa207aca3422919318a1fa3727c9da3e6e", + "chain_id": "A", + "num_residues": 84, + "sequence_length": 84 + }, + "8K9N": { + "sha256": "c46cea9d73d6e2088fd1c06223f3d9f91415849b924fa40b9a65ec59f10f040d", + "chain_id": "A", + "num_residues": 125, + "sequence_length": 125 + }, + "9HDS": { + "sha256": "787768116f31eac634bc0d7b3b0ab6ea5f279837c4cc43ed006711e92128a795", + "chain_id": "A", + "num_residues": 256, + "sequence_length": 256 + }, + "4I8L": { + "sha256": "fa42b5064c447ee8d30115cb37331a9f7e3e81c8f50d9c4bd1927127b416d3db", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "4I8J": { + "sha256": "1f0d5edf2f2fa55390923ff061f0dbc47c972a8178e2055c76d0a26599f3a5b7", + "chain_id": "A", + "num_residues": 223, + "sequence_length": 223 + }, + "4TJZ": { + "sha256": "682d48e4c065cd99b355e9a32783f488cfc38da0b16cf9bc8bdb3ae1babfcaa2", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + }, + "4K8M": { + "sha256": "080966ccc26fc03d0f168bd1deb283dfa85f1430024510e9607277c9897b6ea3", + "chain_id": "A", + "num_residues": 84, + "sequence_length": 84 + }, + "5L87": { + "sha256": "225cfceedfc92eb7fa52a585c38f6456bdac4fafb453a167143ec47106dd4d31", + "chain_id": "A", + "num_residues": 62, + "sequence_length": 62 + }, + "5DJ7": { + "sha256": "74d170ab93a50647d48e2cdc2954026705c5f4ee93e7cb29514418df3dd8996b", + "chain_id": "A", + "num_residues": 230, + "sequence_length": 230 + }, + "4TKJ": { + "sha256": "dc7e419881dae8847e95872a730f6f8b84e550d558d367204f064c0333bbfc16", + "chain_id": "A", + "num_residues": 133, + "sequence_length": 133 + } +} diff --git a/bionemo-recipes/recipes/esm2_minifold_te/data/eval_pdb_ids.txt b/bionemo-recipes/recipes/esm2_minifold_te/data/eval_pdb_ids.txt new file mode 100644 index 0000000000..83fcc6ae4b --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/data/eval_pdb_ids.txt @@ -0,0 +1,153 @@ +# Eval PDB IDs for ESM2-MiniFold TE +# resolution<=1.5A, Ca completeness>=0.95 +# Total: 150 structures (excluded 207 training IDs) +1UCS +2DSX +5D8V +1YK4 +5NW3 +3X2M +7VOS +2VB1 +3A38 +4ZM7 +3A39 +5OME +2OV0 +1R6J +6S2M +2B97 +2WFI +4I8H +2WFJ +3X34 +8C5N +5YCE +6L27 +7FEZ +1X6Z +5KWM +7KR0 +5GV8 +6JGJ +3W5H +6ZM8 +7A5M +2PVE +5TDA +4UA6 +5MN1 +7R2H +1IUA +5MNK +1PQ7 +6KL0 +3MFJ +1W0N +3MI4 +4I8G +3UI4 +5NFM +5GV7 +5WQQ +5WQR +9M1X +8ZST +9CY0 +1FN8 +9M22 +1FY5 +1FY4 +1GDN +4M7G +7WKB +5HB7 +2VXN +2H5C +1NWZ +7AVK +1N55 +8C3X +1SSX +2JFR +2PWA +2O9S +3X32 +3X2L +7FF6 +7WF0 +8RMA +9HDW +6MZ2 +2HS1 +2O7A +1XVO +2YKZ +4EIC +3ZQV +7FFK +7BNH +4UA9 +6EIO +7G0Z +7WKG +4PSS +1MC2 +7TX0 +1PQ5 +1M40 +1X8P +2FMA +1X8Q +2F01 +3PYP +5EMB +3QPA +4O8H +4PSY +4I8K +5VLE +5YOK +6AIR +6AIQ +6ETK +6ETL +6Q00 +6UN0 +6JGI +7AOT +7KQO +7P4R +6KL1 +1G6X +8R30 +7WCI +6Q01 +5MNN +2PYA +4PTH +3WDN +4TKB +5AVD +3ZSJ +5MNG +7FDT +6S2S +7FDU +7FBF +7X48 +4UAA +7WJ1 +1DY5 +1V6P +6ZPA +4HS1 +8K9N +9HDS +4I8L +4I8J +4TJZ +4K8M +5L87 +5DJ7 +4TKJ diff --git a/bionemo-recipes/recipes/esm2_minifold_te/data/eval_structures.parquet b/bionemo-recipes/recipes/esm2_minifold_te/data/eval_structures.parquet new file mode 100644 index 0000000000..0c63fdc7b3 Binary files /dev/null and b/bionemo-recipes/recipes/esm2_minifold_te/data/eval_structures.parquet differ diff --git a/bionemo-recipes/recipes/esm2_minifold_te/data/pdb_ids.txt b/bionemo-recipes/recipes/esm2_minifold_te/data/pdb_ids.txt new file mode 100644 index 0000000000..74e24f692d --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/data/pdb_ids.txt @@ -0,0 +1,210 @@ +# PDB ID list for ESM2-MiniFold TE training +# Selection: X-ray, <2.5A resolution, single chain, 50-300 residues, diverse folds +# ~200 well-characterized structures from PDB +1CRN +1L2Y +1UBQ +2GB1 +1VII +1PGA +1IGD +1SHG +1FME +2JOF +1ENH +1PIN +1BDD +2PTL +1K8M +1MJC +1RIS +2CI2 +1CSP +1HRC +1LZ1 +2LZM +1RN1 +3LZT +1AKE +4PTI +1BPI +1CTF +2ABD +1SN3 +1PHT +3ICB +1CDT +1COA +1AHO +1MBN +1YCC +2CPL +1LYZ +1RGE +3RUB +1OVA +1LKI +1PLC +1TEN +1FNF +1TIT +1WIT +1THX +1AGT +1BKR +1CEX +1CHD +1CSE +1CYO +1DTP +1ECA +1FAS +1FKB +1FLV +1GCI +1GOF +1HMR +1IOB +1JBC +1KPT +1LAM +1LEC +1MLA +1NLS +1OMP +1PAZ +1PBE +1PHP +1POH +1PPL +1REC +1REP +1RTP +1SAU +1SBP +1SIM +1SMR +1STN +1THB +1TON +1TPK +1TRB +1VCC +1WHI +1XNB +2ACT +2AIT +2AZA +2CCY +2CDV +2CMD +2CTC +2END +2ERL +2FDN +2FCR +2GBP +2HMB +2IFO +2LIV +2MCM +2MHR +2NLR +2OHX +2PAB +2PHY +2PLV +2POR +2RBI +2RHE +2RSP +2SAR +2SAS +2SNS +2SOD +2SPC +2TGI +2TMA +2TRX +3BLM +3CHY +3CLA +3COX +3DFR +3EBX +3EST +3GAP +3GRS +3HTS +3LYN +3MDS +3PGK +3PGM +3PRO +3SEB +3SGB +3TGL +3TMS +4BCL +4BLM +4CAM +4CPA +4ENL +4FGF +4GCR +4HHB +4I1B +4MDH +4PEP +4PFK +4TNA +5CPA +5CYT +5LYZ +5PTI +5RUB +5TIM +6ACN +6LDH +6TAA +7AAT +7RSA +8ABP +8ADH +8ATC +8CAT +8DFR +8TLN +9PAP +1A2P +1A6M +1A8D +1ADS +1AMP +1AOZ +1ARB +1AVH +1B6C +1B8O +1BCF +1BGF +1BJ7 +1BKF +1BM8 +1BN6 +1BRS +1BTL +1BUO +1BXO +1C1K +1C52 +1C75 +1C9O +1CC8 +1CEM +1CHN +1CIU +1CNV +1CQY +1CTJ +1CX1 diff --git a/bionemo-recipes/recipes/esm2_minifold_te/data/pdb_manifest.json b/bionemo-recipes/recipes/esm2_minifold_te/data/pdb_manifest.json new file mode 100644 index 0000000000..63292682ad --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/data/pdb_manifest.json @@ -0,0 +1,1172 @@ +{ + "1UBQ": { + "sha256": "056f98710cb2b36f633c45e41902a02eb446e82871da21ff2dd44f74a56ca0f6", + "chain_id": "A", + "num_residues": 76, + "sequence_length": 76 + }, + "2GB1": { + "sha256": "83c8456bad4688e4be5b9057bf634cd5265a89313abd927290b8d66c4e0a2d5c", + "chain_id": "A", + "num_residues": 56, + "sequence_length": 56 + }, + "1PGA": { + "sha256": "e8d9e171a9ca2b7e66b10568bf6d0122da8450e57ad632a7b4c13c6a71f7edf2", + "chain_id": "A", + "num_residues": 56, + "sequence_length": 56 + }, + "1IGD": { + "sha256": "9180b720ccbae9095f558c4945e5de3642129b9b1be68d5c75f22a94bfe79ddb", + "chain_id": "A", + "num_residues": 61, + "sequence_length": 61 + }, + "1SHG": { + "sha256": "c278e2e544b120b060ee893094f4083add38db1dd421952b9ab1a98891a99dc1", + "chain_id": "A", + "num_residues": 57, + "sequence_length": 57 + }, + "1ENH": { + "sha256": "80c861fd8559f1a8842e0ae1de7d742aca803aabaf0b3107f93dab91d70e8ca5", + "chain_id": "A", + "num_residues": 54, + "sequence_length": 54 + }, + "1PIN": { + "sha256": "3060c3ca9dfff3af572d754a50c819eb400dfa0e9bf9065e72e1175f3814e375", + "chain_id": "A", + "num_residues": 153, + "sequence_length": 153 + }, + "1BDD": { + "sha256": "de2327fa385c168abae275b06e6f9d91228cb03b3e95c37d3646bc454ee9b91c", + "chain_id": "A", + "num_residues": 60, + "sequence_length": 60 + }, + "2PTL": { + "sha256": "fd758ef87543f7c29a5b451fd849b5a4693ed45acb0cb24fda60c5501538d43d", + "chain_id": "A", + "num_residues": 78, + "sequence_length": 78 + }, + "1K8M": { + "sha256": "37f623fb464c27e1b87623b25d037f8c520a087f151b94ab3bd61579860525d2", + "chain_id": "A", + "num_residues": 87, + "sequence_length": 87 + }, + "1MJC": { + "sha256": "2ce3a3b1694eb4a07930bcd4350892066ce6b064dffb495435427e645013d9f4", + "chain_id": "A", + "num_residues": 69, + "sequence_length": 69 + }, + "1RIS": { + "sha256": "a77dd016ec124f95442bca65cf11e30d3a884d1b3bb22157a647745dd78d2af9", + "chain_id": "A", + "num_residues": 97, + "sequence_length": 97 + }, + "2CI2": { + "sha256": "62c6ec30a5986404a30417f585b77c59d95b4f6029587c810b4b5d213d9727cc", + "chain_id": "I", + "num_residues": 65, + "sequence_length": 65 + }, + "1CSP": { + "sha256": "c23f6e4bf192bbd0c181da9a90f1942c51cb5660e23af1eb846356d5711c018f", + "chain_id": "A", + "num_residues": 67, + "sequence_length": 67 + }, + "1HRC": { + "sha256": "2ee756c54a501c49e1e63371548b4de3051b704f2153f8cfc15fa66c3585fecd", + "chain_id": "A", + "num_residues": 104, + "sequence_length": 104 + }, + "1LZ1": { + "sha256": "b7c3da2c7a54184fa768a2f51b56c4d7f4516854a4e77dbaa2528803a673cea0", + "chain_id": "A", + "num_residues": 130, + "sequence_length": 130 + }, + "2LZM": { + "sha256": "9fa81cd0454d37c53cd5046b96ef7d993a99c249e77a2da4dc3adf7cdc263d90", + "chain_id": "A", + "num_residues": 164, + "sequence_length": 164 + }, + "1RN1": { + "sha256": "c9afd184070c41ce71a3bf617a816d866dd22c877a2de9eb55c458b617f5bdbe", + "chain_id": "A", + "num_residues": 103, + "sequence_length": 103 + }, + "3LZT": { + "sha256": "e925f918c98540d51bc5b8a7aa6217d9b9726ebab8a48a6833221c95bdd24454", + "chain_id": "A", + "num_residues": 129, + "sequence_length": 129 + }, + "1AKE": { + "sha256": "01f41b1b42318a1a5df7f650dbab881677aa0e8d825f7c42dd26ae16a94c0948", + "chain_id": "A", + "num_residues": 214, + "sequence_length": 214 + }, + "4PTI": { + "sha256": "fc81f5dacdf5618e65b8ee7b02b72cd276c1e83f6082094c90a08041f0200b13", + "chain_id": "A", + "num_residues": 58, + "sequence_length": 58 + }, + "1BPI": { + "sha256": "bca455ebc267a5a756bdc998d257a2e3d04d6843c11fe1f3cfdcde88fd05a59c", + "chain_id": "A", + "num_residues": 58, + "sequence_length": 58 + }, + "1CTF": { + "sha256": "d4ac4412659fe599e14816d21232eb94c6a57a24a9f3965c773ccc260e57ec01", + "chain_id": "A", + "num_residues": 68, + "sequence_length": 68 + }, + "2ABD": { + "sha256": "1937b67a9cbeb42ddb2e74015cecbe230be5bb687bc8e3f549e8802a2d4334c5", + "chain_id": "A", + "num_residues": 86, + "sequence_length": 86 + }, + "1PHT": { + "sha256": "10fd4997efe2a9da3e474d696f9cd172cd00ca68c6b15ec87408386670fc885b", + "chain_id": "A", + "num_residues": 83, + "sequence_length": 83 + }, + "3ICB": { + "sha256": "6389449d31ef5bc5b1c2f25180b520f333544d71ff2f34ac20b5369ff746f261", + "chain_id": "A", + "num_residues": 75, + "sequence_length": 75 + }, + "1CDT": { + "sha256": "6649b9938358f786c78b72c42600fb354fc37ee416ed0ba58144f8cadb8b2bf5", + "chain_id": "A", + "num_residues": 60, + "sequence_length": 60 + }, + "1COA": { + "sha256": "4c3a0e08646b8e1f9611d84b5ada4069a6898a989733ec25d4a01867546c3443", + "chain_id": "I", + "num_residues": 64, + "sequence_length": 64 + }, + "1AHO": { + "sha256": "0169a6ec0c2f8883542b05a1df031ee56c608c925cd479deac8d307191953478", + "chain_id": "A", + "num_residues": 64, + "sequence_length": 64 + }, + "1MBN": { + "sha256": "835e78e22edcd1a23cb2f34625adf8818af654fd5494562038cc9fd5dcba8945", + "chain_id": "A", + "num_residues": 153, + "sequence_length": 153 + }, + "1YCC": { + "sha256": "d31ccf5c5429b83632dc0f2a63e470746d4bb167ec0cb50cc5e7fb43b86d34e8", + "chain_id": "A", + "num_residues": 107, + "sequence_length": 107 + }, + "2CPL": { + "sha256": "12495c509148a61158e66edf2a8394df3ab0f21f5ebe8a01cb7ee96cdc39950f", + "chain_id": "A", + "num_residues": 164, + "sequence_length": 164 + }, + "1LYZ": { + "sha256": "7d4eeceaeee0adf6946dd07903907066c3923e103797caee2d40318a1af3cad6", + "chain_id": "A", + "num_residues": 129, + "sequence_length": 129 + }, + "1RGE": { + "sha256": "ef32c3f2e7e10309afe098e416b0d0c83c76401c8c8bd2876c11c4f04b11a43e", + "chain_id": "A", + "num_residues": 96, + "sequence_length": 96 + }, + "3RUB": { + "sha256": "80db9d3ffa2c0a3613940061806b23c95baf386b29f503c1e26d4af286724b52", + "chain_id": "L", + "num_residues": 300, + "sequence_length": 300 + }, + "1OVA": { + "sha256": "b10377f8705036247f62af11800e272312771abdc37e03d3ef41b35154db585d", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1LKI": { + "sha256": "3545830536d72bb491dc7e08c4850197e7406520a377177fff5e8bcb208e8214", + "chain_id": "A", + "num_residues": 172, + "sequence_length": 172 + }, + "1PLC": { + "sha256": "ff6dea63f20ee33eddffa2c7eaf599ab65562b92918feec0da49f7a282d48778", + "chain_id": "A", + "num_residues": 99, + "sequence_length": 99 + }, + "1TEN": { + "sha256": "8244778c8f2e85688274cc78da4e30a37505ffb860268864bb4ca325e02eaab2", + "chain_id": "A", + "num_residues": 90, + "sequence_length": 90 + }, + "1FNF": { + "sha256": "a49c2efaed3cb0fb774f3bdeb1a6473c211eb173939200976049e8d129cce9ad", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1TIT": { + "sha256": "9b53fbf0825333e210b4211c4e980ef498e7527407aaf87f80ea1960c95e2001", + "chain_id": "A", + "num_residues": 89, + "sequence_length": 89 + }, + "1WIT": { + "sha256": "33e4228a1027b6b3013ecf9281c08e3d893e54ec3fc779a2cef361b483f1022a", + "chain_id": "A", + "num_residues": 93, + "sequence_length": 93 + }, + "1THX": { + "sha256": "958dca9649ce1b5b53d16f296d147bf2c8559d6f2b61cb8d931f23b44ed041a2", + "chain_id": "A", + "num_residues": 108, + "sequence_length": 108 + }, + "1BKR": { + "sha256": "1f31d1c63c4a5c7bcd9793cf206828c463a4ea2ceb58a0782988140c6c0fb9c9", + "chain_id": "A", + "num_residues": 108, + "sequence_length": 108 + }, + "1CEX": { + "sha256": "09e73e24a2fa01fae6598bd673d049146f5fc8d5156dda14ed03cb040eccffec", + "chain_id": "A", + "num_residues": 197, + "sequence_length": 197 + }, + "1CHD": { + "sha256": "0bacad1fac9f61239fec8ce3181a9a491374bc81748afda25e298437154c7dd2", + "chain_id": "A", + "num_residues": 198, + "sequence_length": 198 + }, + "1CSE": { + "sha256": "b5d7e7db068b277ddc72eb753e4a7b0ec424b2194a2814c169e832345e601ade", + "chain_id": "E", + "num_residues": 274, + "sequence_length": 274 + }, + "1CYO": { + "sha256": "956ffa7dcda6a3b27c82146f1faaebbafbbd5b23f51497b280281761993a198c", + "chain_id": "A", + "num_residues": 88, + "sequence_length": 88 + }, + "1DTP": { + "sha256": "be00ea9dc45576074eb2a22855afca794c0cf070e33f8418ee67d6d751238a2a", + "chain_id": "A", + "num_residues": 190, + "sequence_length": 190 + }, + "1ECA": { + "sha256": "72ded9fcb3b89a9f6b36e9d5b1ad91986e2f5bae42d62edd0cbf71864b174a5c", + "chain_id": "A", + "num_residues": 136, + "sequence_length": 136 + }, + "1FAS": { + "sha256": "d36495d2ca3474f1bab1d37c5f7191e46e625d50515afbe5bb63ae6ac2e963fe", + "chain_id": "A", + "num_residues": 61, + "sequence_length": 61 + }, + "1FKB": { + "sha256": "a483de9e6db2c1f6119258fdfc68765bf308eee3aa5101b0e07199802bbda44d", + "chain_id": "A", + "num_residues": 107, + "sequence_length": 107 + }, + "1FLV": { + "sha256": "bd1ce960155ccfcad20579f84c6eb21c0fb561ce98f2346b34253f4018daed36", + "chain_id": "A", + "num_residues": 168, + "sequence_length": 168 + }, + "1GCI": { + "sha256": "c69c4b51df71b65f7ad4a49458af228a0261893c00ab6bab0200570a7d20c5b5", + "chain_id": "A", + "num_residues": 269, + "sequence_length": 269 + }, + "1GOF": { + "sha256": "2c324c80ee423dc0f507fa665aa892cd001bfa96d5509b9f3d6d0fac37dfaa39", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1HMR": { + "sha256": "fdc8e492a4e1d6b3baf64bb98b9b377c69ff268f8e67b36c1807b454f38baaf1", + "chain_id": "A", + "num_residues": 131, + "sequence_length": 131 + }, + "1IOB": { + "sha256": "241f339669a911f3c0c5c60a24d455c81e54827da90336b749a12ee09ad75c2e", + "chain_id": "A", + "num_residues": 153, + "sequence_length": 153 + }, + "1JBC": { + "sha256": "418c1e271bc77936bc2f74b0f1ad75b6f84b5accbe13082fe9b1cb30b427cdc2", + "chain_id": "A", + "num_residues": 237, + "sequence_length": 237 + }, + "1KPT": { + "sha256": "38dcb1d110760d87de6e5a0ae66b2cc7c38bcb7cdcf64eba5c0984fbc35bd001", + "chain_id": "A", + "num_residues": 105, + "sequence_length": 105 + }, + "1LAM": { + "sha256": "46e824c80fcd4640f8507ce5576a6c4d3a04b89c536dd43fc660b791aab4df0e", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1LEC": { + "sha256": "dce01b94d945c90104560dcc214501c7ffd76e4b3c20170c438602d3024b0456", + "chain_id": "A", + "num_residues": 242, + "sequence_length": 242 + }, + "1MLA": { + "sha256": "cfb1a6f6ff3775dc2174c8789168d8e44b7a91fe2063cb862feb0b74e74a4af3", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1NLS": { + "sha256": "d79c6bdf4e5f2abb8e21a3ca76a2dd8830144c7f69bc9c10b678b032bead979b", + "chain_id": "A", + "num_residues": 237, + "sequence_length": 237 + }, + "1OMP": { + "sha256": "8fc0436ac9314dce1c504e6016e1fc05d95f5b71faa245036efb6e3efe670b8f", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1PAZ": { + "sha256": "1df4918e56a9bde8705ad4ec394aea223fcb1739194ab72144f9dbc5008ddedc", + "chain_id": "A", + "num_residues": 120, + "sequence_length": 120 + }, + "1PBE": { + "sha256": "4fd1b3863f2317671b0cf68840b3b4235ad68e8182e226637c7b477cf59da865", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1PHP": { + "sha256": "642a85c990fe7c041bf8f7952c0d1fc97be6c14dde32f0c4bb61d6a77bd958bb", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1POH": { + "sha256": "279f526e0e1d75dc25ecac0e8d151d646eaeb92b985852499254ea9c42af6454", + "chain_id": "A", + "num_residues": 85, + "sequence_length": 85 + }, + "1PPL": { + "sha256": "dc97f5605b05b3f958e437a15b85ad297856c6df9f2dc7dc489b879759d70295", + "chain_id": "E", + "num_residues": 300, + "sequence_length": 300 + }, + "1REC": { + "sha256": "c90edd6d537ba299f8d01e1d8dcb9e95da69a87ae99ecf5015dc7274a57df069", + "chain_id": "A", + "num_residues": 185, + "sequence_length": 185 + }, + "1REP": { + "sha256": "11c2ceefdca82572a3df10bc58efe1c123a36c7245708595282707e7282fdb66", + "chain_id": "C", + "num_residues": 214, + "sequence_length": 214 + }, + "1RTP": { + "sha256": "2d2b65f4e06b17465493be514883502267665aba9e73f1802e445389ed5b7bc8", + "chain_id": "1", + "num_residues": 109, + "sequence_length": 109 + }, + "1SAU": { + "sha256": "024803af2e2d8629d25d5b777e9358a8f16501966a1de7363ce7eba490d3e05f", + "chain_id": "A", + "num_residues": 114, + "sequence_length": 114 + }, + "1SBP": { + "sha256": "61889856ca53b6f7d7169df10ef4d9a91c92ef2969637df04e288240c02a7c92", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1SMR": { + "sha256": "6c5e25ff997e3683d6cf49fc5f3b3f27dfc7003d5d3a75df7133eeb29fe4a31a", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1STN": { + "sha256": "ebd8b4aa5b6d630a723e752013a6794f8636882db0d5d6064461be3c4be64ef9", + "chain_id": "A", + "num_residues": 136, + "sequence_length": 136 + }, + "1THB": { + "sha256": "db9535f33bec750d5e7525d7a999e4812ea799f3917595e0b92d3270307bdeee", + "chain_id": "A", + "num_residues": 141, + "sequence_length": 141 + }, + "1TON": { + "sha256": "ac8e747238603a572e2385483e9602d6bb04d384f1b54d5258fbc3643c700507", + "chain_id": "A", + "num_residues": 228, + "sequence_length": 228 + }, + "1TPK": { + "sha256": "e9cc5c5654ca78195bb70998166f89cdd2d4b72de2cda2414104c49b19fb730a", + "chain_id": "A", + "num_residues": 88, + "sequence_length": 88 + }, + "1TRB": { + "sha256": "580684c633c75f3b9ce3fb153823102e2abf27873af23310beded3bbec849235", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1VCC": { + "sha256": "7c23da5c418646b6b72aedec26506dd59198cb12140344df7548daf6f040455f", + "chain_id": "A", + "num_residues": 77, + "sequence_length": 77 + }, + "1WHI": { + "sha256": "73049e77254a1e4afc64032d75eb9da6db248eb83343e4102dc6837a7f882a7d", + "chain_id": "A", + "num_residues": 122, + "sequence_length": 122 + }, + "1XNB": { + "sha256": "cb1e11f6640c4aed0cdd5e5ef7724e1c4e3a5b3d6ee7ff0c084e7bcf8a0beb5e", + "chain_id": "A", + "num_residues": 185, + "sequence_length": 185 + }, + "2ACT": { + "sha256": "650139fdaa141b5b805f48c69b0431e3a07dd7f99779246ff7ce8accade45b35", + "chain_id": "A", + "num_residues": 217, + "sequence_length": 217 + }, + "2AIT": { + "sha256": "80d87df39d255c874e034d254237a49df39fdb5633855feff32c3406d35276cc", + "chain_id": "A", + "num_residues": 74, + "sequence_length": 74 + }, + "2AZA": { + "sha256": "284b1405b0f9032fa1bc0efe61ecf6842e35d24160d4eeb750c000aa7a243018", + "chain_id": "A", + "num_residues": 129, + "sequence_length": 129 + }, + "2CCY": { + "sha256": "d4d9f481a4ba4b2d5167b1c725b2a94673f3c1d54ea7f240340bc4c95e37e01a", + "chain_id": "A", + "num_residues": 127, + "sequence_length": 127 + }, + "2CDV": { + "sha256": "f4e14129ce13722542f4e53d79dd6ccb44c2c6a01a5fe8364c4c1f6e6aacc19f", + "chain_id": "A", + "num_residues": 107, + "sequence_length": 107 + }, + "2CMD": { + "sha256": "7f986ee1610d5b9a191af06699601e1e7110133d16f44b1d435fda9d59738ff5", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "2CTC": { + "sha256": "3beeef2681a45bff701a2841bf0f11e38b2944231aa08dd7c9c165f96df20d17", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "2END": { + "sha256": "d7437cb48150eca138e05e3c31e13a98220de9a197cca7e63601e16d6aecc54e", + "chain_id": "A", + "num_residues": 137, + "sequence_length": 137 + }, + "2FDN": { + "sha256": "f5c9d86e0a66f2579b572d1d4dd9de2cda65f57190850533d7e0a3814aa01875", + "chain_id": "A", + "num_residues": 55, + "sequence_length": 55 + }, + "2FCR": { + "sha256": "8268871fadab8dc5109ee75c90dd53c466c5d1de21e7e1b4a7c3d77f6a056cdc", + "chain_id": "A", + "num_residues": 173, + "sequence_length": 173 + }, + "2GBP": { + "sha256": "c9cd68018aa4480701bfd1f6240e65310d64b48d35f7cfea6aaf4ce3e7b030f2", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "2HMB": { + "sha256": "549c23e8c0ec6d830a865a25d5fc801fdea70fd071c8686ec400715020a07622", + "chain_id": "A", + "num_residues": 131, + "sequence_length": 131 + }, + "2LIV": { + "sha256": "d87320739a3f9439ebd625dffcf01ef816ffadd5d4f8de3f3d333bcb3ee48722", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "2MCM": { + "sha256": "e239c95e376ad6948e965bb4579ef9cc142ee2ed598075911a0ff3923aae2b78", + "chain_id": "A", + "num_residues": 112, + "sequence_length": 112 + }, + "2MHR": { + "sha256": "20b21394451b5d5145974b7eef17ef70ccdd5328e0583cb694155ea6d9706af1", + "chain_id": "A", + "num_residues": 118, + "sequence_length": 118 + }, + "2NLR": { + "sha256": "0563cea6fa76b4436d69936222447e6517f58b0201c0031e0490c8ff9d4902d4", + "chain_id": "A", + "num_residues": 222, + "sequence_length": 222 + }, + "2OHX": { + "sha256": "982373459270ff74905c9093e6e0110ee2ddc6b63f862b378b357c078f2ba37c", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "2PAB": { + "sha256": "be37980387a155ce1b94c84c103ea97cfddac1c56d2a18c36dea9a6d4ef56a7c", + "chain_id": "A", + "num_residues": 114, + "sequence_length": 114 + }, + "2PHY": { + "sha256": "54efd97c399de1531e42e63d69173a50538923028e05d02f9534575d9df2b3cb", + "chain_id": "A", + "num_residues": 125, + "sequence_length": 125 + }, + "2PLV": { + "sha256": "68beef5ef758bffa72e5183870954d87c1778dbc3122b8a70a15df85f47fcd59", + "chain_id": "1", + "num_residues": 288, + "sequence_length": 288 + }, + "2POR": { + "sha256": "dd6e8a757dbbe7865486e2d29f034ee7a6745cbbd043486bb3b4655115313484", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "2RBI": { + "sha256": "1ce9eef3dee319d9d6f171d97df1c96c1549513d2a7555f55abf80ccc89b3630", + "chain_id": "A", + "num_residues": 108, + "sequence_length": 108 + }, + "2RHE": { + "sha256": "8fc2ff50bdbdee76a8e77f96e673fd0fa7e707678b19dac99b1f616e1fbc2d19", + "chain_id": "A", + "num_residues": 114, + "sequence_length": 114 + }, + "2RSP": { + "sha256": "8cb202a43fbbfad8e663d801e11b87d53ee3ece93c607080b6338d2cd52ba6b9", + "chain_id": "A", + "num_residues": 115, + "sequence_length": 115 + }, + "2SAR": { + "sha256": "e5ce0c93e50596e04f057928b94c5946619364b50efd01c2e6f93ee502298d4c", + "chain_id": "A", + "num_residues": 96, + "sequence_length": 96 + }, + "2SAS": { + "sha256": "80ad5027d095876dc2ca07333f794a72548fcac9663c4689aae8f0180e64977c", + "chain_id": "A", + "num_residues": 185, + "sequence_length": 185 + }, + "2SNS": { + "sha256": "b962d4a4697f7efc79dfa5cb766c60c00bea03934dd6bb89117126d859cb791b", + "chain_id": "A", + "num_residues": 141, + "sequence_length": 141 + }, + "2SOD": { + "sha256": "6522ae7ad98a89e8530372b5e73063d0f60832fa32ddbf3d856c8ac1c76819ad", + "chain_id": "O", + "num_residues": 151, + "sequence_length": 151 + }, + "2SPC": { + "sha256": "9afb37cfc7ac9d958416425689c3ae634116e886bd9bf1689efbab94de5adad4", + "chain_id": "A", + "num_residues": 107, + "sequence_length": 107 + }, + "2TGI": { + "sha256": "1bf4c0308be73d594c9da7315dbdfbdd311d21aeeffc3d9f00c873db1ef70bfd", + "chain_id": "A", + "num_residues": 112, + "sequence_length": 112 + }, + "2TMA": { + "sha256": "bd6e988a1394de06a5a28487f039f3bf6728f17b34cf07fe444ccce55150e938", + "chain_id": "A", + "num_residues": 284, + "sequence_length": 284 + }, + "2TRX": { + "sha256": "c068dd829e81b10396cdf4eb13757aa0147f3645d4fea15556dacc19d6115485", + "chain_id": "A", + "num_residues": 108, + "sequence_length": 108 + }, + "3BLM": { + "sha256": "1afa52208746c3964756b6dbf37d1dbf6fbf4950077f0b705c1da94ff6919be3", + "chain_id": "A", + "num_residues": 257, + "sequence_length": 257 + }, + "3CHY": { + "sha256": "cc45affe443523897fa38068e31d88e21b37f13c493b0d0227e2d75a16e721b9", + "chain_id": "A", + "num_residues": 128, + "sequence_length": 128 + }, + "3CLA": { + "sha256": "6f7e191b0098a92c26da3ea26d9419d09ceafe15c8206663f7bcfeadceabd52c", + "chain_id": "A", + "num_residues": 213, + "sequence_length": 213 + }, + "3COX": { + "sha256": "c193f3452f870af407f89dbbdc8f49c71f1286fc5c413fbd402ea67830ba0be9", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "3DFR": { + "sha256": "c7bbf3fc5c6a95b8359232d4f7b14551520fac15affa779f9231b7ba8944e35e", + "chain_id": "A", + "num_residues": 162, + "sequence_length": 162 + }, + "3EBX": { + "sha256": "44568949895af9c2be2e0d1a2f9c7c6be26dd2218bf9b01c23180e822d09170e", + "chain_id": "A", + "num_residues": 62, + "sequence_length": 62 + }, + "3EST": { + "sha256": "24125292af7072189e3d56cfaf0ed3e4c62c9de8b8ff13f15474539d05be7d72", + "chain_id": "A", + "num_residues": 240, + "sequence_length": 240 + }, + "3GRS": { + "sha256": "58e59f31e7e6f800c78cefede5ccd990698541c11272171255b1712ccddbc2e7", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "3HTS": { + "sha256": "9e4bfadeff7919bfd55944b50830b2fc5f36e10b5c199b92fe0ca05928f4dcbd", + "chain_id": "B", + "num_residues": 82, + "sequence_length": 82 + }, + "3LYN": { + "sha256": "83be53a53f891eb9a881e5c92383117c15be03287045a7e1a1db496839a3e087", + "chain_id": "A", + "num_residues": 122, + "sequence_length": 122 + }, + "3MDS": { + "sha256": "3e5450e5cdeb553212caedf62aa231f499da51b3bdc35600efbebbd2f1cb259f", + "chain_id": "A", + "num_residues": 203, + "sequence_length": 203 + }, + "3PGK": { + "sha256": "8ef7ea7b5531afe7c59067bb0b3829b30872f399fcbb687970703336951c50f4", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "3PGM": { + "sha256": "f96bf5db04a2a9290d87572abd652c870b7826710849aa760a680b2d0a1e9144", + "chain_id": "A", + "num_residues": 230, + "sequence_length": 230 + }, + "3PRO": { + "sha256": "272d8cb627faec718dde3a6862e7b624ac046dfc5c9b5a1ec38948c162aef3bf", + "chain_id": "A", + "num_residues": 198, + "sequence_length": 198 + }, + "3SEB": { + "sha256": "01d620828663e2eb3bbb8e14ad14d1cf17503b123a9b16d3467950f52616a646", + "chain_id": "A", + "num_residues": 238, + "sequence_length": 238 + }, + "3SGB": { + "sha256": "61579665a2ba5eb732218bf33fd5176260b05132ccd75a762ca02f3bd60c25ae", + "chain_id": "E", + "num_residues": 185, + "sequence_length": 185 + }, + "3TGL": { + "sha256": "334cc39bd5f82e3f4c898a122ce6d673d0fc8d600246039c270581d17f6302e9", + "chain_id": "A", + "num_residues": 265, + "sequence_length": 265 + }, + "3TMS": { + "sha256": "f178bd9cb6b0266503b217f695cb2e55b07925d3daaf51c21a6ef6f61c047472", + "chain_id": "A", + "num_residues": 264, + "sequence_length": 264 + }, + "4BCL": { + "sha256": "8e2f02358340426ab0a019a5f416846a77a0e9293e24f414e72b50c3b6b64c13", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "4BLM": { + "sha256": "c1f0745811444e976b5b76bd30a4a049ff5632c83f936cf36c87a0c8c40b54e5", + "chain_id": "A", + "num_residues": 256, + "sequence_length": 256 + }, + "4CAM": { + "sha256": "2fc1dda18e01556c902e0ddefb8be47f8c1b9ab42fe19ead8cd334e59e11bb43", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "4CPA": { + "sha256": "f132396a9b7eba01011432c4bbc9b19cf8cefd9a7d0d4ed3967486d356c55208", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "4ENL": { + "sha256": "dcc642dedbd834bd8978c0a51a3088e685f2d596d15dd4d12a7256c044b4359f", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "4FGF": { + "sha256": "6f351d58dbf99fef876e6e2d6f663481510079461c7b0dab2aaa4204b5b4ba85", + "chain_id": "A", + "num_residues": 124, + "sequence_length": 124 + }, + "4GCR": { + "sha256": "64ed578776c9d837eef8716e72e09a1b8d613ce2e6d92feb0945ebaea465fe95", + "chain_id": "A", + "num_residues": 174, + "sequence_length": 174 + }, + "4HHB": { + "sha256": "6c977e3c48fcae60ef116dffff90ce2ae2dbc987f39c3f930d325361a1916812", + "chain_id": "A", + "num_residues": 141, + "sequence_length": 141 + }, + "4I1B": { + "sha256": "1ce778fd8d5bf97314d5c9db9f30a8b4bd836357309b03e2ee2e33a6ff57c635", + "chain_id": "A", + "num_residues": 151, + "sequence_length": 151 + }, + "4MDH": { + "sha256": "c1c86daefea7cfbda5ac5c2f563b8301b48e1a1f45c01d38e0bc44729975fab8", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "4PEP": { + "sha256": "78d09ed6269f6c59a9881f8af8c7e403f4506f0ca8e5a87120650cc9b20c9108", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "4PFK": { + "sha256": "02c89d9fad874855a9a501c5d20b71147be9969be745cf0d72c9cc27f688071f", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "5CPA": { + "sha256": "aae9d8720b41205a6aa8fa62ac15f65f6aeb95ffbeeb5298f6411d9705172891", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "5CYT": { + "sha256": "5f06660121b5792c8ca84722e4e697dc59459daef13315ccbcadb4dd2ac416fe", + "chain_id": "R", + "num_residues": 103, + "sequence_length": 103 + }, + "5LYZ": { + "sha256": "21e918fbc704f728de5fcc010b9ef82282c46d5967ee2f5dd194e71af9fdd245", + "chain_id": "A", + "num_residues": 129, + "sequence_length": 129 + }, + "5PTI": { + "sha256": "d39881b81af60445d9cb7ca59c6b4858a1efd989d5c404e99730c475b26fb55b", + "chain_id": "A", + "num_residues": 58, + "sequence_length": 58 + }, + "5RUB": { + "sha256": "8d3f3dbe1d9d7fd5406938383ec3b0a36c4cb799469485a1c7939c34683b97e4", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "5TIM": { + "sha256": "9cd03cc1a50c0903635c7182eeec0c2f83a711757d8e1b72b38414d1adcf586c", + "chain_id": "A", + "num_residues": 249, + "sequence_length": 249 + }, + "6ACN": { + "sha256": "d32459ba2616cf0b075cc97d40f9b716b77126882b61eb44aafc40e760336c93", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "6LDH": { + "sha256": "867ebfce0708705bf2aa2f33ec163e54a72eadb00b630bc6c9715c71f3213317", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "6TAA": { + "sha256": "e7e44313d3fdd0ed0b54f1d092bc3828123218ed2de97bb05f8cd0c04ccb97ee", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "7AAT": { + "sha256": "1b6d7b4916ad92ad19745c44f12a99f78e4d06a4778764f38d8345dfcd9b5574", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "7RSA": { + "sha256": "715cece1a426fbcf0d6b4e93a6810f700dd8b9455a2eb666be71aaddd21f2846", + "chain_id": "A", + "num_residues": 124, + "sequence_length": 124 + }, + "8ABP": { + "sha256": "b01a3ee74ccc9d37aa49541557828a5abe241277bbc94ea3add4b4daff241d21", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "8ADH": { + "sha256": "4c19677fd3dafe9d022560fc9341e86fe1b8f740dab97474c9ad8a30fcb1fb04", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "8ATC": { + "sha256": "99a62b081803171585b8495ebca53078a7947fff0d22dbf356246d78323cf1d6", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "8CAT": { + "sha256": "2cad7defe9aca4f63bcb2bbb43ee4197e7432e498f4b1493a3453fb19d2fa523", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "8DFR": { + "sha256": "68f429a3993d2533bff1c39af569ab91a28fa80334cacb18c1c0d8330ec32051", + "chain_id": "A", + "num_residues": 186, + "sequence_length": 186 + }, + "8TLN": { + "sha256": "d171b00a7aa7e5a2ddef1f8d2e59cf3da9418abf5ce04ca6ffe72f2c4395d1c1", + "chain_id": "E", + "num_residues": 300, + "sequence_length": 300 + }, + "9PAP": { + "sha256": "559172c930b0af8ba099733299905634f807d903a30bbfd758fe30aba1119850", + "chain_id": "A", + "num_residues": 211, + "sequence_length": 211 + }, + "1A2P": { + "sha256": "50b78c9c7e95c699e81610db9b140855762bbfb70a7dd821a0acef179d9e54d7", + "chain_id": "A", + "num_residues": 108, + "sequence_length": 108 + }, + "1A6M": { + "sha256": "ce6574d325b046f46803df49894524537d098dfeee1033e76e62133e941fd948", + "chain_id": "A", + "num_residues": 151, + "sequence_length": 151 + }, + "1A8D": { + "sha256": "76e862145f7c5b12905ca59e5883c5071368e835518506ec44d019a109ac0989", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1ADS": { + "sha256": "3c1e4bc84b8aab08749f5bc1d92d194d19f43ee3935f17fa471d81c828e1f2bb", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1AMP": { + "sha256": "72895144a61e07cca37953dfa16cfbfdef79f7dce44ea9b0f0b5472fc0da17ad", + "chain_id": "A", + "num_residues": 291, + "sequence_length": 291 + }, + "1AOZ": { + "sha256": "98827257d353a5e4a424c222d4f576225185e670014c1460a7cb3343d2f31761", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1ARB": { + "sha256": "68982c8f5542514d24079a7ac5c86de4c923b5c34796682a8a373b6d1e73e5d3", + "chain_id": "A", + "num_residues": 263, + "sequence_length": 263 + }, + "1AVH": { + "sha256": "b615b8450c413a18351e7fb8e1da296d3dd833412e6b2c0cd63ba4b436dd622b", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1B6C": { + "sha256": "179831a59b5f64876c98cae70d7d9db34bfe30d02d452d21502f568e2ae0f16e", + "chain_id": "A", + "num_residues": 107, + "sequence_length": 107 + }, + "1B8O": { + "sha256": "324c050ee767e77a7e71228ecccd352431b5787cd8db2efa788b1f5669e39204", + "chain_id": "A", + "num_residues": 280, + "sequence_length": 280 + }, + "1BCF": { + "sha256": "3bdb2a74848712617e912e80da92fa854bd61b1a09ba9c7b52964d99319c46ac", + "chain_id": "A", + "num_residues": 158, + "sequence_length": 158 + }, + "1BGF": { + "sha256": "f8366cd3103ddd924943019cf21708f024187852c7a76fad78e723f7fbfe4ce0", + "chain_id": "A", + "num_residues": 124, + "sequence_length": 124 + }, + "1BJ7": { + "sha256": "b7f42707c1f2e462120ba948b58baf740683520b03dcf47c5615ee8113ca8bea", + "chain_id": "A", + "num_residues": 150, + "sequence_length": 150 + }, + "1BKF": { + "sha256": "d6d0b8d27d16b012e9097cbb476b8f556309adb90e4a354e76278956056d72f7", + "chain_id": "A", + "num_residues": 107, + "sequence_length": 107 + }, + "1BM8": { + "sha256": "4ce229554c7ea27ca371cc8b2a6d4bf7a8fdf00f2a2a333274e4bba0f6c99e75", + "chain_id": "A", + "num_residues": 99, + "sequence_length": 99 + }, + "1BN6": { + "sha256": "f1843d1268d7da01319bad632ca1de330be2e3136bca1a79088ba6ff44592c9b", + "chain_id": "A", + "num_residues": 291, + "sequence_length": 291 + }, + "1BRS": { + "sha256": "1c98dcc3eac7ccb88657e3ab0f745121d7ddbd2fe11357c65dbaffdd5704f153", + "chain_id": "A", + "num_residues": 108, + "sequence_length": 108 + }, + "1BTL": { + "sha256": "932981534e5dd6abcdd02271935e055c74dff16486c991cd610c16c95b5a2b82", + "chain_id": "A", + "num_residues": 263, + "sequence_length": 263 + }, + "1BUO": { + "sha256": "edc39eb81e02bb9c6f34b27d274feb762b11c67cf0d3b369e9d07197574a4859", + "chain_id": "A", + "num_residues": 121, + "sequence_length": 121 + }, + "1BXO": { + "sha256": "213e2fa58d1db04fb1913623cbdd95e1b605e083986db9bc50b5ecd76ccf98d9", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1C1K": { + "sha256": "ff0bbde9638682f57dfa65feb1b28edfa42ec1fafda1a43de54e005cdfc16c93", + "chain_id": "A", + "num_residues": 217, + "sequence_length": 217 + }, + "1C52": { + "sha256": "0bd7845a19e139b8a4789c6eb29894970c381d1d8a81c9b8f8156816a796f238", + "chain_id": "A", + "num_residues": 131, + "sequence_length": 131 + }, + "1C75": { + "sha256": "7d24db3209948e3e4cb402d764320c4e551d4eeef5f5b6d05a641f07821a92b8", + "chain_id": "A", + "num_residues": 71, + "sequence_length": 71 + }, + "1C9O": { + "sha256": "960a42e2482b2081243646b260c391d0cfe4015da065dd214cec0e77e17e8a30", + "chain_id": "A", + "num_residues": 66, + "sequence_length": 66 + }, + "1CC8": { + "sha256": "cbcdf55e0accf725a1169b91567223c192d6212f2416b99ae185f9ead087f55e", + "chain_id": "A", + "num_residues": 72, + "sequence_length": 72 + }, + "1CEM": { + "sha256": "5aa0de834281b7f34ed4f01a4f6d7608852ced3b985937186de6cba298ff5e2a", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1CHN": { + "sha256": "cfd30d33db5bdfd05ce83bdb5859bb84a098d8c5204e8717ed2c0f8e2064f561", + "chain_id": "A", + "num_residues": 126, + "sequence_length": 126 + }, + "1CIU": { + "sha256": "0df4d846d01f0e135f1376b040d807a8fed486bb261a6f5ff132db2f08af8919", + "chain_id": "A", + "num_residues": 300, + "sequence_length": 300 + }, + "1CNV": { + "sha256": "3c29154c365ba25de2393888ada1267315b534d43041561fa75ca30689e004ab", + "chain_id": "A", + "num_residues": 283, + "sequence_length": 283 + }, + "1CQY": { + "sha256": "24081a1305ddd18db6c3ab64a9e4e8dd1e0b24bf67a86c64a6192164b6c9ec94", + "chain_id": "A", + "num_residues": 99, + "sequence_length": 99 + }, + "1CTJ": { + "sha256": "1613bc11939ef7495c4d750a63d39741fba01b7ff2ea6be930271117f5b57401", + "chain_id": "A", + "num_residues": 89, + "sequence_length": 89 + }, + "1CX1": { + "sha256": "6dc0858f9c94f01313741ef1d9415da6973336670b38b1c2180b0ddd9718d569", + "chain_id": "A", + "num_residues": 153, + "sequence_length": 153 + } +} diff --git a/bionemo-recipes/recipes/esm2_minifold_te/data/pdb_structures.parquet b/bionemo-recipes/recipes/esm2_minifold_te/data/pdb_structures.parquet new file mode 100644 index 0000000000..a3dc2d14b6 Binary files /dev/null and b/bionemo-recipes/recipes/esm2_minifold_te/data/pdb_structures.parquet differ diff --git a/bionemo-recipes/recipes/esm2_minifold_te/data/prepare_eval_dataset.py b/bionemo-recipes/recipes/esm2_minifold_te/data/prepare_eval_dataset.py new file mode 100644 index 0000000000..b7d08c942c --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/data/prepare_eval_dataset.py @@ -0,0 +1,379 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download a high-quality evaluation dataset for ESM2-MiniFold TE. + +Queries RCSB PDB for very high resolution structures (≤1.5Å), excludes any +PDB IDs in the training set, downloads mmCIF files, and exports to parquet. + +Usage: + python data/prepare_eval_dataset.py + python data/prepare_eval_dataset.py --max-structures 50 # quick test +""" + +import argparse +import hashlib +import json +import logging +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen, urlretrieve + +import numpy as np +import pandas as pd +from Bio.PDB.MMCIFParser import MMCIFParser + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") +logger = logging.getLogger(__name__) + +RCSB_SEARCH_URL = "https://search.rcsb.org/rcsbsearch/v2/query" +RCSB_DOWNLOAD_URL = "https://files.rcsb.org/download/{pdb_id}.cif" + +SCRIPT_DIR = Path(__file__).resolve().parent + +AA_3TO1 = { + "ALA": "A", + "CYS": "C", + "ASP": "D", + "GLU": "E", + "PHE": "F", + "GLY": "G", + "HIS": "H", + "ILE": "I", + "LYS": "K", + "LEU": "L", + "MET": "M", + "ASN": "N", + "PRO": "P", + "GLN": "Q", + "ARG": "R", + "SER": "S", + "THR": "T", + "VAL": "V", + "TRP": "W", + "TYR": "Y", + "MSE": "M", +} + +MAX_RETRIES = 3 +MIN_CA_COMPLETENESS = 0.95 # Stricter than training (0.9) for eval quality + + +def read_training_pdb_ids(path=None): + """Read training PDB IDs to exclude from eval set.""" + if path is None: + path = SCRIPT_DIR / "pdb_ids.txt" + if not path.exists(): + logger.warning("Training PDB IDs file not found: %s", path) + return set() + + ids = set() + with open(path) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + ids.add(line.upper()) + return ids + + +def query_rcsb(max_resolution=1.5, min_length=50, max_length=300, max_results=500): + """Query RCSB Search API for high-resolution protein structures.""" + query = { + "query": { + "type": "group", + "logical_operator": "and", + "nodes": [ + { + "type": "terminal", + "service": "text", + "parameters": { + "attribute": "exptl.method", + "operator": "exact_match", + "value": "X-RAY DIFFRACTION", + }, + }, + { + "type": "terminal", + "service": "text", + "parameters": { + "attribute": "rcsb_entry_info.resolution_combined", + "operator": "less_or_equal", + "value": max_resolution, + }, + }, + { + "type": "terminal", + "service": "text", + "parameters": { + "attribute": "entity_poly.rcsb_entity_polymer_type", + "operator": "exact_match", + "value": "Protein", + }, + }, + { + "type": "terminal", + "service": "text", + "parameters": { + "attribute": "entity_poly.rcsb_sample_sequence_length", + "operator": "range", + "value": {"from": min_length, "to": max_length}, + }, + }, + ], + }, + "return_type": "entry", + "request_options": { + "paginate": {"start": 0, "rows": max_results}, + "results_content_type": ["experimental"], + "sort": [{"sort_by": "rcsb_entry_info.resolution_combined", "direction": "asc"}], + }, + } + + logger.info( + "Querying RCSB for eval set: resolution<=%.1fA, length %d-%d, max %d", + max_resolution, + min_length, + max_length, + max_results, + ) + + req = Request( + RCSB_SEARCH_URL, + data=json.dumps(query).encode("utf-8"), + headers={"Content-Type": "application/json"}, + ) + + with urlopen(req, timeout=60) as resp: + data = json.loads(resp.read().decode("utf-8")) + + total_count = data.get("total_count", 0) + pdb_ids = [r["identifier"] for r in data.get("result_set", [])] + logger.info("RCSB returned %d results (total available: %d)", len(pdb_ids), total_count) + return pdb_ids + + +def download_cif(pdb_id, output_dir): + """Download a single mmCIF file with retry. Returns (pdb_id, path, success).""" + url = RCSB_DOWNLOAD_URL.format(pdb_id=pdb_id) + output_path = output_dir / f"{pdb_id}.cif" + + if output_path.exists() and output_path.stat().st_size > 0: + return pdb_id, output_path, True + + for attempt in range(MAX_RETRIES): + try: + urlretrieve(url, output_path) + return pdb_id, output_path, True + except (HTTPError, URLError, TimeoutError, OSError): + wait = 2**attempt + if attempt < MAX_RETRIES - 1: + time.sleep(wait) + + return pdb_id, None, False + + +def compute_sha256(file_path): + """Compute SHA256 checksum for a file.""" + digest = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + digest.update(chunk) + return digest.hexdigest() + + +def parse_mmcif(cif_path, pdb_id, min_residues=50, max_residues=300): + """Parse a mmCIF file and extract sequence + Ca coordinates.""" + parser = MMCIFParser(QUIET=True) + try: + structure = parser.get_structure(pdb_id, str(cif_path)) + except Exception as e: + logger.debug("Failed to parse %s: %s", pdb_id, e) + return None + + model = structure[0] + + for chain in model: + residues = [] + for res in chain.get_residues(): + if res.id[0] != " ": + continue + resname = res.get_resname().strip() + if resname not in AA_3TO1: + continue + residues.append(res) + + if len(residues) < min_residues: + continue + if len(residues) > max_residues: + residues = residues[:max_residues] + + sequence = [] + coords = [] + ca_mask = [] + + for res in residues: + resname = res.get_resname().strip() + sequence.append(AA_3TO1[resname]) + + if "CA" in res: + ca = res["CA"].get_vector() + coords.append([float(ca[0]), float(ca[1]), float(ca[2])]) + ca_mask.append(1) + else: + coords.append([0.0, 0.0, 0.0]) + ca_mask.append(0) + + completeness = sum(ca_mask) / len(ca_mask) + if completeness < MIN_CA_COMPLETENESS: + continue + + coords_arr = np.array(coords) + if not np.all(np.isfinite(coords_arr)): + continue + + return { + "pdb_id": pdb_id, + "chain_id": chain.id, + "sequence": "".join(sequence), + "coords": coords, + "ca_mask": ca_mask, + "num_residues": len(residues), + } + + return None + + +def main(): + ap = argparse.ArgumentParser(description="Prepare eval dataset for ESM2-MiniFold TE") + ap.add_argument("--max-structures", type=int, default=150, help="Target number of eval structures") + ap.add_argument("--max-resolution", type=float, default=1.5, help="Max X-ray resolution (Angstroms)") + ap.add_argument("--min-length", type=int, default=50, help="Min polymer entity length") + ap.add_argument("--max-length", type=int, default=300, help="Max polymer entity length") + ap.add_argument("--output-dir", type=str, default=None, help="Output directory (default: data/)") + ap.add_argument("--download-workers", type=int, default=8, help="Parallel download threads") + ap.add_argument("--training-ids-file", type=str, default=None, help="Training PDB IDs to exclude") + args = ap.parse_args() + + output_dir = Path(args.output_dir) if args.output_dir else SCRIPT_DIR + cif_dir = output_dir / "eval_cif_files" + cif_dir.mkdir(parents=True, exist_ok=True) + + # Step 1: Load training PDB IDs to exclude + train_ids_path = Path(args.training_ids_file) if args.training_ids_file else None + train_ids = read_training_pdb_ids(train_ids_path) + logger.info("Excluding %d training PDB IDs from eval set", len(train_ids)) + + # Step 2: Query RCSB — request extra to account for dedup + failures + query_count = args.max_structures + len(train_ids) + 200 + pdb_ids = query_rcsb( + max_resolution=args.max_resolution, + min_length=args.min_length, + max_length=args.max_length, + max_results=query_count, + ) + + # Deduplicate against training set + pdb_ids = [pid for pid in pdb_ids if pid.upper() not in train_ids] + logger.info( + "After dedup: %d candidates (removed %d training overlaps)", len(pdb_ids), query_count - len(pdb_ids) - 200 + ) + pdb_ids = pdb_ids[: args.max_structures + 100] # Keep some buffer for parse failures + + # Step 3: Download in parallel + logger.info("Downloading %d CIF files with %d workers...", len(pdb_ids), args.download_workers) + downloaded = {} + failed_download = [] + + with ThreadPoolExecutor(max_workers=args.download_workers) as pool: + futures = {pool.submit(download_cif, pid, cif_dir): pid for pid in pdb_ids} + done_count = 0 + for future in as_completed(futures): + pdb_id, path, success = future.result() + done_count += 1 + if success: + downloaded[pdb_id] = path + else: + failed_download.append(pdb_id) + if done_count % 100 == 0: + logger.info("Download progress: %d/%d done, %d succeeded", done_count, len(pdb_ids), len(downloaded)) + + logger.info("Download complete: %d succeeded, %d failed", len(downloaded), len(failed_download)) + + # Step 4: Parse structures + logger.info("Parsing %d CIF files...", len(downloaded)) + records = [] + manifest = {} + failed_parse = [] + + for i, (pdb_id, cif_path) in enumerate(downloaded.items()): + if len(records) >= args.max_structures: + break + + record = parse_mmcif(cif_path, pdb_id, min_residues=args.min_length, max_residues=args.max_length) + if record is None: + failed_parse.append(pdb_id) + else: + records.append(record) + manifest[pdb_id] = { + "sha256": compute_sha256(cif_path), + "chain_id": record["chain_id"], + "num_residues": record["num_residues"], + "sequence_length": len(record["sequence"]), + } + + if (i + 1) % 100 == 0: + logger.info("Parse progress: %d/%d processed, %d valid", i + 1, len(downloaded), len(records)) + + # Step 5: Write outputs + output_parquet = output_dir / "eval_structures.parquet" + output_manifest = output_dir / "eval_manifest.json" + output_ids = output_dir / "eval_pdb_ids.txt" + + df = pd.DataFrame(records) + df.to_parquet(str(output_parquet), index=False) + logger.info("Wrote %d eval structures to %s", len(df), output_parquet) + + with open(output_manifest, "w") as f: + json.dump(manifest, f, indent=2) + + with open(output_ids, "w") as f: + f.write("# Eval PDB IDs for ESM2-MiniFold TE\n") + f.write(f"# resolution<={args.max_resolution}A, Ca completeness>={MIN_CA_COMPLETENESS}\n") + f.write(f"# Total: {len(records)} structures (excluded {len(train_ids)} training IDs)\n") + for r in records: + f.write(r["pdb_id"] + "\n") + + # Summary + if records: + lengths = [r["num_residues"] for r in records] + logger.info("=== Eval Dataset Summary ===") + logger.info("Valid structures: %d", len(records)) + logger.info("Failed download: %d, failed parse: %d", len(failed_download), len(failed_parse)) + logger.info( + "Residue lengths: min=%d, max=%d, mean=%.0f, median=%.0f", + min(lengths), + max(lengths), + np.mean(lengths), + np.median(lengths), + ) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/esm2_minifold_te/data/prepare_pdb_dataset.py b/bionemo-recipes/recipes/esm2_minifold_te/data/prepare_pdb_dataset.py new file mode 100644 index 0000000000..5b7c875283 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/data/prepare_pdb_dataset.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download PDB structures and prepare parquet dataset for ESM2-MiniFold TE training. + +Downloads mmCIF files from RCSB PDB, extracts Ca coordinates and sequences, +validates structures, and exports to parquet format. + +Usage: + python data/prepare_pdb_dataset.py + python data/prepare_pdb_dataset.py --max-structures 50 # limit for testing +""" + +import hashlib +import json +import logging +import time +from pathlib import Path +from urllib.error import HTTPError, URLError +from urllib.request import urlretrieve + +import numpy as np +import pandas as pd +from Bio.PDB.MMCIFParser import MMCIFParser + + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +logger = logging.getLogger(__name__) + +RCSB_URL = "https://files.rcsb.org/download/{pdb_id}.cif" +SCRIPT_DIR = Path(__file__).resolve().parent +PDB_IDS_FILE = SCRIPT_DIR / "pdb_ids.txt" +OUTPUT_PARQUET = SCRIPT_DIR / "pdb_structures.parquet" +OUTPUT_CIF_DIR = SCRIPT_DIR / "cif_files" +OUTPUT_MANIFEST = SCRIPT_DIR / "pdb_manifest.json" + +# 3-letter to 1-letter amino acid mapping +AA_3TO1 = { + "ALA": "A", + "CYS": "C", + "ASP": "D", + "GLU": "E", + "PHE": "F", + "GLY": "G", + "HIS": "H", + "ILE": "I", + "LYS": "K", + "LEU": "L", + "MET": "M", + "ASN": "N", + "PRO": "P", + "GLN": "Q", + "ARG": "R", + "SER": "S", + "THR": "T", + "VAL": "V", + "TRP": "W", + "TYR": "Y", + "MSE": "M", # Selenomethionine -> Methionine +} + +MIN_RESIDUES = 50 +MAX_RESIDUES = 300 +MIN_CA_COMPLETENESS = 0.9 +MAX_RETRIES = 3 +DOWNLOAD_DELAY = 0.5 # seconds between RCSB requests + + +def compute_sha256(file_path): + """Compute SHA256 checksum for a file.""" + digest = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + digest.update(chunk) + return digest.hexdigest() + + +def read_pdb_ids(path): + """Read PDB IDs from a text file (one per line, # comments allowed).""" + ids = [] + with open(path) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + ids.append(line.upper()) + return ids + + +def download_cif(pdb_id, output_dir, timeout=30): + """Download a mmCIF file from RCSB PDB with retry logic. + + Returns path to downloaded file, or None on failure. + """ + url = RCSB_URL.format(pdb_id=pdb_id) + output_path = output_dir / f"{pdb_id}.cif" + + if output_path.exists(): + return output_path + + for attempt in range(MAX_RETRIES): + try: + urlretrieve(url, output_path) + return output_path + except (HTTPError, URLError, TimeoutError) as e: + wait = 2**attempt + logger.warning("Download %s attempt %d failed: %s (retry in %ds)", pdb_id, attempt + 1, e, wait) + time.sleep(wait) + + logger.error("Failed to download %s after %d attempts", pdb_id, MAX_RETRIES) + return None + + +def parse_mmcif(cif_path, pdb_id): + """Parse a mmCIF file and extract sequence + Ca coordinates. + + Returns dict with keys: pdb_id, chain_id, sequence, coords, ca_mask, num_residues. + Returns None if parsing fails or structure is invalid. + """ + parser = MMCIFParser(QUIET=True) + try: + structure = parser.get_structure(pdb_id, str(cif_path)) + except Exception as e: + logger.warning("Failed to parse %s: %s", pdb_id, e) + return None + + model = structure[0] + + # Find first protein chain with enough residues + for chain in model: + residues = [] + for res in chain.get_residues(): + # Skip heteroatoms and water + if res.id[0] != " ": + continue + resname = res.get_resname().strip() + if resname not in AA_3TO1: + continue + residues.append(res) + + if len(residues) < MIN_RESIDUES: + continue + if len(residues) > MAX_RESIDUES: + residues = residues[:MAX_RESIDUES] + + # Extract sequence and Ca coordinates + sequence = [] + coords = [] + ca_mask = [] + + for res in residues: + resname = res.get_resname().strip() + sequence.append(AA_3TO1[resname]) + + if "CA" in res: + ca = res["CA"].get_vector() + coords.append([float(ca[0]), float(ca[1]), float(ca[2])]) + ca_mask.append(1) + else: + coords.append([0.0, 0.0, 0.0]) + ca_mask.append(0) + + # Validate Ca completeness + completeness = sum(ca_mask) / len(ca_mask) + if completeness < MIN_CA_COMPLETENESS: + logger.warning( + "Skipping %s chain %s: Ca completeness %.1f%% < %.0f%%", + pdb_id, + chain.id, + completeness * 100, + MIN_CA_COMPLETENESS * 100, + ) + continue + + # Validate coordinates are finite + coords_arr = np.array(coords) + if not np.all(np.isfinite(coords_arr)): + logger.warning("Skipping %s chain %s: non-finite coordinates", pdb_id, chain.id) + continue + + return { + "pdb_id": pdb_id, + "chain_id": chain.id, + "sequence": "".join(sequence), + "coords": coords, + "ca_mask": ca_mask, + "num_residues": len(residues), + } + + logger.warning("No valid chain found in %s", pdb_id) + return None + + +def main(): + """Download PDB structures and create parquet dataset.""" + import argparse + + parser = argparse.ArgumentParser(description="Prepare PDB dataset for ESM2-MiniFold TE") + parser.add_argument("--max-structures", type=int, default=None, help="Limit number of structures (for testing)") + args = parser.parse_args() + + pdb_ids = read_pdb_ids(PDB_IDS_FILE) + if args.max_structures: + pdb_ids = pdb_ids[: args.max_structures] + + logger.info("Processing %d PDB IDs", len(pdb_ids)) + + OUTPUT_CIF_DIR.mkdir(parents=True, exist_ok=True) + + records = [] + manifest = {} + failed = [] + + for i, pdb_id in enumerate(pdb_ids): + if i > 0: + time.sleep(DOWNLOAD_DELAY) + + # Download + cif_path = download_cif(pdb_id, OUTPUT_CIF_DIR) + if cif_path is None: + failed.append(pdb_id) + continue + + sha256 = compute_sha256(cif_path) + + # Parse + record = parse_mmcif(cif_path, pdb_id) + if record is None: + failed.append(pdb_id) + continue + + records.append(record) + manifest[pdb_id] = { + "sha256": sha256, + "chain_id": record["chain_id"], + "num_residues": record["num_residues"], + "sequence_length": len(record["sequence"]), + } + + if (i + 1) % 20 == 0: + logger.info( + "Progress: %d/%d downloaded, %d valid, %d failed", i + 1, len(pdb_ids), len(records), len(failed) + ) + + # Write parquet + df = pd.DataFrame(records) + df.to_parquet(OUTPUT_PARQUET, index=False) + logger.info("Wrote %d structures to %s", len(df), OUTPUT_PARQUET) + + # Write manifest + with open(OUTPUT_MANIFEST, "w") as f: + json.dump(manifest, f, indent=2) + logger.info("Wrote manifest to %s", OUTPUT_MANIFEST) + + # Summary + if records: + lengths = [r["num_residues"] for r in records] + logger.info("Summary: %d valid structures, %d failed", len(records), len(failed)) + logger.info("Residue lengths: min=%d, max=%d, mean=%.0f", min(lengths), max(lengths), np.mean(lengths)) + + if failed: + logger.warning("Failed PDB IDs: %s", ", ".join(failed)) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/esm2_minifold_te/data/prepare_pdb_dataset_large.py b/bionemo-recipes/recipes/esm2_minifold_te/data/prepare_pdb_dataset_large.py new file mode 100644 index 0000000000..c44bf73470 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/data/prepare_pdb_dataset_large.py @@ -0,0 +1,417 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download a large PDB dataset via RCSB Search API for ESM2-MiniFold TE training. + +Queries RCSB PDB for high-quality X-ray structures, downloads mmCIF files in +parallel, parses Ca coordinates, and exports to parquet. Designed for cluster use. + +Usage: + # Default: ~10k structures, 8 download workers + python data/prepare_pdb_dataset_large.py + + # Smaller test run + python data/prepare_pdb_dataset_large.py --max-structures 500 + + # Custom output directory (e.g., on a fast scratch disk) + python data/prepare_pdb_dataset_large.py --output-dir /scratch/$USER/pdb_data + + # Custom resolution and length filters + python data/prepare_pdb_dataset_large.py --max-resolution 2.0 --min-length 80 --max-length 250 +""" + +import argparse +import hashlib +import json +import logging +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen, urlretrieve + +import numpy as np +import pandas as pd +from Bio.PDB.MMCIFParser import MMCIFParser + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") +logger = logging.getLogger(__name__) + +RCSB_SEARCH_URL = "https://search.rcsb.org/rcsbsearch/v2/query" +RCSB_DOWNLOAD_URL = "https://files.rcsb.org/download/{pdb_id}.cif" + +AA_3TO1 = { + "ALA": "A", + "CYS": "C", + "ASP": "D", + "GLU": "E", + "PHE": "F", + "GLY": "G", + "HIS": "H", + "ILE": "I", + "LYS": "K", + "LEU": "L", + "MET": "M", + "ASN": "N", + "PRO": "P", + "GLN": "Q", + "ARG": "R", + "SER": "S", + "THR": "T", + "VAL": "V", + "TRP": "W", + "TYR": "Y", + "MSE": "M", +} + +MAX_RETRIES = 3 +MIN_CA_COMPLETENESS = 0.9 + + +def query_rcsb(max_resolution=2.5, min_length=50, max_length=300, max_results=15000): + """Query RCSB Search API for high-quality protein structures. + + Filters: + - X-ray diffraction only + - Resolution better than max_resolution + - Polymer entity length between min_length and max_length + - Protein entity type + + Paginates automatically (RCSB caps at 10,000 rows per request). + + Returns list of PDB IDs (4-letter codes, uppercase). + """ + PAGE_SIZE = 10000 # RCSB maximum rows per request + + base_query = { + "query": { + "type": "group", + "logical_operator": "and", + "nodes": [ + { + "type": "terminal", + "service": "text", + "parameters": { + "attribute": "exptl.method", + "operator": "exact_match", + "value": "X-RAY DIFFRACTION", + }, + }, + { + "type": "terminal", + "service": "text", + "parameters": { + "attribute": "rcsb_entry_info.resolution_combined", + "operator": "less_or_equal", + "value": max_resolution, + }, + }, + { + "type": "terminal", + "service": "text", + "parameters": { + "attribute": "entity_poly.rcsb_entity_polymer_type", + "operator": "exact_match", + "value": "Protein", + }, + }, + { + "type": "terminal", + "service": "text", + "parameters": { + "attribute": "entity_poly.rcsb_sample_sequence_length", + "operator": "range", + "value": {"from": min_length, "to": max_length}, + }, + }, + ], + }, + "return_type": "entry", + "request_options": { + "results_content_type": ["experimental"], + "sort": [{"sort_by": "rcsb_entry_info.resolution_combined", "direction": "asc"}], + }, + } + + logger.info( + "Querying RCSB: resolution<=%.1fA, length %d-%d, max %d results", + max_resolution, + min_length, + max_length, + max_results, + ) + + pdb_ids = [] + start = 0 + total_count = None + + while len(pdb_ids) < max_results: + rows = min(PAGE_SIZE, max_results - len(pdb_ids)) + query = { + **base_query, + "request_options": { + **base_query["request_options"], + "paginate": {"start": start, "rows": rows}, + }, + } + + req = Request( + RCSB_SEARCH_URL, + data=json.dumps(query).encode("utf-8"), + headers={"Content-Type": "application/json"}, + ) + + with urlopen(req, timeout=60) as resp: + data = json.loads(resp.read().decode("utf-8")) + + if total_count is None: + total_count = data.get("total_count", 0) + + page_ids = [r["identifier"] for r in data.get("result_set", [])] + if not page_ids: + break + + pdb_ids.extend(page_ids) + start += len(page_ids) + logger.info( + "RCSB page: fetched %d (total so far: %d, available: %d)", len(page_ids), len(pdb_ids), total_count + ) + + if start >= total_count: + break + + logger.info("RCSB query complete: %d results (total available: %d)", len(pdb_ids), total_count) + return pdb_ids + + +def download_cif(pdb_id, output_dir, timeout=30): + """Download a single mmCIF file with retry. Returns (pdb_id, path, success).""" + url = RCSB_DOWNLOAD_URL.format(pdb_id=pdb_id) + output_path = output_dir / f"{pdb_id}.cif" + + if output_path.exists() and output_path.stat().st_size > 0: + return pdb_id, output_path, True + + for attempt in range(MAX_RETRIES): + try: + urlretrieve(url, output_path) + return pdb_id, output_path, True + except (HTTPError, URLError, TimeoutError, OSError) as e: + wait = 2**attempt + if attempt < MAX_RETRIES - 1: + time.sleep(wait) + + return pdb_id, None, False + + +def compute_sha256(file_path): + """Compute SHA256 checksum for a file.""" + digest = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + digest.update(chunk) + return digest.hexdigest() + + +def parse_mmcif(cif_path, pdb_id, min_residues=50, max_residues=300): + """Parse a mmCIF file and extract sequence + Ca coordinates. + + Returns dict or None if parsing fails or structure is invalid. + """ + parser = MMCIFParser(QUIET=True) + try: + structure = parser.get_structure(pdb_id, str(cif_path)) + except Exception as e: + logger.debug("Failed to parse %s: %s", pdb_id, e) + return None + + model = structure[0] + + for chain in model: + residues = [] + for res in chain.get_residues(): + if res.id[0] != " ": + continue + resname = res.get_resname().strip() + if resname not in AA_3TO1: + continue + residues.append(res) + + if len(residues) < min_residues: + continue + if len(residues) > max_residues: + residues = residues[:max_residues] + + sequence = [] + coords = [] + ca_mask = [] + + for res in residues: + resname = res.get_resname().strip() + sequence.append(AA_3TO1[resname]) + + if "CA" in res: + ca = res["CA"].get_vector() + coords.append([float(ca[0]), float(ca[1]), float(ca[2])]) + ca_mask.append(1) + else: + coords.append([0.0, 0.0, 0.0]) + ca_mask.append(0) + + completeness = sum(ca_mask) / len(ca_mask) + if completeness < MIN_CA_COMPLETENESS: + continue + + coords_arr = np.array(coords) + if not np.all(np.isfinite(coords_arr)): + continue + + return { + "pdb_id": pdb_id, + "chain_id": chain.id, + "sequence": "".join(sequence), + "coords": coords, + "ca_mask": ca_mask, + "num_residues": len(residues), + } + + return None + + +def main(): + parser = argparse.ArgumentParser(description="Prepare large PDB dataset for ESM2-MiniFold TE") + parser.add_argument("--max-structures", type=int, default=10000, help="Maximum structures to download") + parser.add_argument("--max-resolution", type=float, default=2.5, help="Max X-ray resolution in Angstroms") + parser.add_argument("--min-length", type=int, default=50, help="Min polymer entity length") + parser.add_argument("--max-length", type=int, default=300, help="Max polymer entity length") + parser.add_argument("--output-dir", type=str, default=None, help="Output directory (default: data/)") + parser.add_argument("--download-workers", type=int, default=8, help="Parallel download threads") + parser.add_argument("--parse-workers", type=int, default=4, help="Parallel parse threads") + args = parser.parse_args() + + script_dir = Path(__file__).resolve().parent + output_dir = Path(args.output_dir) if args.output_dir else script_dir + cif_dir = output_dir / "cif_files" + cif_dir.mkdir(parents=True, exist_ok=True) + + # Step 1: Query RCSB for PDB IDs + pdb_ids = query_rcsb( + max_resolution=args.max_resolution, + min_length=args.min_length, + max_length=args.max_length, + max_results=args.max_structures + 2000, # Query extra to account for failures + ) + pdb_ids = pdb_ids[: args.max_structures] + logger.info("Will process %d PDB IDs", len(pdb_ids)) + + # Step 2: Download in parallel + logger.info("Downloading CIF files with %d workers...", args.download_workers) + downloaded = {} + failed_download = [] + + with ThreadPoolExecutor(max_workers=args.download_workers) as pool: + futures = {pool.submit(download_cif, pid, cif_dir): pid for pid in pdb_ids} + done_count = 0 + for future in as_completed(futures): + pdb_id, path, success = future.result() + done_count += 1 + if success: + downloaded[pdb_id] = path + else: + failed_download.append(pdb_id) + if done_count % 500 == 0: + logger.info( + "Download progress: %d/%d done, %d succeeded, %d failed", + done_count, + len(pdb_ids), + len(downloaded), + len(failed_download), + ) + + logger.info("Download complete: %d succeeded, %d failed", len(downloaded), len(failed_download)) + + # Step 3: Parse structures + logger.info("Parsing %d CIF files...", len(downloaded)) + records = [] + manifest = {} + failed_parse = [] + + # Parse sequentially (BioPython's C extensions may not be thread-safe) + for i, (pdb_id, cif_path) in enumerate(downloaded.items()): + record = parse_mmcif(cif_path, pdb_id, min_residues=args.min_length, max_residues=args.max_length) + if record is None: + failed_parse.append(pdb_id) + else: + records.append(record) + manifest[pdb_id] = { + "sha256": compute_sha256(cif_path), + "chain_id": record["chain_id"], + "num_residues": record["num_residues"], + "sequence_length": len(record["sequence"]), + } + + if (i + 1) % 500 == 0: + logger.info( + "Parse progress: %d/%d processed, %d valid, %d failed", + i + 1, + len(downloaded), + len(records), + len(failed_parse), + ) + + # Step 4: Write outputs + output_parquet = output_dir / "pdb_structures.parquet" + output_manifest = output_dir / "pdb_manifest.json" + output_ids = output_dir / "pdb_ids.txt" + + df = pd.DataFrame(records) + df.to_parquet(str(output_parquet), index=False) + logger.info("Wrote %d structures to %s", len(df), output_parquet) + + with open(output_manifest, "w") as f: + json.dump(manifest, f, indent=2) + logger.info("Wrote manifest to %s", output_manifest) + + # Write PDB IDs file for reproducibility + with open(output_ids, "w") as f: + f.write("# PDB IDs for ESM2-MiniFold TE training\n") + f.write(f"# Generated: resolution<={args.max_resolution}A, length {args.min_length}-{args.max_length}\n") + f.write(f"# Total: {len(records)} valid structures\n") + for r in records: + f.write(r["pdb_id"] + "\n") + logger.info("Wrote PDB IDs to %s", output_ids) + + # Summary + if records: + lengths = [r["num_residues"] for r in records] + logger.info("=== Summary ===") + logger.info("Valid structures: %d", len(records)) + logger.info("Failed download: %d", len(failed_download)) + logger.info("Failed parse: %d", len(failed_parse)) + logger.info( + "Residue lengths: min=%d, max=%d, mean=%.0f, median=%.0f", + min(lengths), + max(lengths), + np.mean(lengths), + np.median(lengths), + ) + logger.info("Output: %s (%.1f MB)", output_parquet, output_parquet.stat().st_size / 1e6) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/esm2_minifold_te/dataset.py b/bionemo-recipes/recipes/esm2_minifold_te/dataset.py new file mode 100644 index 0000000000..b72305bb8e --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/dataset.py @@ -0,0 +1,386 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset for protein structure prediction training. + +Provides: +- SyntheticStructureDataset: generates random data for testing +- ParquetStructureDataset: loads from parquet (pre-processed Ca coords) +- MmcifStructureDataset: loads from mmCIF files on-the-fly via BioPython +- create_dataloader: factory function for any dataset type +""" + +import logging +from pathlib import Path +from typing import ClassVar + +import torch +from torch.utils.data import DataLoader, Dataset, DistributedSampler + +from distributed_config import DistributedConfig + + +logger = logging.getLogger(__name__) + + +class SyntheticStructureDataset(Dataset): + """Generates synthetic protein structure data for testing. + + Each sample contains: + input_ids: Random ESM-2 token IDs (L,) + attention_mask: All ones (L,) + mask: All ones (L,) as float + coords: Random Ca coordinates (L, 3) + """ + + def __init__(self, num_samples: int = 1000, max_seq_length: int = 128, seed: int = 42): + self.num_samples = num_samples + self.max_seq_length = max_seq_length + self.rng = torch.Generator().manual_seed(seed) + + # ESM-2 special tokens: 0=cls, 1=pad, 2=eos, 3=unk, then 4-23 are AA tokens + self.vocab_start = 4 + self.vocab_end = 24 # 20 amino acid tokens + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + # Random sequence length (at least 16 residues) + seq_len = torch.randint(16, self.max_seq_length - 2, (1,), generator=self.rng).item() + + # Generate random tokens: [CLS] + AA tokens + [EOS] + padding + tokens = torch.randint(self.vocab_start, self.vocab_end, (seq_len,), generator=self.rng) + input_ids = torch.zeros(self.max_seq_length, dtype=torch.long) + input_ids[0] = 0 # CLS + input_ids[1 : seq_len + 1] = tokens + input_ids[seq_len + 1] = 2 # EOS + + # Attention mask: 1 for real tokens, 0 for padding + attention_mask = torch.zeros(self.max_seq_length, dtype=torch.long) + attention_mask[: seq_len + 2] = 1 # CLS + tokens + EOS + + # Residue mask: 1 for real residues (excludes special tokens) + mask = torch.zeros(self.max_seq_length, dtype=torch.float) + mask[: seq_len + 2] = 1.0 + + # Random Ca coordinates (Angstroms) + coords = torch.randn(self.max_seq_length, 3) * 10.0 + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "mask": mask, + "coords": coords, + } + + +class ParquetStructureDataset(Dataset): + """Loads protein structures from a parquet file. + + Expected columns: + sequence: str - amino acid sequence + coords: list[list[float]] - Ca coordinates (N, 3) + ca_mask: list[int] - optional, 1=valid Ca, 0=missing (defaults to all-1s) + """ + + def __init__(self, parquet_path: str, tokenizer, max_seq_length: int = 256): + import pandas as pd + + self.df = pd.read_parquet(parquet_path) + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.has_ca_mask = "ca_mask" in self.df.columns + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + sequence = row["sequence"] + + # Tokenize + encoded = self.tokenizer( + sequence, + max_length=self.max_seq_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + input_ids = encoded["input_ids"].squeeze(0) + attention_mask = encoded["attention_mask"].squeeze(0) + + # Mask is same as attention_mask but as float + mask = attention_mask.float() + + # Coordinates: pad to max_seq_length + import numpy as np + + # Parquet stores list-of-lists as numpy object array of arrays; np.stack handles this + coords_raw = torch.from_numpy(np.stack(row["coords"]).astype(np.float32)) + coords = torch.zeros(self.max_seq_length, 3) + seq_len = min(len(coords_raw), self.max_seq_length) + coords[:seq_len] = coords_raw[:seq_len] + + # Zero out coords for residues with missing Ca atoms + if self.has_ca_mask: + ca_mask_list = row["ca_mask"] + for i in range(min(len(ca_mask_list), self.max_seq_length)): + if ca_mask_list[i] == 0: + coords[i] = 0.0 + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "mask": mask, + "coords": coords, + } + + +class MmcifStructureDataset(Dataset): + """Loads protein structures directly from mmCIF files via BioPython. + + Parses each .cif file on-the-fly, extracts the amino acid sequence and Ca + coordinates, tokenizes with ESM-2, and returns the standard batch format. + """ + + # 3-letter to 1-letter amino acid mapping + AA_3TO1: ClassVar[dict[str, str]] = { + "ALA": "A", + "CYS": "C", + "ASP": "D", + "GLU": "E", + "PHE": "F", + "GLY": "G", + "HIS": "H", + "ILE": "I", + "LYS": "K", + "LEU": "L", + "MET": "M", + "ASN": "N", + "PRO": "P", + "GLN": "Q", + "ARG": "R", + "SER": "S", + "THR": "T", + "VAL": "V", + "TRP": "W", + "TYR": "Y", + "MSE": "M", + } + + def __init__( + self, + cif_dir: str, + tokenizer, + max_seq_length: int = 256, + pdb_ids: list[str] | None = None, + min_residues: int = 50, + max_residues: int = 300, + min_ca_completeness: float = 0.9, + ): + from Bio.PDB.MMCIFParser import MMCIFParser + + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.min_residues = min_residues + self.max_residues = max_residues + self.min_ca_completeness = min_ca_completeness + self.parser = MMCIFParser(QUIET=True) + + cif_path = Path(cif_dir) + all_files = sorted(cif_path.glob("*.cif")) + + if pdb_ids is not None: + # Preserve caller's ordering (e.g., to match parquet row order) + file_by_id = {f.stem.upper(): f for f in all_files} + self.files = [file_by_id[pid.upper()] for pid in pdb_ids if pid.upper() in file_by_id] + else: + self.files = all_files + + if not self.files: + raise FileNotFoundError(f"No .cif files found in {cif_dir}") + + logger.info("MmcifStructureDataset: %d CIF files from %s", len(self.files), cif_dir) + + def __len__(self): + return len(self.files) + + def _parse_cif(self, cif_path): + """Parse mmCIF file and extract sequence + Ca coordinates. + + Uses the same filtering as prepare_pdb_dataset.py: min/max residues, + Ca completeness threshold, and truncation to max_residues. + + Returns (sequence, ca_coords, ca_mask) or raises on failure. + """ + pdb_id = cif_path.stem + structure = self.parser.get_structure(pdb_id, str(cif_path)) + model = structure[0] + + for chain in model: + residues = [] + for res in chain.get_residues(): + if res.id[0] != " ": + continue + resname = res.get_resname().strip() + if resname not in self.AA_3TO1: + continue + residues.append(res) + + if len(residues) < self.min_residues: + continue + if len(residues) > self.max_residues: + residues = residues[: self.max_residues] + + sequence = [] + coords = [] + ca_mask = [] + for res in residues: + resname = res.get_resname().strip() + sequence.append(self.AA_3TO1[resname]) + if "CA" in res: + ca = res["CA"].get_vector() + coords.append([float(ca[0]), float(ca[1]), float(ca[2])]) + ca_mask.append(1) + else: + coords.append([0.0, 0.0, 0.0]) + ca_mask.append(0) + + completeness = sum(ca_mask) / len(ca_mask) + if completeness < self.min_ca_completeness: + continue + + return "".join(sequence), coords, ca_mask + + raise ValueError(f"No valid protein chain in {pdb_id}") + + def __getitem__(self, idx): + try: + sequence, ca_coords, ca_mask = self._parse_cif(self.files[idx]) + except Exception as e: + logger.warning("Failed to parse %s: %s, falling back to index 0", self.files[idx].name, e) + sequence, ca_coords, ca_mask = self._parse_cif(self.files[0]) + + # Tokenize + encoded = self.tokenizer( + sequence, + max_length=self.max_seq_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + input_ids = encoded["input_ids"].squeeze(0) + attention_mask = encoded["attention_mask"].squeeze(0) + mask = attention_mask.float() + + # Coordinates: pad to max_seq_length + coords_raw = torch.tensor(ca_coords, dtype=torch.float32) + coords = torch.zeros(self.max_seq_length, 3) + seq_len = min(len(coords_raw), self.max_seq_length) + coords[:seq_len] = coords_raw[:seq_len] + + # Zero out missing Ca positions + for i in range(seq_len): + if ca_mask[i] == 0: + coords[i] = 0.0 + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "mask": mask, + "coords": coords, + } + + +def create_dataloader( + dist_config: DistributedConfig, + micro_batch_size: int = 2, + max_seq_length: int = 128, + num_workers: int = 0, + dataset_type: str = "synthetic", + parquet_path: str | None = None, + tokenizer_name: str | None = None, + num_samples: int = 1000, + cif_dir: str | None = None, + pdb_ids: list[str] | None = None, + shuffle: bool = True, + drop_last: bool = True, + **kwargs, +): + """Create a DataLoader for structure prediction training or evaluation. + + Args: + dist_config: Distributed training configuration. + micro_batch_size: Batch size per GPU. + max_seq_length: Maximum sequence length. + num_workers: Number of DataLoader workers. + dataset_type: "synthetic", "parquet", or "mmcif". + parquet_path: Path to parquet file (required if dataset_type="parquet"). + tokenizer_name: HuggingFace tokenizer name (required if dataset_type="parquet" or "mmcif"). + num_samples: Number of synthetic samples. + cif_dir: Directory with .cif files (required if dataset_type="mmcif"). + pdb_ids: Optional list of PDB IDs to filter (for dataset_type="mmcif"). + shuffle: Whether to shuffle the data (False for eval). + drop_last: Whether to drop the last incomplete batch (False for eval). + **kwargs: Additional keyword arguments (ignored). + + Returns: + Tuple of (DataLoader, DistributedSampler). + """ + if dataset_type == "synthetic": + dataset = SyntheticStructureDataset( + num_samples=num_samples, + max_seq_length=max_seq_length, + ) + elif dataset_type == "parquet": + from transformers import EsmTokenizer + + tokenizer = EsmTokenizer.from_pretrained(tokenizer_name) + dataset = ParquetStructureDataset( + parquet_path=parquet_path, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + ) + elif dataset_type == "mmcif": + from transformers import EsmTokenizer + + tokenizer = EsmTokenizer.from_pretrained(tokenizer_name) + dataset = MmcifStructureDataset( + cif_dir=cif_dir, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + pdb_ids=pdb_ids, + ) + else: + raise ValueError(f"Unknown dataset_type: {dataset_type}") + + sampler = DistributedSampler( + dataset, + num_replicas=dist_config.world_size, + rank=dist_config.rank, + shuffle=shuffle, + ) + + dataloader = DataLoader( + dataset, + batch_size=micro_batch_size, + sampler=sampler, + num_workers=num_workers, + pin_memory=True, + drop_last=drop_last, + ) + + return dataloader, sampler diff --git a/bionemo-recipes/recipes/esm2_minifold_te/distributed_config.py b/bionemo-recipes/recipes/esm2_minifold_te/distributed_config.py new file mode 100644 index 0000000000..271a5ffcfc --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/distributed_config.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from dataclasses import dataclass, field + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class DistributedConfig: + """Class to track distributed ranks and handle basic distributed training setup. + + If torch distributed environment variables are not set, we set them to default values for single-process training. + + Attributes: + rank: The rank of the process. + local_rank: The local rank of the process. + world_size: The total number of processes. + """ + + rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0"))) + local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0"))) + world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1"))) + _master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost")) + _master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355")) + + def is_main_process(self) -> bool: + """This is the global rank 0 process, to be used for wandb logging, etc.""" + return self.rank == 0 diff --git a/bionemo-recipes/recipes/esm2_minifold_te/esm_backbone.py b/bionemo-recipes/recipes/esm2_minifold_te/esm_backbone.py new file mode 100644 index 0000000000..b674551a24 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/esm_backbone.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ESM-2 backbone wrapper using HuggingFace transformers. + +Loads a pretrained ESM-2 model and extracts: +- Per-residue embeddings from the last hidden layer +- Pairwise attention maps from all transformer layers + +The backbone is intended to be frozen during folding head training. +""" + +import torch +import torch.nn as nn +from transformers import EsmModel, EsmTokenizer + + +def load_esm2_backbone(model_name: str = "facebook/esm2_t33_650M_UR50D", device: str = "cuda"): + """Load a pretrained ESM-2 model from HuggingFace. + + Args: + model_name: HuggingFace model name. Common options: + - "facebook/esm2_t6_8M_UR50D" (8M params, 6 layers, 20 heads) + - "facebook/esm2_t12_35M_UR50D" (35M params, 12 layers, 20 heads) + - "facebook/esm2_t30_150M_UR50D" (150M params, 30 layers, 20 heads) + - "facebook/esm2_t33_650M_UR50D" (650M params, 33 layers, 20 heads) + - "facebook/esm2_t36_3B_UR50D" (3B params, 36 layers, 40 heads) + - "facebook/esm2_t48_15B_UR50D" (15B params, 48 layers, 40 heads) + device: Device to load the model onto. + + Returns: + Tuple of (model, tokenizer). + """ + tokenizer = EsmTokenizer.from_pretrained(model_name) + model = EsmModel.from_pretrained(model_name, attn_implementation="eager").to(device) + return model, tokenizer + + +class ESM2Backbone(nn.Module): + """Frozen ESM-2 backbone that extracts embeddings and attention maps. + + This wraps a HuggingFace EsmModel and provides a forward pass that returns + both per-residue embeddings and pairwise attention maps in the format + expected by MiniFold's folding head. + """ + + def __init__(self, model_name: str = "facebook/esm2_t33_650M_UR50D"): + super().__init__() + self.model = EsmModel.from_pretrained(model_name, attn_implementation="eager") + config = self.model.config + + self.embed_dim = config.hidden_size + self.num_layers = config.num_hidden_layers + self.num_heads = config.num_attention_heads + self.attn_dim = self.num_layers * self.num_heads + + # Freeze all parameters + for param in self.model.parameters(): + param.requires_grad = False + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None): + """Extract embeddings and attention maps from ESM-2. + + Args: + input_ids: Token IDs (B, L). Use ESM-2 tokenizer to encode sequences. + attention_mask: Optional attention mask (B, L). 1 = valid, 0 = padding. + + Returns: + Dict with: + "representations": Per-residue embeddings (B, L, embed_dim). + "attentions": Pairwise attention maps (B, L, L, num_layers * num_heads). + """ + with torch.no_grad(): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=True, + return_dict=True, + ) + + # Per-residue embeddings from last hidden state + representations = outputs.last_hidden_state # (B, L, embed_dim) + + # Stack attention maps from all layers + # Each layer returns (B, num_heads, L, L) + # Stack to (num_layers, B, num_heads, L, L) then rearrange to (B, L, L, num_layers * num_heads) + attn_stack = torch.stack(outputs.attentions, dim=0) # (num_layers, B, H, L, L) + B, H, L = attn_stack.shape[1], attn_stack.shape[2], attn_stack.shape[3] + # Rearrange: (num_layers, B, H, L, L) -> (B, L, L, num_layers, H) -> (B, L, L, num_layers * H) + attn_stack = attn_stack.permute(1, 3, 4, 0, 2) # (B, L, L, num_layers, H) + attentions = attn_stack.reshape(B, L, L, -1) # (B, L, L, num_layers * num_heads) + + return { + "representations": representations, + "attentions": attentions, + } + + +# ESM-2 model specs for reference +ESM2_MODELS = { + "facebook/esm2_t6_8M_UR50D": {"layers": 6, "embed_dim": 320, "heads": 20}, + "facebook/esm2_t12_35M_UR50D": {"layers": 12, "embed_dim": 480, "heads": 20}, + "facebook/esm2_t30_150M_UR50D": {"layers": 30, "embed_dim": 640, "heads": 20}, + "facebook/esm2_t33_650M_UR50D": {"layers": 33, "embed_dim": 1280, "heads": 20}, + "facebook/esm2_t36_3B_UR50D": {"layers": 36, "embed_dim": 2560, "heads": 40}, + "facebook/esm2_t48_15B_UR50D": {"layers": 48, "embed_dim": 5120, "heads": 40}, +} diff --git a/bionemo-recipes/recipes/esm2_minifold_te/eval_fsdp2.py b/bionemo-recipes/recipes/esm2_minifold_te/eval_fsdp2.py new file mode 100644 index 0000000000..326b325829 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/eval_fsdp2.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FSDP2 evaluation script for ESM2-MiniFold TE structure prediction. + +Loads a trained checkpoint and evaluates on a held-out dataset, reporting +structure quality metrics (lDDT, distogram accuracy, contact prediction) +to WandB and stdout. + +Usage: + # With FSDP2 distributed checkpoint + torchrun --nproc_per_node=2 eval_fsdp2.py checkpoint.ckpt_dir=/path/to/checkpoints + + # With exported safetensors model + torchrun --nproc_per_node=2 eval_fsdp2.py \ + checkpoint.ckpt_dir=/path/to/final_model \ + checkpoint.checkpoint_type=safetensors +""" + +import logging +import os +from pathlib import Path + +import hydra +import torch +from omegaconf import DictConfig, OmegaConf +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard +from tqdm import tqdm + +import wandb +from checkpoint import load_checkpoint_fsdp2 +from dataset import create_dataloader +from distributed_config import DistributedConfig +from modeling_esm2_minifold_te import ESM2MiniFoldTE +from quantization import ComponentPrecisionConfig, resolve_layer_precision +from scheduler import get_linear_schedule_with_warmup +from train_fsdp2 import compute_distogram_loss, compute_distogram_metrics + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@hydra.main(config_path="hydra_config", config_name="eval", version_base="1.2") +def main(args: DictConfig) -> None: + """Evaluate ESM2-MiniFold TE on a held-out dataset.""" + os.environ["HF_HUB_TRUST_REMOTE_CODE"] = "1" + logging.getLogger("httpx").setLevel(logging.WARNING) + + # Initialize distributed + dist_config = DistributedConfig() + logger.info("Initializing eval: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(dist_config.world_size,), + mesh_dim_names=("dp",), + ) + + # Resolve per-block quantization precision + block_precision = resolve_layer_precision( + num_layers=args.model.num_blocks, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + + fp8_recipe = None + fp4_recipe = None + if args.fp8_config.enabled: + from transformer_engine.common.recipe import Format + + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + if args.fp4_config.enabled: + from transformer_engine.common.recipe import Format + + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + + # Component-level precision overrides + component_precision = ComponentPrecisionConfig(**OmegaConf.to_container(args.component_precision, resolve=True)) + + # Create model (same architecture as training) + model = ESM2MiniFoldTE( + esm_model_name=args.esm_model_name, + c_s=args.model.c_s, + c_z=args.model.c_z, + num_blocks=args.model.num_blocks, + no_bins=args.model.no_bins, + use_structure_module=args.model.use_structure_module, + block_precision=block_precision, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + component_precision=component_precision, + ).to(device) + + # Load checkpoint + ckpt_dir = Path(args.checkpoint.ckpt_dir) + checkpoint_type = args.checkpoint.get("checkpoint_type", "fsdp2") + + if checkpoint_type == "safetensors": + # Load safetensors BEFORE FSDP2 sharding (plain tensors -> plain params) + from safetensors.torch import load_file + + state_dict = load_file(str(ckpt_dir / "model.safetensors")) + model.load_state_dict(state_dict, strict=False) + logger.info("Loaded safetensors model from %s", ckpt_dir) + + # FSDP2 sharding (must match training for FSDP2 checkpoint loading; + # also needed for multi-GPU eval even with safetensors) + mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) + for block in model.fold.miniformer.blocks: + fully_shard(block, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + + if checkpoint_type == "fsdp2": + # Need dummy optimizer/scheduler for the checkpoint loader + dummy_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + dummy_scheduler = get_linear_schedule_with_warmup(dummy_optimizer, num_warmup_steps=0, num_training_steps=1) + ckpt_path = ckpt_dir / "train_fsdp2" + model, _, _, _, loaded_step, _ = load_checkpoint_fsdp2( + model=model, + optimizer=dummy_optimizer, + scheduler=dummy_scheduler, + ckpt_path=ckpt_path, + dist_config=dist_config, + ) + logger.info("Loaded FSDP2 checkpoint from step %d", loaded_step) + elif checkpoint_type != "safetensors": + raise ValueError(f"Unknown checkpoint_type: {checkpoint_type}") + + if dist_config.is_main_process(): + logger.info("Block precision: %s", block_precision) + + # Create eval dataloader (shuffle=False, drop_last=False from config) + eval_dataloader, _ = create_dataloader(dist_config, **args.eval_dataset) + logger.info("Eval dataset: %d batches", len(eval_dataloader)) + + # Initialize WandB + run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) + if dist_config.is_main_process(): + wandb.init(**args.wandb_init_args, config=run_config) + + # Eval loop + model.eval() + all_metrics = { + "loss": [], + "disto_loss": [], + "distogram_acc": [], + "contact_precision_8A": [], + "contact_recall_8A": [], + "lddt_from_distogram": [], + "mean_distance_error": [], + } + + progress = tqdm(eval_dataloader, desc="Evaluating", disable=not dist_config.is_main_process()) + + with torch.no_grad(): + for batch in progress: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + r_dict = model(batch, num_recycling=args.model.get("num_recycling", 0)) + + # Distogram loss + disto_loss = compute_distogram_loss( + preds=r_dict["preds"], + coords=batch["coords"], + mask=batch["mask"], + no_bins=args.model.no_bins, + ) + + # Structure quality metrics + metrics = compute_distogram_metrics( + preds=r_dict["preds"].float(), + coords=batch["coords"], + mask=batch["mask"], + no_bins=args.model.no_bins, + ) + + all_metrics["loss"].append(disto_loss.item()) + all_metrics["disto_loss"].append(disto_loss.item()) + for key, value in metrics.items(): + all_metrics[key].append(value.item()) + + progress.set_postfix( + { + "loss": f"{disto_loss.item():.3f}", + "lddt": f"{metrics['lddt_from_distogram'].item():.3f}", + } + ) + + # Aggregate metrics + summary = {} + for key, values in all_metrics.items(): + if values: + summary[f"eval/{key}"] = sum(values) / len(values) + + # Log to WandB and stdout + if dist_config.is_main_process(): + wandb.log(summary) + wandb.finish() + + if dist_config.local_rank == 0: + logger.info("=== Evaluation Results ===") + logger.info("Batches evaluated: %d", len(all_metrics["loss"])) + for key, value in summary.items(): + logger.info(" %s: %.4f", key, value) + + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/esm2_minifold_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_minifold_te/fp4_debugging_stats.yaml new file mode 100644 index 0000000000..67425affb1 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/fp4_debugging_stats.yaml @@ -0,0 +1,31 @@ +example_fp4_tensor_stat_collection: + enabled: True + layers: + # Match MiniFold te.Linear sublayers in FP4 blocks + layer_types: [pi, gi, po, go, fc1, fc2] + transformer_engine: + LogNvfp4TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, mse] + freq: 100 + - tensor: gradient + stats: [underflows%, mse] + freq: 100 + +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Match MiniFold te.Linear sublayers in FP8 blocks + layer_types: [pi, gi, po, go, fc1, fc2] + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 + - tensor: gradient + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 diff --git a/bionemo-recipes/recipes/esm2_minifold_te/fp8_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_minifold_te/fp8_debugging_stats.yaml new file mode 100644 index 0000000000..a3339c0804 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/fp8_debugging_stats.yaml @@ -0,0 +1,23 @@ +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Match the te.Linear sublayers within MiniFormer blocks + layer_types: [pi, gi, po, go, fc1, fc2] + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 10 + - tensor: gradient + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 10 + - tensor: weight + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 10 + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] + tensors: [dgrad, wgrad] + freq: 1 diff --git a/bionemo-recipes/recipes/esm2_minifold_te/heads_te.py b/bionemo-recipes/recipes/esm2_minifold_te/heads_te.py new file mode 100644 index 0000000000..27deb8233e --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/heads_te.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import transformer_engine.pytorch as te + +from loss import compute_plddt +from minifold_utils import init +from te_utils import te_layernorm_nd, te_linear_nd + + +class PerResidueLDDTCaPredictorTE(nn.Module): + """TE version of PerResidueLDDTCaPredictor.""" + + def __init__(self, no_bins, c_in, c_hidden, params_dtype=torch.float32): + super().__init__() + + self.no_bins = no_bins + self.c_in = c_in + self.c_hidden = c_hidden + + self.layer_norm = te.LayerNorm(self.c_in, eps=1e-5, params_dtype=params_dtype) + self.linear_1 = te.Linear(self.c_in, self.c_hidden, params_dtype=params_dtype) + self.linear_2 = te.Linear(self.c_hidden, self.c_hidden, params_dtype=params_dtype) + self.linear_3 = te.Linear(self.c_hidden, self.no_bins, params_dtype=params_dtype) + + init.he_normal_init_(self.linear_1.weight) + init.he_normal_init_(self.linear_2.weight) + init.final_init_(self.linear_3.weight) + + init.bias_init_zero_(self.linear_1.bias) + init.bias_init_zero_(self.linear_2.bias) + init.bias_init_zero_(self.linear_3.bias) + + self.relu = nn.ReLU() + + def forward(self, s): + s = te_layernorm_nd(self.layer_norm, s) + s = te_linear_nd(self.linear_1, s) + s = self.relu(s) + s = te_linear_nd(self.linear_2, s) + s = self.relu(s) + s = te_linear_nd(self.linear_3, s) + return s + + +class AuxiliaryHeadsTE(nn.Module): + """TE version of AuxiliaryHeads.""" + + def __init__(self, config): + super().__init__() + self.plddt = PerResidueLDDTCaPredictorTE( + **config["lddt"], + ) + self.config = config + + def forward(self, outputs): + aux_out = {} + lddt_logits = self.plddt(outputs["sm"]["single"]) + aux_out["lddt_logits"] = lddt_logits + aux_out["plddt"] = compute_plddt(lddt_logits) + return aux_out diff --git a/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/L0_sanity.yaml b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/L0_sanity.yaml new file mode 100644 index 0000000000..0cdbfe5bb7 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/L0_sanity.yaml @@ -0,0 +1,92 @@ +# ESM2-MiniFold TE: L0 sanity check config +# Quick test with small model and synthetic data +# Usage: python train_fsdp2.py --config-name L0_sanity + +# Small ESM-2 backbone (8M params) +esm_model_name: facebook/esm2_t6_8M_UR50D + +# Short training +num_train_steps: 50 + +# Small model +model: + c_s: 512 + c_z: 64 + num_blocks: 2 + no_bins: 64 + use_structure_module: false + num_recycling: 0 + +# Synthetic dataset +dataset: + dataset_type: synthetic + micro_batch_size: 2 + max_seq_length: 64 + num_workers: 0 + num_samples: 200 + +# Optimizer +optimizer: + folding_lr: 1.0e-4 + struct_lr: 1.0e-4 + backbone_lr: 3.0e-5 + betas: [0.9, 0.98] + eps: 1.0e-8 + weight_decay: 0.01 + +# LR scheduler +lr_scheduler_kwargs: + num_warmup_steps: 10 + num_training_steps: ${num_train_steps} + +# Checkpointing +checkpoint: + ckpt_dir: ./checkpoints + save_final_model: false + resume_from_checkpoint: false + save_every_n_steps: 0 + max_checkpoints: 1 + +use_fp32_master_weights: false + +fp8_config: + enabled: false + fp8_recipe: transformer_engine.common.recipe.DelayedScaling + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + +fp8_layers: null +fp4_layers: null + +component_precision: + tri_proj: false # TriangularUpdateTE input/output projections (pi, po te.Linear layers) + tri_gate: false # TriangularUpdateTE sigmoid gates (gi, go te.Linear layers) + tri_einsum: "off" # "off" = FP32 (default), "bf16" = ambient dtype (recommended) + ffn: true # TransitionUpdateTE FFN (fc1, fc2 te.Linear layers) + struct_attn: false # StructureModuleTE attention projections (proj, o_proj, g_proj te.Linear layers) + struct_ffn: false # StructureModuleTE transition MLP (fc1, fc2 te.Linear layers) + seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers) + dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper + +quant_stats_config: + enabled: false + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats + log_to_wandb: false + log_heatmap: false + +# Log every step for sanity check +logger: + frequency: 1 + +# Wandb offline for CI +wandb_init_args: + project: esm2_minifold_te + name: L0_sanity + mode: offline diff --git a/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/defaults.yaml new file mode 100644 index 0000000000..123404aac7 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/defaults.yaml @@ -0,0 +1,101 @@ +# ESM2-MiniFold TE: Structure prediction training config +# Usage: torchrun --nproc_per_node=8 train_fsdp2.py --config-name defaults + +# ESM-2 backbone +esm_model_name: facebook/esm2_t33_650M_UR50D + +# Training +num_train_steps: 500_000 + +# Model architecture +model: + c_s: 1024 # Sequence feature dimension + c_z: 128 # Pair feature dimension + num_blocks: 48 # MiniFormer blocks + no_bins: 64 # Distogram bins + use_structure_module: false # Stage 1: distogram only + num_recycling: 0 # No recycling (Stage 1) + +# Dataset +dataset: + dataset_type: parquet + parquet_path: ??? + tokenizer_name: ${esm_model_name} + micro_batch_size: 4 + max_seq_length: 256 + num_workers: 4 + num_samples: 10000 # For synthetic dataset + +# Optimizer +optimizer: + folding_lr: 1.0e-4 + struct_lr: 1.0e-4 + backbone_lr: 3.0e-5 # Only used if backbone unfrozen + betas: [0.9, 0.98] + eps: 1.0e-8 + weight_decay: 0.01 + +# LR scheduler +lr_scheduler_kwargs: + num_warmup_steps: 2_000 + num_training_steps: ${num_train_steps} + +# Checkpointing +checkpoint: + ckpt_dir: ./checkpoints + save_final_model: true + resume_from_checkpoint: true + save_every_n_steps: 1_000 + max_checkpoints: 5 + +# FP32 master weights (optimizer keeps FP32 copies, forward/backward in BF16) +use_fp32_master_weights: false + +# Per-block quantization precision +fp8_config: + enabled: false + fp8_recipe: transformer_engine.common.recipe.DelayedScaling + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + +# 1-indexed lists of which MiniFormer blocks use FP8/FP4. +# null means "all blocks" when the corresponding config is enabled. +fp8_layers: null +fp4_layers: null + +# Per-component precision overrides within FP8/FP4 blocks. +# Components set to true run in the block's precision; false keeps them in BF16. +# Only meaningful when block-level FP8/FP4 is enabled. +component_precision: + tri_proj: false # TriangularUpdateTE input/output projections (pi, po te.Linear layers) + tri_gate: false # TriangularUpdateTE sigmoid gates (gi, go te.Linear layers) + tri_einsum: "off" # "off" = FP32 (default), "bf16" = ambient dtype (recommended) + ffn: true # TransitionUpdateTE FFN (fc1, fc2 te.Linear layers) + struct_attn: false # StructureModuleTE attention projections (proj, o_proj, g_proj te.Linear layers) + struct_ffn: false # StructureModuleTE transition MLP (fc1, fc2 te.Linear layers) + seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers) + dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper + +# Quantization stats logging (requires nvdlfw_inspect) +quant_stats_config: + enabled: false + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats + log_to_wandb: false + log_heatmap: false + +# Logging +logger: + frequency: 100 + +# Wandb +wandb_init_args: + project: esm2_minifold_te + name: null + mode: online diff --git a/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/eval.yaml b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/eval.yaml new file mode 100644 index 0000000000..e8fd148c69 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/eval.yaml @@ -0,0 +1,63 @@ +# ESM2-MiniFold TE: Post-training evaluation on held-out structures +# Usage: torchrun --nproc_per_node=2 eval_fsdp2.py checkpoint.ckpt_dir=/path/to/checkpoint + +esm_model_name: facebook/esm2_t33_650M_UR50D + +model: + c_s: 1024 + c_z: 128 + num_blocks: 8 + no_bins: 64 + use_structure_module: false + num_recycling: 0 + +eval_dataset: + dataset_type: parquet + parquet_path: data/eval_structures.parquet + tokenizer_name: ${esm_model_name} + micro_batch_size: 4 + max_seq_length: 256 + num_workers: 2 + shuffle: false + drop_last: false + +checkpoint: + ckpt_dir: ??? # required: path to trained checkpoint or final model + checkpoint_type: fsdp2 # "fsdp2" for distributed checkpoints, "safetensors" for exported model + +fp8_config: + enabled: false + fp8_recipe: transformer_engine.common.recipe.DelayedScaling + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + +fp8_layers: null +fp4_layers: null + +component_precision: + tri_proj: false # TriangularUpdateTE input/output projections (pi, po te.Linear layers) + tri_gate: false # TriangularUpdateTE sigmoid gates (gi, go te.Linear layers) + tri_einsum: "off" # "off" = FP32 (default), "bf16" = ambient dtype (recommended) + ffn: true # TransitionUpdateTE FFN (fc1, fc2 te.Linear layers) + struct_attn: false # StructureModuleTE attention projections (proj, o_proj, g_proj te.Linear layers) + struct_ffn: false # StructureModuleTE transition MLP (fc1, fc2 te.Linear layers) + seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers) + dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper + +quant_stats_config: + enabled: false + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats + log_to_wandb: false + log_heatmap: false + +wandb_init_args: + project: esm2_minifold_te + name: eval_${now:%Y%m%d_%H%M%S} + mode: online diff --git a/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100.yaml b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100.yaml new file mode 100644 index 0000000000..4cdbb2b178 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100.yaml @@ -0,0 +1,81 @@ +# ESM2-MiniFold TE: 100-step exploratory run +# 2x RTX 5090, frozen ESM-2 650M, synthetic data +# Usage: torchrun --nproc_per_node=2 train_fsdp2.py --config-name run_100 + +esm_model_name: facebook/esm2_t33_650M_UR50D + +num_train_steps: 100 + +model: + c_s: 1024 + c_z: 128 + num_blocks: 8 + no_bins: 64 + use_structure_module: false + num_recycling: 0 + +dataset: + dataset_type: synthetic + micro_batch_size: 2 + max_seq_length: 128 + num_workers: 2 + num_samples: 1000 + +optimizer: + folding_lr: 1.0e-4 + struct_lr: 1.0e-4 + backbone_lr: 3.0e-5 + betas: [0.9, 0.98] + eps: 1.0e-8 + weight_decay: 0.01 + +lr_scheduler_kwargs: + num_warmup_steps: 10 + num_training_steps: ${num_train_steps} + +checkpoint: + ckpt_dir: null + save_final_model: false + resume_from_checkpoint: false + save_every_n_steps: 0 + max_checkpoints: 1 + +fp8_config: + enabled: false + fp8_recipe: transformer_engine.common.recipe.DelayedScaling + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + +fp8_layers: null +fp4_layers: null + +component_precision: + tri_proj: false # TriangularUpdateTE input/output projections (pi, po te.Linear layers) + tri_gate: false # TriangularUpdateTE sigmoid gates (gi, go te.Linear layers) + tri_einsum: "off" # "off" = FP32 (default), "bf16" = ambient dtype (recommended) + ffn: true # TransitionUpdateTE FFN (fc1, fc2 te.Linear layers) + struct_attn: false # StructureModuleTE attention projections (proj, o_proj, g_proj te.Linear layers) + struct_ffn: false # StructureModuleTE transition MLP (fc1, fc2 te.Linear layers) + seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers) + dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper + +quant_stats_config: + enabled: false + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats + log_to_wandb: false + log_heatmap: false + +logger: + frequency: 5 + +wandb_init_args: + project: esm2_minifold_te + name: run_100_650M_synthetic + mode: offline diff --git a/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100_real.yaml b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100_real.yaml new file mode 100644 index 0000000000..137109515b --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100_real.yaml @@ -0,0 +1,88 @@ +# ESM2-MiniFold TE: 100-step run with REAL PDB data +# 2x RTX 5090, frozen ESM-2 650M +# Usage: +# Parquet (default, faster): torchrun --nproc_per_node=2 train_fsdp2.py --config-name run_100_real +# MmCIF (on-the-fly parsing): torchrun --nproc_per_node=2 train_fsdp2.py --config-name run_100_real dataset.dataset_type=mmcif + +esm_model_name: facebook/esm2_t33_650M_UR50D + +num_train_steps: 100 + +model: + c_s: 1024 + c_z: 128 + num_blocks: 8 + no_bins: 64 + use_structure_module: false + num_recycling: 0 + +dataset: + dataset_type: parquet # "parquet" (fast, pre-processed) or "mmcif" (on-the-fly BioPython parsing) + parquet_path: data/pdb_structures.parquet + cif_dir: data/cif_files + tokenizer_name: ${esm_model_name} + micro_batch_size: 2 + max_seq_length: 256 + num_workers: 2 + num_samples: 1000 + +optimizer: + folding_lr: 1.0e-4 + struct_lr: 1.0e-4 + backbone_lr: 3.0e-5 + betas: [0.9, 0.98] + eps: 1.0e-8 + weight_decay: 0.01 + +lr_scheduler_kwargs: + num_warmup_steps: 10 + num_training_steps: ${num_train_steps} + +checkpoint: + ckpt_dir: null + save_final_model: false + resume_from_checkpoint: false + save_every_n_steps: 0 + max_checkpoints: 1 + +use_fp32_master_weights: false + +fp8_config: + enabled: false + fp8_recipe: transformer_engine.common.recipe.DelayedScaling + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + +fp8_layers: null +fp4_layers: null + +component_precision: + tri_proj: false # TriangularUpdateTE input/output projections (pi, po te.Linear layers) + tri_gate: false # TriangularUpdateTE sigmoid gates (gi, go te.Linear layers) + tri_einsum: "off" # "off" = FP32 (default), "bf16" = ambient dtype (recommended) + ffn: true # TransitionUpdateTE FFN (fc1, fc2 te.Linear layers) + struct_attn: false # StructureModuleTE attention projections (proj, o_proj, g_proj te.Linear layers) + struct_ffn: false # StructureModuleTE transition MLP (fc1, fc2 te.Linear layers) + seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers) + dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper + +quant_stats_config: + enabled: false + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats + log_to_wandb: false + log_heatmap: false + +logger: + frequency: 5 + +wandb_init_args: + project: esm2_minifold_te + name: run_100_650M_real_pdb + mode: online diff --git a/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100_real_3B.yaml b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100_real_3B.yaml new file mode 100644 index 0000000000..d5c57c5a73 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100_real_3B.yaml @@ -0,0 +1,87 @@ +# ESM2-MiniFold TE: 100-step run with REAL PDB data +# Matches original MiniFold: ESM-2 3B backbone, 48 MiniFormer blocks +# Usage: +# torchrun --nproc_per_node=8 train_fsdp2.py --config-name run_100_real_3B + +esm_model_name: facebook/esm2_t36_3B_UR50D + +num_train_steps: 100 + +model: + c_s: 1024 + c_z: 128 + num_blocks: 48 + no_bins: 64 + use_structure_module: false + num_recycling: 0 + +dataset: + dataset_type: parquet + parquet_path: data/pdb_structures.parquet + cif_dir: data/cif_files + tokenizer_name: ${esm_model_name} + micro_batch_size: 2 + max_seq_length: 256 + num_workers: 2 + num_samples: 1000 + +optimizer: + folding_lr: 1.0e-4 + struct_lr: 1.0e-4 + backbone_lr: 3.0e-5 + betas: [0.9, 0.98] + eps: 1.0e-8 + weight_decay: 0.0 + +lr_scheduler_kwargs: + num_warmup_steps: 10 + num_training_steps: ${num_train_steps} + +checkpoint: + ckpt_dir: null + save_final_model: false + resume_from_checkpoint: false + save_every_n_steps: 0 + max_checkpoints: 1 + +use_fp32_master_weights: false + +fp8_config: + enabled: false + fp8_recipe: transformer_engine.common.recipe.DelayedScaling + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + +fp8_layers: null +fp4_layers: null + +component_precision: + tri_proj: false # TriangularUpdateTE input/output projections (pi, po te.Linear layers) + tri_gate: false # TriangularUpdateTE sigmoid gates (gi, go te.Linear layers) + tri_einsum: "off" # "off" = FP32 (default), "bf16" = ambient dtype (recommended) + ffn: true # TransitionUpdateTE FFN (fc1, fc2 te.Linear layers) + struct_attn: false # StructureModuleTE attention projections (proj, o_proj, g_proj te.Linear layers) + struct_ffn: false # StructureModuleTE transition MLP (fc1, fc2 te.Linear layers) + seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers) + dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper + +quant_stats_config: + enabled: false + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats + log_to_wandb: false + log_heatmap: false + +logger: + frequency: 5 + +wandb_init_args: + project: esm2_minifold_te + name: run_100_3B_real_pdb + mode: online diff --git a/bionemo-recipes/recipes/esm2_minifold_te/loss.py b/bionemo-recipes/recipes/esm2_minifold_te/loss.py new file mode 100644 index 0000000000..898ed183f6 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/loss.py @@ -0,0 +1,1424 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Optional, Tuple + +import ml_collections +import numpy as np +import torch +import torch.nn as nn + +from minifold_utils import residue_constants +from minifold_utils.rigid_utils import Rigid +from minifold_utils.tensor_utils import ( + masked_mean, + permute_final_dims, + tensor_tree_map, + tree_map, +) + + +def softmax_cross_entropy(logits, labels): + loss = -1 * torch.sum( + labels * torch.nn.functional.log_softmax(logits, dim=-1), + dim=-1, + ) + return loss + + +def sigmoid_cross_entropy(logits, labels): + logits_dtype = logits.dtype + logits = logits.double() + labels = labels.double() + log_p = torch.nn.functional.logsigmoid(logits) + # log_p = torch.log(torch.sigmoid(logits)) + log_not_p = torch.nn.functional.logsigmoid(-1 * logits) + # log_not_p = torch.log(torch.sigmoid(-logits)) + loss = (-1.0 * labels) * log_p - (1.0 - labels) * log_not_p + loss = loss.to(dtype=logits_dtype) + return loss + + +def torsion_angle_loss( + a, # [*, N, 7, 2] + a_gt, # [*, N, 7, 2] + a_alt_gt, # [*, N, 7, 2] +): + # [*, N, 7] + norm = torch.norm(a, dim=-1) + + # [*, N, 7, 2] + a = a / norm.unsqueeze(-1) + + # [*, N, 7] + diff_norm_gt = torch.norm(a - a_gt, dim=-1) + diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1) + min_diff = torch.minimum(diff_norm_gt**2, diff_norm_alt_gt**2) + + # [*] + l_torsion = torch.mean(min_diff, dim=(-1, -2)) + l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2)) + + an_weight = 0.02 + return l_torsion + an_weight * l_angle_norm + + +def compute_fape( + pred_frames: Rigid, + target_frames: Rigid, + frames_mask: torch.Tensor, + pred_positions: torch.Tensor, + target_positions: torch.Tensor, + positions_mask: torch.Tensor, + length_scale: float, + l1_clamp_distance: Optional[float] = None, + eps=1e-8, +) -> torch.Tensor: + """Computes FAPE loss. + + Args: + pred_frames: + [*, N_frames] Rigid object of predicted frames + target_frames: + [*, N_frames] Rigid object of ground truth frames + frames_mask: + [*, N_frames] binary mask for the frames + pred_positions: + [*, N_pts, 3] predicted atom positions + target_positions: + [*, N_pts, 3] ground truth positions + positions_mask: + [*, N_pts] positions mask + length_scale: + Length scale by which the loss is divided + l1_clamp_distance: + Cutoff above which distance errors are disregarded + eps: + Small value used to regularize denominators + Returns: + [*] loss tensor + """ + # [*, N_frames, N_pts, 3] + local_pred_pos = pred_frames.invert()[..., None].apply( + pred_positions[..., None, :, :], + ) + local_target_pos = target_frames.invert()[..., None].apply( + target_positions[..., None, :, :], + ) + + error_dist = torch.sqrt(torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps) + + if l1_clamp_distance is not None: + error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance) + + normed_error = error_dist / length_scale + normed_error = normed_error * frames_mask[..., None] + normed_error = normed_error * positions_mask[..., None, :] + + # FP16-friendly averaging. Roughly equivalent to: + # + # norm_factor = ( + # torch.sum(frames_mask, dim=-1) * + # torch.sum(positions_mask, dim=-1) + # ) + # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor) + # + # ("roughly" because eps is necessarily duplicated in the latter) + normed_error = torch.sum(normed_error, dim=-1) + normed_error = normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None] + normed_error = torch.sum(normed_error, dim=-1) + normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1)) + + return normed_error + + +def backbone_loss( + backbone_rigid_tensor: torch.Tensor, + backbone_rigid_mask: torch.Tensor, + traj: torch.Tensor, + use_clamped_fape: Optional[torch.Tensor] = None, + clamp_distance: float = 10.0, + loss_unit_distance: float = 10.0, + eps: float = 1e-4, + **kwargs, +) -> torch.Tensor: + pred_aff = Rigid.from_tensor_4x4(traj) + # pred_aff = Rigid.from_tensor_7(traj) + # pred_aff = Rigid( + # Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None), + # pred_aff.get_trans(), + # ) + + # DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of + # backbone tensor, normalizes it, and then turns it back to a rotation + # matrix. To avoid a potentially numerically unstable rotation matrix + # to quaternion conversion, we just use the original rotation matrix + # outright. This one hasn't been composed a bunch of times, though, so + # it might be fine. + gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor) + + fape_loss = compute_fape( + pred_aff, + gt_aff[None], + backbone_rigid_mask[None], + pred_aff.get_trans(), + gt_aff[None].get_trans(), + backbone_rigid_mask[None], + l1_clamp_distance=clamp_distance, + length_scale=loss_unit_distance, + eps=eps, + ) + if use_clamped_fape is not None: + unclamped_fape_loss = compute_fape( + pred_aff, + gt_aff[None], + backbone_rigid_mask[None], + pred_aff.get_trans(), + gt_aff[None].get_trans(), + backbone_rigid_mask[None], + l1_clamp_distance=None, + length_scale=loss_unit_distance, + eps=eps, + ) + + fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (1 - use_clamped_fape) + + # Average over the batch dimension + fape_loss = torch.mean(fape_loss) + + return fape_loss + + +def sidechain_loss( + sidechain_frames: torch.Tensor, + sidechain_atom_pos: torch.Tensor, + rigidgroups_gt_frames: torch.Tensor, + rigidgroups_alt_gt_frames: torch.Tensor, + rigidgroups_gt_exists: torch.Tensor, + renamed_atom14_gt_positions: torch.Tensor, + renamed_atom14_gt_exists: torch.Tensor, + alt_naming_is_better: torch.Tensor, + clamp_distance: float = 10.0, + length_scale: float = 10.0, + eps: float = 1e-4, + **kwargs, +) -> torch.Tensor: + renamed_gt_frames = ( + 1.0 - alt_naming_is_better[..., None, None, None] + ) * rigidgroups_gt_frames + alt_naming_is_better[..., None, None, None] * rigidgroups_alt_gt_frames + + # Steamroll the inputs + sidechain_frames = sidechain_frames[-1] + batch_dims = sidechain_frames.shape[:-4] + sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4) + sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames) + renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4) + renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames) + rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1) + sidechain_atom_pos = sidechain_atom_pos[-1] + sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3) + renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(*batch_dims, -1, 3) + renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1) + + fape = compute_fape( + sidechain_frames, + renamed_gt_frames, + rigidgroups_gt_exists, + sidechain_atom_pos, + renamed_atom14_gt_positions, + renamed_atom14_gt_exists, + l1_clamp_distance=clamp_distance, + length_scale=length_scale, + eps=eps, + ) + + return fape + + +def fape_loss( + out: Dict[str, torch.Tensor], + batch: Dict[str, torch.Tensor], + config: ml_collections.ConfigDict, +) -> torch.Tensor: + bb_loss = backbone_loss( + traj=out["sm"]["frames"], + **{**batch, **config.backbone}, + ) + + sc_loss = sidechain_loss( + out["sm"]["sidechain_frames"], + out["sm"]["positions"], + **{**batch, **config.sidechain}, + ) + + loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def supervised_chi_loss( + angles_sin_cos: torch.Tensor, + unnormalized_angles_sin_cos: torch.Tensor, + aatype: torch.Tensor, + seq_mask: torch.Tensor, + chi_mask: torch.Tensor, + chi_angles_sin_cos: torch.Tensor, + chi_weight: float, + angle_norm_weight: float, + eps=1e-6, + **kwargs, +) -> torch.Tensor: + """Implements Algorithm 27 (torsionAngleLoss) + + Args: + angles_sin_cos: + [*, N, 7, 2] predicted angles + unnormalized_angles_sin_cos: + The same angles, but unnormalized + aatype: + [*, N] residue indices + seq_mask: + [*, N] sequence mask + chi_mask: + [*, N, 7] angle mask + chi_angles_sin_cos: + [*, N, 7, 2] ground truth angles + chi_weight: + Weight for the angle component of the loss + angle_norm_weight: + Weight for the normalization component of the loss + Returns: + [*] loss tensor + """ + pred_angles = angles_sin_cos[..., 3:, :] + residue_type_one_hot = torch.nn.functional.one_hot( + aatype, + residue_constants.restype_num + 1, + ) + chi_pi_periodic = torch.einsum( + "...ij,jk->ik", + residue_type_one_hot.type(angles_sin_cos.dtype), + angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic), + ) + + true_chi = chi_angles_sin_cos[None] + + shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1) + true_chi_shifted = shifted_mask * true_chi + sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1) + sq_chi_error_shifted = torch.sum((true_chi_shifted - pred_angles) ** 2, dim=-1) + sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted) + + # The ol' switcheroo + sq_chi_error = sq_chi_error.permute(*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1) + + sq_chi_loss = masked_mean(chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)) + + loss = chi_weight * sq_chi_loss + + angle_norm = torch.sqrt(torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps) + norm_error = torch.abs(angle_norm - 1.0) + norm_error = norm_error.permute(*range(len(norm_error.shape))[1:-2], 0, -2, -1) + angle_norm_loss = masked_mean(seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)) + + loss = loss + angle_norm_weight * angle_norm_loss + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def compute_plddt(logits: torch.Tensor) -> torch.Tensor: + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bounds = torch.arange(start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device) + probs = torch.nn.functional.softmax(logits, dim=-1) + pred_lddt_ca = torch.sum( + probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape), + dim=-1, + ) + return pred_lddt_ca * 100 + + +def lddt( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + n = all_atom_mask.shape[-2] + dmat_true = torch.sqrt( + eps + + torch.sum( + (all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :]) ** 2, + dim=-1, + ) + ) + + dmat_pred = torch.sqrt( + eps + + torch.sum( + (all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2, + dim=-1, + ) + ) + dists_to_score = ( + (dmat_true < cutoff) + * all_atom_mask + * permute_final_dims(all_atom_mask, (1, 0)) + * (1.0 - torch.eye(n, device=all_atom_mask.device)) + ) + + dist_l1 = torch.abs(dmat_true - dmat_pred) + + score = ( + (dist_l1 < 0.5).type(dist_l1.dtype) + + (dist_l1 < 1.0).type(dist_l1.dtype) + + (dist_l1 < 2.0).type(dist_l1.dtype) + + (dist_l1 < 4.0).type(dist_l1.dtype) + ) + score = score * 0.25 + + dims = (-1,) if per_residue else (-2, -1) + norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) + score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) + + return score + + +def lddt_ca( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + ca_pos = residue_constants.atom_order["CA"] + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + all_atom_positions = all_atom_positions[..., ca_pos, :] + all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim + + return lddt( + all_atom_pred_pos, + all_atom_positions, + all_atom_mask, + cutoff=cutoff, + eps=eps, + per_residue=per_residue, + ) + + +def lddt_loss( + logits: torch.Tensor, + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + resolution: torch.Tensor, + cutoff: float = 15.0, + no_bins: int = 50, + min_resolution: float = 0.1, + max_resolution: float = 3.0, + eps: float = 1e-10, + **kwargs, +) -> torch.Tensor: + n = all_atom_mask.shape[-2] + + ca_pos = residue_constants.atom_order["CA"] + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + all_atom_positions = all_atom_positions[..., ca_pos, :] + all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim + + score = lddt(all_atom_pred_pos, all_atom_positions, all_atom_mask, cutoff=cutoff, eps=eps) + + score = score.detach() + + bin_index = torch.floor(score * no_bins).long() + bin_index = torch.clamp(bin_index, max=(no_bins - 1)) + lddt_ca_one_hot = torch.nn.functional.one_hot(bin_index, num_classes=no_bins) + + errors = softmax_cross_entropy(logits, lddt_ca_one_hot) + all_atom_mask = all_atom_mask.squeeze(-1) + loss = torch.sum(errors * all_atom_mask, dim=-1) / (eps + torch.sum(all_atom_mask, dim=-1)) + + """ + loss = loss * ( + (resolution >= min_resolution) & (resolution <= max_resolution) + ) + """ + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def distogram_loss( + logits, + pseudo_beta, + pseudo_beta_mask, + min_bin=2.3125, + max_bin=21.6875, + no_bins=64, + eps=1e-6, + **kwargs, +): + boundaries = torch.linspace( + min_bin, + max_bin, + no_bins - 1, + device=logits.device, + ) + boundaries = boundaries**2 + + dists = torch.sum( + (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2, + dim=-1, + keepdims=True, + ) + + true_bins = torch.sum(dists > boundaries, dim=-1) + + errors = softmax_cross_entropy( + logits, + torch.nn.functional.one_hot(true_bins, no_bins), + ) + + square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :] + + # FP16-friendly sum. Equivalent to: + # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) / + # (eps + torch.sum(square_mask, dim=(-1, -2)))) + denom = eps + torch.sum(square_mask, dim=(-1, -2)) + mean = errors * square_mask + mean = torch.sum(mean, dim=-1) + mean = mean / denom[..., None] + mean = torch.sum(mean, dim=-1) + + # Average over the batch dimensions + mean = torch.mean(mean) + + return mean + + +def _calculate_bin_centers(boundaries: torch.Tensor): + step = boundaries[1] - boundaries[0] + bin_centers = boundaries + step / 2 + bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0) + return bin_centers + + +def _calculate_expected_aligned_error( + alignment_confidence_breaks: torch.Tensor, + aligned_distance_error_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + bin_centers = _calculate_bin_centers(alignment_confidence_breaks) + return ( + torch.sum(aligned_distance_error_probs * bin_centers, dim=-1), + bin_centers[-1], + ) + + +def compute_predicted_aligned_error( + logits: torch.Tensor, + max_bin: int = 31, + no_bins: int = 64, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes aligned confidence metrics from logits. + + Args: + logits: [*, num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + max_bin: Maximum bin value + no_bins: Number of bins + Returns: + aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [*, num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: [*] the maximum predicted error possible. + """ + boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) + + aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1) + ( + predicted_aligned_error, + max_predicted_aligned_error, + ) = _calculate_expected_aligned_error( + alignment_confidence_breaks=boundaries, + aligned_distance_error_probs=aligned_confidence_probs, + ) + + return { + "aligned_confidence_probs": aligned_confidence_probs, + "predicted_aligned_error": predicted_aligned_error, + "max_predicted_aligned_error": max_predicted_aligned_error, + } + + +def compute_tm( + logits: torch.Tensor, + residue_weights: Optional[torch.Tensor] = None, + max_bin: int = 31, + no_bins: int = 64, + eps: float = 1e-8, + **kwargs, +) -> torch.Tensor: + if residue_weights is None: + residue_weights = logits.new_ones(logits.shape[-2]) + + boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) + + bin_centers = _calculate_bin_centers(boundaries) + clipped_n = max(torch.sum(residue_weights), 19) + + d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8 + + probs = torch.nn.functional.softmax(logits, dim=-1) + + tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2)) + predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) + + normed_residue_mask = residue_weights / (eps + residue_weights.sum()) + per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) + + weighted = per_alignment * residue_weights + + argmax = (weighted == torch.max(weighted)).nonzero()[0] + return per_alignment[tuple(argmax)] + + +def tm_loss( + logits, + final_affine_tensor, + backbone_rigid_tensor, + backbone_rigid_mask, + resolution, + max_bin=31, + no_bins=64, + min_resolution: float = 0.1, + max_resolution: float = 3.0, + eps=1e-8, + **kwargs, +): + pred_affine = Rigid.from_tensor_7(final_affine_tensor) + backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor) + + def _points(affine): + pts = affine.get_trans()[..., None, :, :] + return affine.invert()[..., None].apply(pts) + + sq_diff = torch.sum((_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1) + + sq_diff = sq_diff.detach() + + boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) + boundaries = boundaries**2 + true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1) + + errors = softmax_cross_entropy(logits, torch.nn.functional.one_hot(true_bins, no_bins)) + + square_mask = backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :] + + loss = torch.sum(errors * square_mask, dim=-1) + scale = 0.5 # hack to help FP16 training along + denom = eps + torch.sum(scale * square_mask, dim=(-1, -2)) + loss = loss / denom[..., None] + loss = torch.sum(loss, dim=-1) + loss = loss * scale + + loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution)) + + # Average over the loss dimension + loss = torch.mean(loss) + + return loss + + +def between_residue_bond_loss( + pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3) + pred_atom_mask: torch.Tensor, # (*, N, 37/14) + residue_index: torch.Tensor, # (*, N) + aatype: torch.Tensor, # (*, N) + tolerance_factor_soft=12.0, + tolerance_factor_hard=12.0, + eps=1e-6, +) -> Dict[str, torch.Tensor]: + """Flat-bottom loss to penalize structural violations between residues. + + This is a loss penalizing any violation of the geometry around the peptide + bond between consecutive amino acids. This loss corresponds to + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + aatype: Amino acid type of given residue + tolerance_factor_soft: soft tolerance factor measured in standard deviations + of pdb distributions + tolerance_factor_hard: hard tolerance factor measured in standard deviations + of pdb distributions + + Returns: + Dict containing: + * 'c_n_loss_mean': Loss for peptide bond length violations + * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned + by CA, C, N + * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned + by C, N, CA + * 'per_residue_loss_sum': sum of all losses for each residue + * 'per_residue_violation_mask': mask denoting all residues with violation + present. + """ + # Get the positions of the relevant backbone atoms. + this_ca_pos = pred_atom_positions[..., :-1, 1, :] + this_ca_mask = pred_atom_mask[..., :-1, 1] + this_c_pos = pred_atom_positions[..., :-1, 2, :] + this_c_mask = pred_atom_mask[..., :-1, 2] + next_n_pos = pred_atom_positions[..., 1:, 0, :] + next_n_mask = pred_atom_mask[..., 1:, 0] + next_ca_pos = pred_atom_positions[..., 1:, 1, :] + next_ca_mask = pred_atom_mask[..., 1:, 1] + has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 + + # Compute loss for the C--N bond. + c_n_bond_length = torch.sqrt(eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1)) + + # The C-N bond to proline has slightly different length because of the ring. + next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"] + gt_length = (~next_is_proline) * residue_constants.between_res_bond_length_c_n[ + 0 + ] + next_is_proline * residue_constants.between_res_bond_length_c_n[1] + gt_stddev = (~next_is_proline) * residue_constants.between_res_bond_length_stddev_c_n[ + 0 + ] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1] + c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2) + c_n_loss_per_residue = torch.nn.functional.relu(c_n_bond_length_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * has_no_gap_mask + c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (torch.sum(mask, dim=-1) + eps) + c_n_violation_mask = mask * (c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)) + + # Compute loss for the angles. + ca_c_bond_length = torch.sqrt(eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1)) + n_ca_bond_length = torch.sqrt(eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1)) + + c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None] + c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None] + n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None] + + ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1) + gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] + gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] + ca_c_n_cos_angle_error = torch.sqrt(eps + (ca_c_n_cos_angle - gt_angle) ** 2) + ca_c_n_loss_per_residue = torch.nn.functional.relu(ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask + ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (torch.sum(mask, dim=-1) + eps) + ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1) + gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] + gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] + c_n_ca_cos_angle_error = torch.sqrt(eps + torch.square(c_n_ca_cos_angle - gt_angle)) + c_n_ca_loss_per_residue = torch.nn.functional.relu(c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask + c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (torch.sum(mask, dim=-1) + eps) + c_n_ca_violation_mask = mask * (c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + # Compute a per residue loss (equally distribute the loss to both + # neighbouring residues). + per_residue_loss_sum = c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue + per_residue_loss_sum = 0.5 * ( + torch.nn.functional.pad(per_residue_loss_sum, (0, 1)) + torch.nn.functional.pad(per_residue_loss_sum, (1, 0)) + ) + + # Compute hard violations. + violation_mask = torch.max( + torch.stack( + [c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask], + dim=-2, + ), + dim=-2, + )[0] + violation_mask = torch.maximum( + torch.nn.functional.pad(violation_mask, (0, 1)), + torch.nn.functional.pad(violation_mask, (1, 0)), + ) + + return { + "c_n_loss_mean": c_n_loss, + "ca_c_n_loss_mean": ca_c_n_loss, + "c_n_ca_loss_mean": c_n_ca_loss, + "per_residue_loss_sum": per_residue_loss_sum, + "per_residue_violation_mask": violation_mask, + } + + +def between_residue_clash_loss( + atom14_pred_positions: torch.Tensor, + atom14_atom_exists: torch.Tensor, + atom14_atom_radius: torch.Tensor, + residue_index: torch.Tensor, + overlap_tolerance_soft=1.5, + overlap_tolerance_hard=1.5, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """Loss to penalize steric clashes between residues. + + This is a loss penalizing any steric clashes due to non bonded atoms in + different peptides coming too close. This loss corresponds to the part with + different residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + atom14_atom_radius: Van der Waals radius for each atom. + residue_index: Residue index for given amino acid. + overlap_tolerance_soft: Soft tolerance factor. + overlap_tolerance_hard: Hard tolerance factor. + + Returns: + Dict containing: + * 'mean_loss': average clash loss + * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14) + * 'per_atom_clash_mask': mask whether atom clashes with any other atom + shape (N, 14) + """ + fp_type = atom14_pred_positions.dtype + + # Create the distance matrix. + # (N, N, 14, 14) + dists = torch.sqrt( + eps + + torch.sum( + (atom14_pred_positions[..., :, None, :, None, :] - atom14_pred_positions[..., None, :, None, :, :]) ** 2, + dim=-1, + ) + ) + + # Create the mask for valid distances. + # shape (N, N, 14, 14) + dists_mask = (atom14_atom_exists[..., :, None, :, None] * atom14_atom_exists[..., None, :, None, :]).type(fp_type) + + # Mask out all the duplicate entries in the lower triangular matrix. + # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms + # are handled separately. + dists_mask = dists_mask * (residue_index[..., :, None, None, None] < residue_index[..., None, :, None, None]) + + # Backbone C--N bond between subsequent residues is no clash. + c_one_hot = torch.nn.functional.one_hot(residue_index.new_tensor(2), num_classes=14) + c_one_hot = c_one_hot.reshape(*((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape) + c_one_hot = c_one_hot.type(fp_type) + n_one_hot = torch.nn.functional.one_hot(residue_index.new_tensor(0), num_classes=14) + n_one_hot = n_one_hot.reshape(*((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape) + n_one_hot = n_one_hot.type(fp_type) + + neighbour_mask = (residue_index[..., :, None, None, None] + 1) == residue_index[..., None, :, None, None] + c_n_bonds = neighbour_mask * c_one_hot[..., None, None, :, None] * n_one_hot[..., None, None, None, :] + dists_mask = dists_mask * (1.0 - c_n_bonds) + + # Disulfide bridge between two cysteines is no clash. + cys = residue_constants.restype_name_to_atom14_names["CYS"] + cys_sg_idx = cys.index("SG") + cys_sg_idx = residue_index.new_tensor(cys_sg_idx) + cys_sg_idx = cys_sg_idx.reshape(*((1,) * len(residue_index.shape[:-1])), 1).squeeze(-1) + cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14) + disulfide_bonds = cys_sg_one_hot[..., None, None, :, None] * cys_sg_one_hot[..., None, None, None, :] + dists_mask = dists_mask * (1.0 - disulfide_bonds) + + # Compute the lower bound for the allowed distances. + # shape (N, N, 14, 14) + dists_lower_bound = dists_mask * ( + atom14_atom_radius[..., :, None, :, None] + atom14_atom_radius[..., None, :, None, :] + ) + + # Compute the error. + # shape (N, N, 14, 14) + dists_to_low_error = dists_mask * torch.nn.functional.relu(dists_lower_bound - overlap_tolerance_soft - dists) + + # Compute the mean loss. + # shape () + mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask)) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(dists_to_low_error, axis=(-3, -1)) + + # Compute the hard clash mask. + # shape (N, N, 14, 14) + clash_mask = dists_mask * (dists < (dists_lower_bound - overlap_tolerance_hard)) + + # Compute the per atom clash. + # shape (N, 14) + per_atom_clash_mask = torch.maximum( + torch.amax(clash_mask, axis=(-4, -2)), + torch.amax(clash_mask, axis=(-3, -1)), + ) + + return { + "mean_loss": mean_loss, # shape () + "per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14) + "per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14) + } + + +def within_residue_violations( + atom14_pred_positions: torch.Tensor, + atom14_atom_exists: torch.Tensor, + atom14_dists_lower_bound: torch.Tensor, + atom14_dists_upper_bound: torch.Tensor, + tighten_bounds_for_loss=0.0, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """Loss to penalize steric clashes within residues. + + This is a loss penalizing any steric violations or clashes of non-bonded atoms + in a given peptide. This loss corresponds to the part with + the same residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions ([*, N, 14, 3]): + Predicted positions of atoms in global prediction frame. + atom14_atom_exists ([*, N, 14]): + Mask denoting whether atom at positions exists for given + amino acid type + atom14_dists_lower_bound ([*, N, 14]): + Lower bound on allowed distances. + atom14_dists_upper_bound ([*, N, 14]): + Upper bound on allowed distances + tighten_bounds_for_loss ([*, N]): + Extra factor to tighten loss + + Returns: + Dict containing: + * 'per_atom_loss_sum' ([*, N, 14]): + sum of all clash losses per atom, shape + * 'per_atom_clash_mask' ([*, N, 14]): + mask whether atom clashes with any other atom shape + """ + # Compute the mask for each residue. + dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None] + dists_masks = dists_masks.reshape(*((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape) + dists_masks = atom14_atom_exists[..., :, :, None] * atom14_atom_exists[..., :, None, :] * dists_masks + + # Distance matrix + dists = torch.sqrt( + eps + + torch.sum( + (atom14_pred_positions[..., :, :, None, :] - atom14_pred_positions[..., :, None, :, :]) ** 2, + dim=-1, + ) + ) + + # Compute the loss. + dists_to_low_error = torch.nn.functional.relu(atom14_dists_lower_bound + tighten_bounds_for_loss - dists) + dists_to_high_error = torch.nn.functional.relu(dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)) + loss = dists_masks * (dists_to_low_error + dists_to_high_error) + + # Compute the per atom loss sum. + per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1) + + # Compute the violations mask. + violations = dists_masks * ((dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)) + + # Compute the per atom violations. + per_atom_violations = torch.maximum(torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]) + + return { + "per_atom_loss_sum": per_atom_loss_sum, + "per_atom_violations": per_atom_violations, + } + + +def find_structural_violations( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, + violation_tolerance_factor: float, + clash_overlap_tolerance: float, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes several checks for structural violations.""" + # Compute between residue backbone violations of bonds and angles. + connection_violations = between_residue_bond_loss( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch["atom14_atom_exists"], + residue_index=batch["residue_index"], + aatype=batch["aatype"], + tolerance_factor_soft=violation_tolerance_factor, + tolerance_factor_hard=violation_tolerance_factor, + ) + + # Compute the Van der Waals radius for every atom + # (the first letter of the atom name is the element type). + # Shape: (N, 14). + atomtype_radius = [residue_constants.van_der_waals_radius[name[0]] for name in residue_constants.atom_types] + atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius) + atom14_atom_radius = batch["atom14_atom_exists"] * atomtype_radius[batch["residx_atom14_to_atom37"]] + + # Compute the between residue clash loss. + between_residue_clashes = between_residue_clash_loss( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch["atom14_atom_exists"], + atom14_atom_radius=atom14_atom_radius, + residue_index=batch["residue_index"], + overlap_tolerance_soft=clash_overlap_tolerance, + overlap_tolerance_hard=clash_overlap_tolerance, + ) + + # Compute all within-residue violations (clashes, + # bond length and angle violations). + restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( + overlap_tolerance=clash_overlap_tolerance, + bond_length_tolerance_factor=violation_tolerance_factor, + ) + atom14_atom_exists = batch["atom14_atom_exists"] + atom14_dists_lower_bound = atom14_pred_positions.new_tensor(restype_atom14_bounds["lower_bound"])[batch["aatype"]] + atom14_dists_upper_bound = atom14_pred_positions.new_tensor(restype_atom14_bounds["upper_bound"])[batch["aatype"]] + residue_violations = within_residue_violations( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch["atom14_atom_exists"], + atom14_dists_lower_bound=atom14_dists_lower_bound, + atom14_dists_upper_bound=atom14_dists_upper_bound, + tighten_bounds_for_loss=0.0, + ) + + # Combine them to a single per-residue violation mask (used later for LDDT). + per_residue_violations_mask = torch.max( + torch.stack( + [ + connection_violations["per_residue_violation_mask"], + torch.max(between_residue_clashes["per_atom_clash_mask"], dim=-1)[0], + torch.max(residue_violations["per_atom_violations"], dim=-1)[0], + ], + dim=-1, + ), + dim=-1, + )[0] + + return { + "between_residues": { + "bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # () + "angles_ca_c_n_loss_mean": connection_violations["ca_c_n_loss_mean"], # () + "angles_c_n_ca_loss_mean": connection_violations["c_n_ca_loss_mean"], # () + "connections_per_residue_loss_sum": connection_violations["per_residue_loss_sum"], # (N) + "connections_per_residue_violation_mask": connection_violations["per_residue_violation_mask"], # (N) + "clashes_mean_loss": between_residue_clashes["mean_loss"], # () + "clashes_per_atom_loss_sum": between_residue_clashes["per_atom_loss_sum"], # (N, 14) + "clashes_per_atom_clash_mask": between_residue_clashes["per_atom_clash_mask"], # (N, 14) + }, + "within_residues": { + "per_atom_loss_sum": residue_violations["per_atom_loss_sum"], # (N, 14) + "per_atom_violations": residue_violations["per_atom_violations"], # (N, 14), + }, + "total_per_residue_violations_mask": per_residue_violations_mask, # (N) + } + + +def find_structural_violations_np( + batch: Dict[str, np.ndarray], + atom14_pred_positions: np.ndarray, + config: ml_collections.ConfigDict, +) -> Dict[str, np.ndarray]: + to_tensor = lambda x: torch.tensor(x) + batch = tree_map(to_tensor, batch, np.ndarray) + atom14_pred_positions = to_tensor(atom14_pred_positions) + + out = find_structural_violations(batch, atom14_pred_positions, **config) + + to_np = lambda x: np.array(x) + np_out = tensor_tree_map(to_np, out) + + return np_out + + +def extreme_ca_ca_distance_violations( + pred_atom_positions: torch.Tensor, # (N, 37(14), 3) + pred_atom_mask: torch.Tensor, # (N, 37(14)) + residue_index: torch.Tensor, # (N) + max_angstrom_tolerance=1.5, + eps=1e-6, +) -> torch.Tensor: + """Counts residues whose Ca is a large distance from its neighbour. + + Measures the fraction of CA-CA pairs between consecutive amino acids that are + more than 'max_angstrom_tolerance' apart. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + max_angstrom_tolerance: Maximum distance allowed to not count as violation. + + Returns: + Fraction of consecutive CA-CA pairs with violation. + """ + this_ca_pos = pred_atom_positions[..., :-1, 1, :] + this_ca_mask = pred_atom_mask[..., :-1, 1] + next_ca_pos = pred_atom_positions[..., 1:, 1, :] + next_ca_mask = pred_atom_mask[..., 1:, 1] + has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 + ca_ca_distance = torch.sqrt(eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1)) + violations = (ca_ca_distance - residue_constants.ca_ca) > max_angstrom_tolerance + mask = this_ca_mask * next_ca_mask * has_no_gap_mask + mean = masked_mean(mask, violations, -1) + return mean + + +def compute_violation_metrics( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, # (N, 14, 3) + violations: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: + """Compute several metrics to assess the structural violations.""" + ret = {} + extreme_ca_ca_violations = extreme_ca_ca_distance_violations( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch["atom14_atom_exists"], + residue_index=batch["residue_index"], + ) + ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations + ret["violations_between_residue_bond"] = masked_mean( + batch["seq_mask"], + violations["between_residues"]["connections_per_residue_violation_mask"], + dim=-1, + ) + ret["violations_between_residue_clash"] = masked_mean( + mask=batch["seq_mask"], + value=torch.max( + violations["between_residues"]["clashes_per_atom_clash_mask"], + dim=-1, + )[0], + dim=-1, + ) + ret["violations_within_residue"] = masked_mean( + mask=batch["seq_mask"], + value=torch.max(violations["within_residues"]["per_atom_violations"], dim=-1)[0], + dim=-1, + ) + ret["violations_per_residue"] = masked_mean( + mask=batch["seq_mask"], + value=violations["total_per_residue_violations_mask"], + dim=-1, + ) + return ret + + +def compute_violation_metrics_np( + batch: Dict[str, np.ndarray], + atom14_pred_positions: np.ndarray, + violations: Dict[str, np.ndarray], +) -> Dict[str, np.ndarray]: + to_tensor = lambda x: torch.tensor(x) + batch = tree_map(to_tensor, batch, np.ndarray) + atom14_pred_positions = to_tensor(atom14_pred_positions) + violations = tree_map(to_tensor, violations, np.ndarray) + + out = compute_violation_metrics(batch, atom14_pred_positions, violations) + + to_np = lambda x: np.array(x) + return tree_map(to_np, out, torch.Tensor) + + +def violation_loss( + violations: Dict[str, torch.Tensor], + atom14_atom_exists: torch.Tensor, + eps=1e-6, + **kwargs, +) -> torch.Tensor: + num_atoms = torch.sum(atom14_atom_exists) + l_clash = torch.sum( + violations["between_residues"]["clashes_per_atom_loss_sum"] + + violations["within_residues"]["per_atom_loss_sum"] + ) + l_clash = l_clash / (eps + num_atoms) + loss = ( + violations["between_residues"]["bonds_c_n_loss_mean"] + + violations["between_residues"]["angles_ca_c_n_loss_mean"] + + violations["between_residues"]["angles_c_n_ca_loss_mean"] + + l_clash + ) + + # Average over the batch dimension + mean = torch.mean(loss) + + return mean + + +def compute_renamed_ground_truth( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """Find optimal renaming of ground truth based on the predicted positions. + + Alg. 26 "renameSymmetricGroundTruthAtoms" + + This renamed ground truth is then used for all losses, + such that each loss moves the atoms in the same direction. + + Args: + batch: Dictionary containing: + * atom14_gt_positions: Ground truth positions. + * atom14_alt_gt_positions: Ground truth positions with renaming swaps. + * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by + renaming swaps. + * atom14_gt_exists: Mask for which atoms exist in ground truth. + * atom14_alt_gt_exists: Mask for which atoms exist in ground truth + after renaming. + * atom14_atom_exists: Mask for whether each atom is part of the given + amino acid type. + atom14_pred_positions: Array of atom positions in global frame with shape + Returns: + Dictionary containing: + alt_naming_is_better: Array with 1.0 where alternative swap is better. + renamed_atom14_gt_positions: Array of optimal ground truth positions + after renaming swaps are performed. + renamed_atom14_gt_exists: Mask after renaming swap is performed. + """ + pred_dists = torch.sqrt( + eps + + torch.sum( + (atom14_pred_positions[..., None, :, None, :] - atom14_pred_positions[..., None, :, None, :, :]) ** 2, + dim=-1, + ) + ) + + atom14_gt_positions = batch["atom14_gt_positions"] + gt_dists = torch.sqrt( + eps + + torch.sum( + (atom14_gt_positions[..., None, :, None, :] - atom14_gt_positions[..., None, :, None, :, :]) ** 2, + dim=-1, + ) + ) + + atom14_alt_gt_positions = batch["atom14_alt_gt_positions"] + alt_gt_dists = torch.sqrt( + eps + + torch.sum( + (atom14_alt_gt_positions[..., None, :, None, :] - atom14_alt_gt_positions[..., None, :, None, :, :]) ** 2, + dim=-1, + ) + ) + + lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2) + alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2) + + atom14_gt_exists = batch["atom14_gt_exists"] + atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"] + mask = ( + atom14_gt_exists[..., None, :, None] + * atom14_atom_is_ambiguous[..., None, :, None] + * atom14_gt_exists[..., None, :, None, :] + * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :]) + ) + + per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3)) + alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3)) + + fp_type = atom14_pred_positions.dtype + alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type) + + renamed_atom14_gt_positions = ( + 1.0 - alt_naming_is_better[..., None, None] + ) * atom14_gt_positions + alt_naming_is_better[..., None, None] * atom14_alt_gt_positions + + renamed_atom14_gt_mask = (1.0 - alt_naming_is_better[..., None]) * atom14_gt_exists + alt_naming_is_better[ + ..., None + ] * batch["atom14_alt_gt_exists"] + + return { + "alt_naming_is_better": alt_naming_is_better, + "renamed_atom14_gt_positions": renamed_atom14_gt_positions, + "renamed_atom14_gt_exists": renamed_atom14_gt_mask, + } + + +def experimentally_resolved_loss( + logits: torch.Tensor, + atom37_atom_exists: torch.Tensor, + all_atom_mask: torch.Tensor, + resolution: torch.Tensor, + min_resolution: float, + max_resolution: float, + eps: float = 1e-8, + **kwargs, +) -> torch.Tensor: + errors = sigmoid_cross_entropy(logits, all_atom_mask) + loss = torch.sum(errors * atom37_atom_exists, dim=-1) + loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)).unsqueeze(-1)) + loss = torch.sum(loss, dim=-1) + + loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution)) + + loss = torch.mean(loss) + + return loss + + +def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): + """Computes BERT-style masked MSA loss. Implements subsection 1.9.9. + + Args: + logits: [*, N_seq, N_res, 23] predicted residue distribution + true_msa: [*, N_seq, N_res] true MSA + bert_mask: [*, N_seq, N_res] MSA mask + Returns: + Masked MSA loss + """ + errors = softmax_cross_entropy(logits, torch.nn.functional.one_hot(true_msa, num_classes=23)) + + # FP16-friendly averaging. Equivalent to: + # loss = ( + # torch.sum(errors * bert_mask, dim=(-1, -2)) / + # (eps + torch.sum(bert_mask, dim=(-1, -2))) + # ) + loss = errors * bert_mask + loss = torch.sum(loss, dim=-1) + scale = 0.5 + denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2)) + loss = loss / denom[..., None] + loss = torch.sum(loss, dim=-1) + loss = loss * scale + + loss = torch.mean(loss) + + return loss + + +class AlphaFoldLoss(nn.Module): + """Aggregation of the various losses described in the supplement""" + + def __init__(self, config): + super(AlphaFoldLoss, self).__init__() + self.config = config + + def forward(self, out, batch, _return_breakdown=False): + """If "violation" not in out.keys(): + out["violation"] = find_structural_violations( + batch, + out["sm"]["positions"][-1], + **self.config.violation, + ) + """ + if "renamed_atom14_gt_positions" not in out.keys(): + batch.update( + compute_renamed_ground_truth( + batch, + out["sm"]["positions"][-1], + ) + ) + + loss_fns = { + "fape": lambda: fape_loss( + out, + batch, + self.config.fape, + ), + "supervised_chi": lambda: supervised_chi_loss( + out["sm"]["angles"], + out["sm"]["unnormalized_angles"], + **{**batch, **self.config.supervised_chi}, + ), + "plddt_loss": lambda: lddt_loss( + logits=out["lddt_logits"], + all_atom_pred_pos=out["final_atom_positions"], + **{**batch, **self.config.plddt_loss}, + ), + } + + """ + "distogram": lambda: distogram_loss( + logits=out["distogram_logits"], + **{**batch, **self.config.distogram}, + ), + "experimentally_resolved": lambda: experimentally_resolved_loss( + logits=out["experimentally_resolved_logits"], + **{**batch, **self.config.experimentally_resolved}, + ), + """ + + """ + "masked_msa": lambda: masked_msa_loss( + logits=out["masked_msa_logits"], + **{**batch, **self.config.masked_msa}, + ), + + "violation": lambda: violation_loss( + out["violation"], + **batch, + ), + """ + + if self.config.tm.enabled: + loss_fns["tm"] = lambda: tm_loss( + logits=out["tm_logits"], + **{**batch, **out, **self.config.tm}, + ) + + cum_loss = 0.0 + losses = {} + for loss_name, loss_fn in loss_fns.items(): + weight = self.config[loss_name].weight + loss = loss_fn() + if torch.isnan(loss) or torch.isinf(loss): + logging.warning(f"{loss_name} loss is NaN. Skipping...") + loss = loss.new_tensor(0.0, requires_grad=True) + cum_loss = cum_loss + weight * loss + losses[loss_name] = loss.detach().clone() + + losses["unscaled_loss"] = cum_loss.detach().clone() + + # Scale the loss by the square root of the minimum of the crop size and + # the (average) sequence length. See subsection 1.9. + seq_len = torch.mean(batch["seq_length"].float()) + crop_len = batch["aatype"].shape[-1] + cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len)) + + losses["loss"] = cum_loss.detach().clone() + + if not _return_breakdown: + return cum_loss + + return cum_loss, losses diff --git a/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/__init__.py b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/__init__.py new file mode 100644 index 0000000000..d59459359d --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/__init__.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from minifold_utils.init import ( + bias_init_one_, + bias_init_zero_, + final_init_, + gating_init_, + he_normal_init_, + lecun_normal_init_, +) diff --git a/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/feats.py b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/feats.py new file mode 100644 index 0000000000..3e0c0c1b7e --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/feats.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +from minifold_utils.rigid_utils import Rigid, Rotation +from minifold_utils.tensor_utils import ( + batched_gather, +) + + +def atom14_to_atom37(atom14, batch): + atom37_data = batched_gather( + atom14, + batch["residx_atom37_to_atom14"], + dim=-2, + no_batch_dims=len(atom14.shape[:-2]), + ) + + atom37_data = atom37_data * batch["atom37_atom_exists"][..., None] + + return atom37_data + + +def torsion_angles_to_frames( + r: Rigid, + alpha: torch.Tensor, + aatype: torch.Tensor, + rrgdf: torch.Tensor, +): + # [*, N, 8, 4, 4] + default_4x4 = rrgdf[aatype, ...] + + # [*, N, 8] transformations, i.e. + # One [*, N, 8, 3, 3] rotation matrix and + # One [*, N, 8, 3] translation matrix + default_r = r.from_tensor_4x4(default_4x4) + + bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) + bb_rot[..., 1] = 1 + + # [*, N, 8, 2] + alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) + + # [*, N, 8, 3, 3] + # Produces rotation matrices of the form: + # [ + # [1, 0 , 0 ], + # [0, a_2,-a_1], + # [0, a_1, a_2] + # ] + # This follows the original code rather than the supplement, which uses + # different indices. + + all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:] = alpha + + all_rots = Rigid(Rotation(rot_mats=all_rots), None) + + all_frames = default_r.compose(all_rots) + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + all_frames_to_bb = Rigid.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = r[..., None].compose(all_frames_to_bb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + r: Rigid, + aatype: torch.Tensor, + default_frames, + group_idx, + atom_mask, + lit_positions, +): + # [*, N, 14, 4, 4] + default_4x4 = default_frames[aatype, ...] + + # [*, N, 14] + group_mask = group_idx[aatype, ...] + + # [*, N, 14, 8] + group_mask = nn.functional.one_hot( + group_mask, + num_classes=default_frames.shape[-3], + ) + + # [*, N, 14, 8] + t_atoms_to_global = r[..., None, :] * group_mask + + # [*, N, 14] + t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) + + # [*, N, 14, 1] + atom_mask = atom_mask[aatype, ...].unsqueeze(-1) + + # [*, N, 14, 3] + lit_positions = lit_positions[aatype, ...] + pred_positions = t_atoms_to_global.apply(lit_positions) + pred_positions = pred_positions * atom_mask + + return pred_positions diff --git a/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/init.py b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/init.py new file mode 100644 index 0000000000..972c5943dd --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/init.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +from scipy.stats import truncnorm + + +def _prod(nums): + out = 1 + for n in nums: + out = out * n + return out + + +def _calculate_fan(linear_weight_shape, fan="fan_in"): + fan_out, fan_in = linear_weight_shape + + if fan == "fan_in": + f = fan_in + elif fan == "fan_out": + f = fan_out + elif fan == "fan_avg": + f = (fan_in + fan_out) / 2 + else: + raise ValueError("Invalid fan option") + + return f + + +def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): + shape = weights.shape + f = _calculate_fan(shape, fan) + scale = scale / max(1, f) + a = -2 + b = 2 + std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) + size = _prod(shape) + samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) + samples = np.reshape(samples, shape) + with torch.no_grad(): + weights.copy_(torch.tensor(samples, device=weights.device)) + + +def lecun_normal_init_(weights): + trunc_normal_init_(weights, scale=1.0) + + +def he_normal_init_(weights): + trunc_normal_init_(weights, scale=2.0) + + +def glorot_uniform_init_(weights): + torch.nn.init.xavier_uniform_(weights, gain=1) + + +def final_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def gating_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def bias_init_zero_(bias): + with torch.no_grad(): + bias.fill_(0.0) + + +def bias_init_one_(bias): + with torch.no_grad(): + bias.fill_(1.0) + + +def normal_init_(weights): + torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) diff --git a/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/residue_constants.py b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/residue_constants.py new file mode 100644 index 0000000000..8f9bb7fa1f --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/residue_constants.py @@ -0,0 +1,1286 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used in AlphaFold.""" + +import collections +import functools +from importlib import resources +from typing import List, Mapping, Tuple + +import numpy as np +import tree + + +# Internal import (35fd). + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + "ALA": [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + "ARG": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "NE"], + ["CG", "CD", "NE", "CZ"], + ], + "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "CYS": [["N", "CA", "CB", "SG"]], + "GLN": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLU": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLY": [], + "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]], + "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]], + "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "LYS": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "CE"], + ["CG", "CD", "CE", "NZ"], + ], + "MET": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "SD"], + ["CB", "CG", "SD", "CE"], + ], + "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]], + "SER": [["N", "CA", "CB", "OG"]], + "THR": [["N", "CA", "CB", "OG1"]], + "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "VAL": [["N", "CA", "CB", "CG1"]], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + "ALA": [ + ["N", 0, (-0.525, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.529, -0.774, -1.205)], + ["O", 3, (0.627, 1.062, 0.000)], + ], + "ARG": [ + ["N", 0, (-0.524, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.524, -0.778, -1.209)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG", 4, (0.616, 1.390, -0.000)], + ["CD", 5, (0.564, 1.414, 0.000)], + ["NE", 6, (0.539, 1.357, -0.000)], + ["NH1", 7, (0.206, 2.301, 0.000)], + ["NH2", 7, (2.078, 0.978, -0.000)], + ["CZ", 7, (0.758, 1.093, -0.000)], + ], + "ASN": [ + ["N", 0, (-0.536, 1.357, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.531, -0.787, -1.200)], + ["O", 3, (0.625, 1.062, 0.000)], + ["CG", 4, (0.584, 1.399, 0.000)], + ["ND2", 5, (0.593, -1.188, 0.001)], + ["OD1", 5, (0.633, 1.059, 0.000)], + ], + "ASP": [ + ["N", 0, (-0.525, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, 0.000, -0.000)], + ["CB", 0, (-0.526, -0.778, -1.208)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.593, 1.398, -0.000)], + ["OD1", 5, (0.610, 1.091, 0.000)], + ["OD2", 5, (0.592, -1.101, -0.003)], + ], + "CYS": [ + ["N", 0, (-0.522, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, 0.000, 0.000)], + ["CB", 0, (-0.519, -0.773, -1.212)], + ["O", 3, (0.625, 1.062, -0.000)], + ["SG", 4, (0.728, 1.653, 0.000)], + ], + "GLN": [ + ["N", 0, (-0.526, 1.361, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, 0.000)], + ["CB", 0, (-0.525, -0.779, -1.207)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.615, 1.393, 0.000)], + ["CD", 5, (0.587, 1.399, -0.000)], + ["NE2", 6, (0.593, -1.189, -0.001)], + ["OE1", 6, (0.634, 1.060, 0.000)], + ], + "GLU": [ + ["N", 0, (-0.528, 1.361, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.526, -0.781, -1.207)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG", 4, (0.615, 1.392, 0.000)], + ["CD", 5, (0.600, 1.397, 0.000)], + ["OE1", 6, (0.607, 1.095, -0.000)], + ["OE2", 6, (0.589, -1.104, -0.001)], + ], + "GLY": [ + ["N", 0, (-0.572, 1.337, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.517, -0.000, -0.000)], + ["O", 3, (0.626, 1.062, -0.000)], + ], + "HIS": [ + ["N", 0, (-0.527, 1.360, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, 0.000, 0.000)], + ["CB", 0, (-0.525, -0.778, -1.208)], + ["O", 3, (0.625, 1.063, 0.000)], + ["CG", 4, (0.600, 1.370, -0.000)], + ["CD2", 5, (0.889, -1.021, 0.003)], + ["ND1", 5, (0.744, 1.160, -0.000)], + ["CE1", 5, (2.030, 0.851, 0.002)], + ["NE2", 5, (2.145, -0.466, 0.004)], + ], + "ILE": [ + ["N", 0, (-0.493, 1.373, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, -0.000)], + ["CB", 0, (-0.536, -0.793, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG1", 4, (0.534, 1.437, -0.000)], + ["CG2", 4, (0.540, -0.785, -1.199)], + ["CD1", 5, (0.619, 1.391, 0.000)], + ], + "LEU": [ + ["N", 0, (-0.520, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.522, -0.773, -1.214)], + ["O", 3, (0.625, 1.063, -0.000)], + ["CG", 4, (0.678, 1.371, 0.000)], + ["CD1", 5, (0.530, 1.430, -0.000)], + ["CD2", 5, (0.535, -0.774, 1.200)], + ], + "LYS": [ + ["N", 0, (-0.526, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, 0.000)], + ["CB", 0, (-0.524, -0.778, -1.208)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.619, 1.390, 0.000)], + ["CD", 5, (0.559, 1.417, 0.000)], + ["CE", 6, (0.560, 1.416, 0.000)], + ["NZ", 7, (0.554, 1.387, 0.000)], + ], + "MET": [ + ["N", 0, (-0.521, 1.364, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, 0.000, 0.000)], + ["CB", 0, (-0.523, -0.776, -1.210)], + ["O", 3, (0.625, 1.062, -0.000)], + ["CG", 4, (0.613, 1.391, -0.000)], + ["SD", 5, (0.703, 1.695, 0.000)], + ["CE", 6, (0.320, 1.786, -0.000)], + ], + "PHE": [ + ["N", 0, (-0.518, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, 0.000, -0.000)], + ["CB", 0, (-0.525, -0.776, -1.212)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.607, 1.377, 0.000)], + ["CD1", 5, (0.709, 1.195, -0.000)], + ["CD2", 5, (0.706, -1.196, 0.000)], + ["CE1", 5, (2.102, 1.198, -0.000)], + ["CE2", 5, (2.098, -1.201, -0.000)], + ["CZ", 5, (2.794, -0.003, -0.001)], + ], + "PRO": [ + ["N", 0, (-0.566, 1.351, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, 0.000)], + ["CB", 0, (-0.546, -0.611, -1.293)], + ["O", 3, (0.621, 1.066, 0.000)], + ["CG", 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + "SER": [ + ["N", 0, (-0.529, 1.360, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.518, -0.777, -1.211)], + ["O", 3, (0.626, 1.062, -0.000)], + ["OG", 4, (0.503, 1.325, 0.000)], + ], + "THR": [ + ["N", 0, (-0.517, 1.364, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, -0.000)], + ["CB", 0, (-0.516, -0.793, -1.215)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG2", 4, (0.550, -0.718, -1.228)], + ["OG1", 4, (0.472, 1.353, 0.000)], + ], + "TRP": [ + ["N", 0, (-0.521, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, 0.000)], + ["CB", 0, (-0.523, -0.776, -1.212)], + ["O", 3, (0.627, 1.062, 0.000)], + ["CG", 4, (0.609, 1.370, -0.000)], + ["CD1", 5, (0.824, 1.091, 0.000)], + ["CD2", 5, (0.854, -1.148, -0.005)], + ["CE2", 5, (2.186, -0.678, -0.007)], + ["CE3", 5, (0.622, -2.530, -0.007)], + ["NE1", 5, (2.140, 0.690, -0.004)], + ["CH2", 5, (3.028, -2.890, -0.013)], + ["CZ2", 5, (3.283, -1.543, -0.011)], + ["CZ3", 5, (1.715, -3.389, -0.011)], + ], + "TYR": [ + ["N", 0, (-0.522, 1.362, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, -0.000, -0.000)], + ["CB", 0, (-0.522, -0.776, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG", 4, (0.607, 1.382, -0.000)], + ["CD1", 5, (0.716, 1.195, -0.000)], + ["CD2", 5, (0.713, -1.194, -0.001)], + ["CE1", 5, (2.107, 1.200, -0.002)], + ["CE2", 5, (2.104, -1.201, -0.003)], + ["OH", 5, (4.168, -0.002, -0.005)], + ["CZ", 5, (2.791, -0.001, -0.003)], + ], + "VAL": [ + ["N", 0, (-0.494, 1.373, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, -0.000)], + ["CB", 0, (-0.533, -0.795, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG1", 4, (0.540, 1.429, -0.000)], + ["CG2", 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + "ALA": ["C", "CA", "CB", "N", "O"], + "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"], + "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"], + "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"], + "CYS": ["C", "CA", "CB", "N", "O", "SG"], + "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"], + "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"], + "GLY": ["C", "CA", "N", "O"], + "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"], + "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"], + "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"], + "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"], + "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"], + "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"], + "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"], + "SER": ["C", "CA", "CB", "N", "O", "OG"], + "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"], + "TRP": [ + "C", + "CA", + "CB", + "CG", + "CD1", + "CD2", + "CE2", + "CE3", + "CZ2", + "CZ3", + "CH2", + "N", + "NE1", + "O", + ], + "TYR": [ + "C", + "CA", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "N", + "O", + "OH", + ], + "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"], +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +# Because for LEU, VAL and ARG, no ambiguous exist when the prediction output is chi angle instead of the location of individual atoms. +# For the rest, ASP and others, when you rotate the bond 180 degree, you get the same configuraiton due to symmetry. + +residue_atom_renaming_swaps = { + "ASP": {"OD1": "OD2"}, + "GLU": {"OE1": "OE2"}, + "PHE": {"CD1": "CD2", "CE1": "CE2"}, + "TYR": {"CD1": "CD2", "CE1": "CE2"}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + "C": 1.7, + "N": 1.55, + "O": 1.52, + "S": 1.8, +} + +Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"]) +BondAngle = collections.namedtuple( + "BondAngle", + ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"], +) + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props() -> Tuple[ + Mapping[str, List[Bond]], + Mapping[str, List[Bond]], + Mapping[str, List[BondAngle]], +]: + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: dict that maps resname --> list of Bond tuples + residue_virtual_bonds: dict that maps resname --> list of Bond tuples + residue_bond_angles: dict that maps resname --> list of BondAngle tuples + """ + # TODO: this file should be downloaded in a setup script + stereo_chemical_props = resources.read_text("minifold_utils", "stereo_chemical_props.txt") + + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split("-") + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append(Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds["UNK"] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split("-") + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle( + atom1, + atom2, + atom3, + float(angle_degree) / 180.0 * np.pi, + float(stddev_degree) / 180.0 * np.pi, + ) + ) + residue_bond_angles["UNK"] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return "-".join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer + stddev = np.sqrt( + (dl_dgamma * ba.stddev) ** 2 + (dl_db1 * bond1.stddev) ** 2 + (dl_db2 * bond2.stddev) ** 2 + ) + residue_virtual_bonds[resname].append(Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return (residue_bonds, residue_virtual_bonds, residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + "N", + "CA", + "C", + "CB", + "O", + "CG", + "CG1", + "CG2", + "OG", + "OG1", + "SG", + "CD", + "CD1", + "CD2", + "ND1", + "ND2", + "OD1", + "OD2", + "SD", + "CE", + "CE1", + "CE2", + "CE3", + "NE", + "NE1", + "NE2", + "OE1", + "OE2", + "CH2", + "NH1", + "NH2", + "OH", + "CZ", + "CZ2", + "CZ3", + "NZ", + "OXT", +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], + "ARG": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "NE", + "CZ", + "NH1", + "NH2", + "", + "", + "", + ], + "ASN": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "OD1", + "ND2", + "", + "", + "", + "", + "", + "", + ], + "ASP": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "OD1", + "OD2", + "", + "", + "", + "", + "", + "", + ], + "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""], + "GLN": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "NE2", + "", + "", + "", + "", + "", + ], + "GLU": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "OE2", + "", + "", + "", + "", + "", + ], + "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""], + "HIS": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "ND1", + "CD2", + "CE1", + "NE2", + "", + "", + "", + "", + ], + "ILE": [ + "N", + "CA", + "C", + "O", + "CB", + "CG1", + "CG2", + "CD1", + "", + "", + "", + "", + "", + "", + ], + "LEU": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "", + "", + "", + "", + "", + "", + ], + "LYS": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "CE", + "NZ", + "", + "", + "", + "", + "", + ], + "MET": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "SD", + "CE", + "", + "", + "", + "", + "", + "", + ], + "PHE": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "", + "", + "", + ], + "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""], + "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""], + "THR": [ + "N", + "CA", + "C", + "O", + "CB", + "OG1", + "CG2", + "", + "", + "", + "", + "", + "", + "", + ], + "TRP": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "NE1", + "CE2", + "CE3", + "CZ2", + "CZ3", + "CH2", + ], + "TYR": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "OH", + "", + "", + ], + "VAL": [ + "N", + "CA", + "C", + "O", + "CB", + "CG1", + "CG2", + "", + "", + "", + "", + "", + "", + "", + ], + "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ["X"] +restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} +restype_order_with_x_inverse = {v: k for k, v in restype_order_with_x.items()} + +num_atoms = { + "A": 5, + "R": 11, + "N": 8, + "D": 8, + "C": 6, + "Q": 9, + "E": 9, + "G": 4, + "H": 10, + "I": 8, + "L": 8, + "K": 9, + "M": 8, + "F": 11, + "P": 7, + "S": 6, + "T": 7, + "W": 14, + "Y": 12, + "V": 7, + "X": 0, + "Z": 9, + "B": 8, + "U": 6, +} + + +def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError( + "The mapping must have values from 0 to num_unique_aas-1 " + "without any gaps. Got: %s" % sorted(mapping.values()) + ) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping["X"]) + else: + raise ValueError(f"Invalid character in the sequence: {aa_type}") + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = "UNK" + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + "A": 0, + "B": 2, + "C": 1, + "D": 2, + "E": 3, + "F": 4, + "G": 5, + "H": 6, + "I": 7, + "J": 20, + "K": 8, + "L": 9, + "M": 10, + "N": 11, + "O": 20, + "P": 12, + "Q": 13, + "R": 14, + "S": 15, + "T": 16, + "U": 1, + "V": 17, + "W": 18, + "X": 20, + "Y": 19, + "Z": 3, + "-": 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: "A", + 1: "C", # Also U. + 2: "D", # Also B. + 3: "E", # Also Z. + 4: "F", + 5: "G", + 6: "H", + 7: "I", + 8: "K", + 9: "L", + 10: "M", + 11: "N", + 12: "P", + 13: "Q", + 14: "R", + 15: "S", + 16: "T", + 17: "V", + 18: "W", + 19: "Y", + 20: "X", # Includes J and O. + 21: "-", +} + +restypes_with_x_and_gap = restypes + ["X", "-"] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) for i in range(len(restypes_with_x_and_gap)) +) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1] * (4 - len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree.map_structure(lambda atom_name: atom_order[atom_name], chi_angles_atom_indices) +chi_angles_atom_indices = np.array( + [chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) for chi_atoms in chi_angles_atom_indices] +) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose() + m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = {name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]} + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["N"] - atom_positions["CA"], + ey=np.array([1.0, 0.0, 0.0]), + translation=atom_positions["N"], + ) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["C"] - atom_positions["CA"], + ey=atom_positions["CA"] - atom_positions["N"], + translation=atom_positions["C"], + ) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [atom_positions[name] for name in base_atom_names] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2], + ) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1.0, 0.0, 0.0]), + translation=axis_end_atom_position, + ) + restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=15): + """Compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev + return { + "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14) + "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14) + "stddev": restype_atom14_bond_stddev, # shape (21,14,14) + } + + +restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32) +restype_atom14_ambiguous_atoms_swap_idx = np.tile(np.arange(14, dtype=int), (21, 1)) + + +def _make_atom14_ambiguity_feats(): + for res, pairs in residue_atom_renaming_swaps.items(): + res_idx = restype_order[restype_3to1[res]] + for atom1, atom2 in pairs.items(): + atom1_idx = restype_name_to_atom14_names[res].index(atom1) + atom2_idx = restype_name_to_atom14_names[res].index(atom2) + restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1 + restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1 + restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = atom2_idx + restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = atom1_idx + + +_make_atom14_ambiguity_feats() + + +def aatype_to_str_sequence(aatype): + return "".join([restypes_with_x[aatype[i]] for i in range(len(aatype))]) diff --git a/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/rigid_utils.py b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/rigid_utils.py new file mode 100644 index 0000000000..fd6326c91b --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/rigid_utils.py @@ -0,0 +1,1348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +# from functools import lru_cache +from typing import Any, Callable, Optional, Sequence, Tuple + +import numpy as np +import torch + + +def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Performs matrix multiplication of two rotation matrix tensors. Written + out by hand to avoid AMP downcasting. + + Args: + a: [*, 3, 3] left multiplicand + b: [*, 3, 3] right multiplicand + Returns: + The product ab + """ + + def row_mul(i): + return torch.stack( + [ + a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0], + a[..., i, 0] * b[..., 0, 1] + a[..., i, 1] * b[..., 1, 1] + a[..., i, 2] * b[..., 2, 1], + a[..., i, 0] * b[..., 0, 2] + a[..., i, 1] * b[..., 1, 2] + a[..., i, 2] * b[..., 2, 2], + ], + dim=-1, + ) + + return torch.stack( + [ + row_mul(0), + row_mul(1), + row_mul(2), + ], + dim=-2, + ) + + +def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """Applies a rotation to a vector. Written out by hand to avoid transfer + to avoid AMP downcasting. + + Args: + r: [*, 3, 3] rotation matrices + t: [*, 3] coordinate tensors + Returns: + [*, 3] rotated coordinates + """ + x, y, z = torch.unbind(t, dim=-1) + return torch.stack( + [ + r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z, + r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z, + r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z, + ], + dim=-1, + ) + + +# @lru_cache(maxsize=None) +def identity_rot_mats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + rots = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad) + rots = rots.view(*((1,) * len(batch_dims)), 3, 3) + rots = rots.expand(*batch_dims, -1, -1) + rots = rots.contiguous() + + return rots + + +# @lru_cache(maxsize=None) +def identity_trans( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + trans = torch.zeros((*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad) + return trans + + +# @lru_cache(maxsize=None) +def identity_quats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + quat = torch.zeros((*batch_dims, 4), dtype=dtype, device=device, requires_grad=requires_grad) + + with torch.no_grad(): + quat[..., 0] = 1 + + return quat + + +_quat_elements = ["a", "b", "c", "d"] +_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] +_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} + + +def _to_mat(pairs): + mat = np.zeros((4, 4)) + for pair in pairs: + key, value = pair + ind = _qtr_ind_dict[key] + mat[ind // 4][ind % 4] = value + + return mat + + +_QTR_MAT = np.zeros((4, 4, 3, 3)) +_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)]) +_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)]) +_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)]) +_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)]) +_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)]) +_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)]) +_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)]) +_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)]) +_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)]) + + +def quat_to_rot(quat: torch.Tensor) -> torch.Tensor: + """Converts a quaternion to a rotation matrix. + + Args: + quat: [*, 4] quaternions + Returns: + [*, 3, 3] rotation matrices + """ + # [*, 4, 4] + quat = quat[..., None] * quat[..., None, :] + + # [4, 4, 3, 3] + mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device) + + # [*, 4, 4, 3, 3] + shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape) + quat = quat[..., None, None] * shaped_qtr_mat + + # [*, 3, 3] + return torch.sum(quat, dim=(-3, -4)) + + +def rot_to_quat( + rot: torch.Tensor, +): + if rot.shape[-2:] != (3, 3): + raise ValueError("Input rotation is incorrectly shaped") + + rot = [[rot[..., i, j] for j in range(3)] for i in range(3)] + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + + k = [ + [ + xx + yy + zz, + zy - yz, + xz - zx, + yx - xy, + ], + [ + zy - yz, + xx - yy - zz, + xy + yx, + xz + zx, + ], + [ + xz - zx, + xy + yx, + yy - xx - zz, + yz + zy, + ], + [ + yx - xy, + xz + zx, + yz + zy, + zz - xx - yy, + ], + ] + + k = (1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2) + + _, vectors = torch.linalg.eigh(k) + + return vectors[..., -1] + + +_QUAT_MULTIPLY = np.zeros((4, 4, 4)) +_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]] + +_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]] + +_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]] + +_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]] + +_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] + +_CACHED_QUATS = { + "_QTR_MAT": _QTR_MAT, + "_QUAT_MULTIPLY": _QUAT_MULTIPLY, + "_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC, +} + + +# @lru_cache(maxsize=None) +def _get_quat(quat_key, dtype, device): + return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device) + + +def quat_multiply(quat1, quat2): + """Multiply a quaternion by another quaternion.""" + mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device) + reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], + dim=(-3, -2), + ) + + +def quat_multiply_by_vec(quat, vec): + """Multiply a quaternion by a pure-vector quaternion.""" + mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device) + reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) + return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2)) + + +def invert_rot_mat(rot_mat: torch.Tensor): + return rot_mat.transpose(-1, -2) + + +def invert_quat(quat: torch.Tensor): + quat_prime = quat.clone() + quat_prime[..., 1:] *= -1 + inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True) + return inv + + +class Rotation: + """A 3D rotation. Depending on how the object is initialized, the + rotation is represented by either a rotation matrix or a + quaternion, though both formats are made available by helper functions. + To simplify gradient computation, the underlying format of the + rotation cannot be changed in-place. Like Rigid, the class is designed + to mimic the behavior of a torch Tensor, almost as if each Rotation + object were a tensor of rotations, in one format or another. + """ + + def __init__( + self, + rot_mats: Optional[torch.Tensor] = None, + quats: Optional[torch.Tensor] = None, + normalize_quats: bool = True, + ): + """Args: + rot_mats: + A [*, 3, 3] rotation matrix tensor. Mutually exclusive with + quats + quats: + A [*, 4] quaternion. Mutually exclusive with rot_mats. If + normalize_quats is not True, must be a unit quaternion + normalize_quats: + If quats is specified, whether to normalize quats + """ + if (rot_mats is None and quats is None) or (rot_mats is not None and quats is not None): + raise ValueError("Exactly one input argument must be specified") + + if (rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (quats is not None and quats.shape[-1] != 4): + raise ValueError("Incorrectly shaped rotation matrix or quaternion") + + # Force full-precision + if quats is not None: + quats = quats.to(dtype=torch.float32) + if rot_mats is not None: + rot_mats = rot_mats.to(dtype=torch.float32) + + if quats is not None and normalize_quats: + quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) + + self._rot_mats = rot_mats + self._quats = quats + + @staticmethod + def identity( + shape, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ) -> Rotation: + """Returns an identity Rotation. + + Args: + shape: + The "shape" of the resulting Rotation object. See documentation + for the shape property + dtype: + The torch dtype for the rotation + device: + The torch device for the new rotation + requires_grad: + Whether the underlying tensors in the new rotation object + should require gradient computation + fmt: + One of "quat" or "rot_mat". Determines the underlying format + of the new object's rotation + Returns: + A new identity rotation + """ + if fmt == "rot_mat": + rot_mats = identity_rot_mats( + shape, + dtype, + device, + requires_grad, + ) + return Rotation(rot_mats=rot_mats, quats=None) + elif fmt == "quat": + quats = identity_quats(shape, dtype, device, requires_grad) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError(f"Invalid format: f{fmt}") + + # Magic methods + + def __getitem__(self, index: Any) -> Rotation: + """Allows torch-style indexing over the virtual shape of the rotation + object. See documentation for the shape property. + + Args: + index: + A torch index. E.g. (1, 3, 2), or (slice(None,)) + + Returns: + The indexed rotation + """ + if type(index) != tuple: + index = (index,) + + if self._rot_mats is not None: + rot_mats = self._rot_mats[index + (slice(None), slice(None))] + return Rotation(rot_mats=rot_mats) + elif self._quats is not None: + quats = self._quats[index + (slice(None),)] + return Rotation(quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __mul__( + self, + right: torch.Tensor, + ) -> Rotation: + """Pointwise left multiplication of the rotation with a tensor. Can be + used to e.g. mask the Rotation. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not (isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + if self._rot_mats is not None: + rot_mats = self._rot_mats * right[..., None, None] + return Rotation(rot_mats=rot_mats, quats=None) + elif self._quats is not None: + quats = self._quats * right[..., None] + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __rmul__( + self, + left: torch.Tensor, + ) -> Rotation: + """Reverse pointwise multiplication of the rotation with a tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + # Properties + + @property + def shape(self) -> torch.Size: + """Returns the virtual shape of the rotation object. This shape is + defined as the batch dimensions of the underlying rotation matrix + or quaternion. If the Rotation was initialized with a [10, 3, 3] + rotation matrix tensor, for example, the resulting shape would be + [10]. + + Returns: + The virtual shape of the rotation object + """ + s = None + if self._quats is not None: + s = self._quats.shape[:-1] + else: + s = self._rot_mats.shape[:-2] + + return s + + @property + def dtype(self) -> torch.dtype: + """Returns the dtype of the underlying rotation. + + Returns: + The dtype of the underlying rotation + """ + if self._rot_mats is not None: + return self._rot_mats.dtype + elif self._quats is not None: + return self._quats.dtype + else: + raise ValueError("Both rotations are None") + + @property + def device(self) -> torch.device: + """The device of the underlying rotation + + Returns: + The device of the underlying rotation + """ + if self._rot_mats is not None: + return self._rot_mats.device + elif self._quats is not None: + return self._quats.device + else: + raise ValueError("Both rotations are None") + + @property + def requires_grad(self) -> bool: + """Returns the requires_grad property of the underlying rotation + + Returns: + The requires_grad property of the underlying tensor + """ + if self._rot_mats is not None: + return self._rot_mats.requires_grad + elif self._quats is not None: + return self._quats.requires_grad + else: + raise ValueError("Both rotations are None") + + def get_rot_mats(self) -> torch.Tensor: + """Returns the underlying rotation as a rotation matrix tensor. + + Returns: + The rotation as a rotation matrix tensor + """ + rot_mats = self._rot_mats + if rot_mats is None: + if self._quats is None: + raise ValueError("Both rotations are None") + else: + rot_mats = quat_to_rot(self._quats) + + return rot_mats + + def get_quats(self) -> torch.Tensor: + """Returns the underlying rotation as a quaternion tensor. + + Depending on whether the Rotation was initialized with a + quaternion, this function may call torch.linalg.eigh. + + Returns: + The rotation as a quaternion tensor. + """ + quats = self._quats + if quats is None: + if self._rot_mats is None: + raise ValueError("Both rotations are None") + else: + quats = rot_to_quat(self._rot_mats) + + return quats + + def get_cur_rot(self) -> torch.Tensor: + """Return the underlying rotation in its current form + + Returns: + The stored rotation + """ + if self._rot_mats is not None: + return self._rot_mats + elif self._quats is not None: + return self._quats + else: + raise ValueError("Both rotations are None") + + # Rotation functions + + def compose_q_update_vec(self, q_update_vec: torch.Tensor, normalize_quats: bool = True) -> Rotation: + """Returns a new quaternion Rotation after updating the current + object's underlying rotation with a quaternion update, formatted + as a [*, 3] tensor whose final three columns represent x, y, z such + that (1, x, y, z) is the desired (not necessarily unit) quaternion + update. + + Args: + q_update_vec: + A [*, 3] quaternion update tensor + normalize_quats: + Whether to normalize the output quaternion + Returns: + An updated Rotation + """ + quats = self.get_quats() + new_quats = quats + quat_multiply_by_vec(quats, q_update_vec) + return Rotation( + rot_mats=None, + quats=new_quats, + normalize_quats=normalize_quats, + ) + + def compose_r(self, r: Rotation) -> Rotation: + """Compose the rotation matrices of the current Rotation object with + those of another. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + r1 = self.get_rot_mats() + r2 = r.get_rot_mats() + new_rot_mats = rot_matmul(r1, r2) + return Rotation(rot_mats=new_rot_mats, quats=None) + + def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation: + """Compose the quaternions of the current Rotation object with those + of another. + + Depending on whether either Rotation was initialized with + quaternions, this function may call torch.linalg.eigh. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + q1 = self.get_quats() + q2 = r.get_quats() + new_quats = quat_multiply(q1, q2) + return Rotation(rot_mats=None, quats=new_quats, normalize_quats=normalize_quats) + + def apply(self, pts: torch.Tensor) -> torch.Tensor: + """Apply the current Rotation as a rotation matrix to a set of 3D + coordinates. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] rotated points + """ + rot_mats = self.get_rot_mats() + return rot_vec_mul(rot_mats, pts) + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + """The inverse of the apply() method. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] inverse-rotated points + """ + rot_mats = self.get_rot_mats() + inv_rot_mats = invert_rot_mat(rot_mats) + return rot_vec_mul(inv_rot_mats, pts) + + def invert(self) -> Rotation: + """Returns the inverse of the current Rotation. + + Returns: + The inverse of the current Rotation + """ + if self._rot_mats is not None: + return Rotation(rot_mats=invert_rot_mat(self._rot_mats), quats=None) + elif self._quats is not None: + return Rotation( + rot_mats=None, + quats=invert_quat(self._quats), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + # "Tensor" stuff + + def unsqueeze( + self, + dim: int, + ) -> Rigid: + """Analogous to torch.unsqueeze. The dimension is relative to the + shape of the Rotation object. + + Args: + dim: A positive or negative dimension index. + + Returns: + The unsqueezed Rotation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + + if self._rot_mats is not None: + rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2) + return Rotation(rot_mats=rot_mats, quats=None) + elif self._quats is not None: + quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + @staticmethod + def cat( + rs: Sequence[Rotation], + dim: int, + ) -> Rigid: + """Concatenates rotations along one of the batch dimensions. Analogous + to torch.cat(). + + Note that the output of this operation is always a rotation matrix, + regardless of the format of input rotations. + + Args: + rs: + A list of rotation objects + dim: + The dimension along which the rotations should be + concatenated + Returns: + A concatenated Rotation object in rotation matrix format + """ + rot_mats = [r.get_rot_mats() for r in rs] + rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) + + return Rotation(rot_mats=rot_mats, quats=None) + + def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation: + """Apply a Tensor -> Tensor function to underlying rotation tensors, + mapping over the rotation dimension(s). Can be used e.g. to sum out + a one-hot batch dimension. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rotation + Returns: + The transformed Rotation object + """ + if self._rot_mats is not None: + rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,)) + rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1) + rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3)) + return Rotation(rot_mats=rot_mats, quats=None) + elif self._quats is not None: + quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def cuda(self) -> Rotation: + """Moves the transformation object to GPU memory (CUDA or MPS if available) + + Returns: + A version of the transformation on GPU (CUDA or MPS) + """ + # Prefer CUDA if available, else MPS, else CPU + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + return self.to(device) + + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Rigid: + """Moves the transformation object to the specified device and dtype. + + Args: + device: torch.device or None + dtype: torch.dtype or None + Returns: + A version of the transformation on the specified device/dtype + """ + rots = self._rots + trans = self._trans + if rots is not None: + rots = rots.to(device=device, dtype=dtype) + if trans is not None: + trans = trans.to(device=device, dtype=dtype) + return Rigid(rots, trans) + + def detach(self) -> Rotation: + """Returns a copy of the Rotation whose underlying Tensor has been + detached from its torch graph. + + Returns: + A copy of the Rotation whose underlying Tensor has been detached + from its torch graph + """ + if self._rot_mats is not None: + return Rotation(rot_mats=self._rot_mats.detach(), quats=None) + elif self._quats is not None: + return Rotation( + rot_mats=None, + quats=self._quats.detach(), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + +class Rigid: + """A class representing a rigid transformation. Little more than a wrapper + around two objects: a Rotation object and a [*, 3] translation + Designed to behave approximately like a single torch tensor with the + shape of the shared batch dimensions of its component parts. + """ + + def __init__( + self, + rots: Optional[Rotation], + trans: Optional[torch.Tensor], + ): + """Args: + rots: A [*, 3, 3] rotation tensor + trans: A corresponding [*, 3] translation tensor + """ + # (we need device, dtype, etc. from at least one input) + + batch_dims, dtype, device, requires_grad = None, None, None, None + if trans is not None: + batch_dims = trans.shape[:-1] + dtype = trans.dtype + device = trans.device + requires_grad = trans.requires_grad + elif rots is not None: + batch_dims = rots.shape + dtype = rots.dtype + device = rots.device + requires_grad = rots.requires_grad + else: + raise ValueError("At least one input argument must be specified") + + if rots is None: + rots = Rotation.identity( + batch_dims, + dtype, + device, + requires_grad, + ) + elif trans is None: + trans = identity_trans( + batch_dims, + dtype, + device, + requires_grad, + ) + + if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device): + raise ValueError("Rots and trans incompatible") + + # Force full precision. Happens to the rotations automatically. + trans = trans.to(dtype=torch.float32) + + self._rots = rots + self._trans = trans + + @staticmethod + def identity( + shape: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ) -> Rigid: + """Constructs an identity transformation. + + Args: + shape: + The desired shape + dtype: + The dtype of both internal tensors + device: + The device of both internal tensors + requires_grad: + Whether grad should be enabled for the internal tensors + Returns: + The identity transformation + """ + return Rigid( + Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt), + identity_trans(shape, dtype, device, requires_grad), + ) + + def __getitem__( + self, + index: Any, + ) -> Rigid: + """Indexes the affine transformation with PyTorch-style indices. + The index is applied to the shared dimensions of both the rotation + and the translation. + + E.g.:: + + r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None) + t = Rigid(r, torch.rand(10, 10, 3)) + indexed = t[3, 4:6] + assert(indexed.shape == (2,)) + assert(indexed.get_rots().shape == (2,)) + assert(indexed.get_trans().shape == (2, 3)) + + Args: + index: A standard torch tensor index. E.g. 8, (10, None, 3), + or (3, slice(0, 1, None)) + + Returns: + The indexed tensor + """ + if type(index) != tuple: + index = (index,) + + return Rigid( + self._rots[index], + self._trans[index + (slice(None),)], + ) + + def __mul__( + self, + right: torch.Tensor, + ) -> Rigid: + """Pointwise left multiplication of the transformation with a tensor. + Can be used to e.g. mask the Rigid. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not (isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + new_rots = self._rots * right + new_trans = self._trans * right[..., None] + + return Rigid(new_rots, new_trans) + + def __rmul__( + self, + left: torch.Tensor, + ) -> Rigid: + """Reverse pointwise multiplication of the transformation with a + tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + @property + def shape(self) -> torch.Size: + """Returns the shape of the shared dimensions of the rotation and + the translation. + + Returns: + The shape of the transformation + """ + s = self._trans.shape[:-1] + return s + + @property + def device(self) -> torch.device: + """Returns the device on which the Rigid's tensors are located. + + Returns: + The device on which the Rigid's tensors are located + """ + return self._trans.device + + def get_rots(self) -> Rotation: + """Getter for the rotation. + + Returns: + The rotation object + """ + return self._rots + + def get_trans(self) -> torch.Tensor: + """Getter for the translation. + + Returns: + The stored translation + """ + return self._trans + + def compose_q_update_vec( + self, + q_update_vec: torch.Tensor, + ) -> Rigid: + """Composes the transformation with a quaternion update vector of + shape [*, 6], where the final 6 columns represent the x, y, and + z values of a quaternion of form (1, x, y, z) followed by a 3D + translation. + + Args: + q_vec: The quaternion update vector. + + Returns: + The composed transformation. + """ + q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:] + new_rots = self._rots.compose_q_update_vec(q_vec) + + trans_update = self._rots.apply(t_vec) + new_translation = self._trans + trans_update + + return Rigid(new_rots, new_translation) + + def compose( + self, + r: Rigid, + ) -> Rigid: + """Composes the current rigid object with another. + + Args: + r: + Another Rigid object + Returns: + The composition of the two transformations + """ + new_rot = self._rots.compose_r(r._rots) + new_trans = self._rots.apply(r._trans) + self._trans + return Rigid(new_rot, new_trans) + + def apply( + self, + pts: torch.Tensor, + ) -> torch.Tensor: + """Applies the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor. + + Returns: + The transformed points. + """ + rotated = self._rots.apply(pts) + return rotated + self._trans + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + """Applies the inverse of the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor + Returns: + The transformed points. + """ + pts = pts - self._trans + return self._rots.invert_apply(pts) + + def invert(self) -> Rigid: + """Inverts the transformation. + + Returns: + The inverse transformation. + """ + rot_inv = self._rots.invert() + trn_inv = rot_inv.apply(self._trans) + + return Rigid(rot_inv, -1 * trn_inv) + + def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: + """Apply a Tensor -> Tensor function to underlying translation and + rotation tensors, mapping over the translation/rotation dimensions + respectively. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rigid + Returns: + The transformed Rigid object + """ + new_rots = self._rots.map_tensor_fn(fn) + new_trans = torch.stack(list(map(fn, torch.unbind(self._trans, dim=-1))), dim=-1) + + return Rigid(new_rots, new_trans) + + def to_tensor_4x4(self) -> torch.Tensor: + """Converts a transformation to a homogenous transformation tensor. + + Returns: + A [*, 4, 4] homogenous transformation tensor + """ + tensor = self._trans.new_zeros((*self.shape, 4, 4)) + tensor[..., :3, :3] = self._rots.get_rot_mats() + tensor[..., :3, 3] = self._trans + tensor[..., 3, 3] = 1 + return tensor + + @staticmethod + def from_tensor_4x4(t: torch.Tensor) -> Rigid: + """Constructs a transformation from a homogenous transformation + tensor. + + Args: + t: [*, 4, 4] homogenous transformation tensor + Returns: + T object with shape [*] + """ + if t.shape[-2:] != (4, 4): + raise ValueError("Incorrectly shaped input tensor") + + rots = Rotation(rot_mats=t[..., :3, :3], quats=None) + trans = t[..., :3, 3] + + return Rigid(rots, trans) + + def to_tensor_7(self) -> torch.Tensor: + """Converts a transformation to a tensor with 7 final columns, four + for the quaternion followed by three for the translation. + + Returns: + A [*, 7] tensor representation of the transformation + """ + tensor = self._trans.new_zeros((*self.shape, 7)) + tensor[..., :4] = self._rots.get_quats() + tensor[..., 4:] = self._trans + + return tensor + + @staticmethod + def from_tensor_7( + t: torch.Tensor, + normalize_quats: bool = False, + ) -> Rigid: + if t.shape[-1] != 7: + raise ValueError("Incorrectly shaped input tensor") + + quats, trans = t[..., :4], t[..., 4:] + + rots = Rotation(rot_mats=None, quats=quats, normalize_quats=normalize_quats) + + return Rigid(rots, trans) + + @staticmethod + def from_3_points( + p_neg_x_axis: torch.Tensor, + origin: torch.Tensor, + p_xy_plane: torch.Tensor, + eps: float = 1e-5, + ) -> Rigid: + """Implements algorithm 21. Constructs transformations from sets of 3 + points using the Gram-Schmidt algorithm. + + Args: + p_neg_x_axis: [*, 3] coordinates + origin: [*, 3] coordinates used as frame origins + p_xy_plane: [*, 3] coordinates + eps: Small epsilon value + Returns: + A transformation object of shape [*] + """ + p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) + origin = torch.unbind(origin, dim=-1) + p_xy_plane = torch.unbind(p_xy_plane, dim=-1) + + e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] + e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] + + denom = torch.sqrt(sum((c * c for c in e0)) + eps) + e0 = [c / denom for c in e0] + dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) + e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] + denom = torch.sqrt(sum((c * c for c in e1)) + eps) + e1 = [c / denom for c in e1] + e2 = [ + e0[1] * e1[2] - e0[2] * e1[1], + e0[2] * e1[0] - e0[0] * e1[2], + e0[0] * e1[1] - e0[1] * e1[0], + ] + + rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) + rots = rots.reshape(rots.shape[:-1] + (3, 3)) + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, torch.stack(origin, dim=-1)) + + def unsqueeze( + self, + dim: int, + ) -> Rigid: + """Analogous to torch.unsqueeze. The dimension is relative to the + shared dimensions of the rotation/translation. + + Args: + dim: A positive or negative dimension index. + + Returns: + The unsqueezed transformation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + rots = self._rots.unsqueeze(dim) + trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1) + + return Rigid(rots, trans) + + @staticmethod + def cat( + ts: Sequence[Rigid], + dim: int, + ) -> Rigid: + """Concatenates transformations along a new dimension. + + Args: + ts: + A list of T objects + dim: + The dimension along which the transformations should be + concatenated + Returns: + A concatenated transformation object + """ + rots = Rotation.cat([t._rots for t in ts], dim) + trans = torch.cat([t._trans for t in ts], dim=dim if dim >= 0 else dim - 1) + + return Rigid(rots, trans) + + def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid: + """Applies a Rotation -> Rotation function to the stored rotation + object. + + Args: + fn: A function of type Rotation -> Rotation + Returns: + A transformation object with a transformed rotation. + """ + return Rigid(fn(self._rots), self._trans) + + def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: + """Applies a Tensor -> Tensor function to the stored translation. + + Args: + fn: + A function of type Tensor -> Tensor to be applied to the + translation + Returns: + A transformation object with a transformed translation. + """ + return Rigid(self._rots, fn(self._trans)) + + def scale_translation(self, trans_scale_factor: float) -> Rigid: + """Scales the translation by a constant factor. + + Args: + trans_scale_factor: + The constant factor + Returns: + A transformation object with a scaled translation. + """ + fn = lambda t: t * trans_scale_factor + return self.apply_trans_fn(fn) + + def stop_rot_gradient(self) -> Rigid: + """Detaches the underlying rotation object + + Returns: + A transformation object with detached rotations + """ + fn = lambda r: r.detach() + return self.apply_rot_fn(fn) + + @staticmethod + def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): + """Returns a transformation object from reference coordinates. + + Note that this method does not take care of symmetries. If you + provide the atom positions in the non-standard way, the N atom will + end up not at [-0.527250, 1.359329, 0.0] but instead at + [-0.527250, -1.359329, 0.0]. You need to take care of such cases in + your code. + + Args: + n_xyz: A [*, 3] tensor of nitrogen xyz coordinates. + ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates. + c_xyz: A [*, 3] tensor of carbon xyz coordinates. + + Returns: + A transformation object. After applying the translation and + rotation to the reference backbone, the coordinates will + approximately equal to the input coordinates. + """ + translation = -1 * ca_xyz + n_xyz = n_xyz + translation + c_xyz = c_xyz + translation + + c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + c_x**2 + c_y**2) + sin_c1 = -c_y / norm + cos_c1 = c_x / norm + zeros = sin_c1.new_zeros(sin_c1.shape) + ones = sin_c1.new_ones(sin_c1.shape) + + c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) + c1_rots[..., 0, 0] = cos_c1 + c1_rots[..., 0, 1] = -1 * sin_c1 + c1_rots[..., 1, 0] = sin_c1 + c1_rots[..., 1, 1] = cos_c1 + c1_rots[..., 2, 2] = 1 + + norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2) + sin_c2 = c_z / norm + cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm + + c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + c2_rots[..., 0, 0] = cos_c2 + c2_rots[..., 0, 2] = sin_c2 + c2_rots[..., 1, 1] = 1 + c2_rots[..., 2, 0] = -1 * sin_c2 + c2_rots[..., 2, 2] = cos_c2 + + c_rots = rot_matmul(c2_rots, c1_rots) + n_xyz = rot_vec_mul(c_rots, n_xyz) + + _, n_y, n_z = [n_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + n_y**2 + n_z**2) + sin_n = -n_z / norm + cos_n = n_y / norm + + n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + n_rots[..., 0, 0] = 1 + n_rots[..., 1, 1] = cos_n + n_rots[..., 1, 2] = -1 * sin_n + n_rots[..., 2, 1] = sin_n + n_rots[..., 2, 2] = cos_n + + rots = rot_matmul(n_rots, c_rots) + + rots = rots.transpose(-1, -2) + translation = -1 * translation + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, translation) + + def cuda(self) -> Rigid: + """Moves the transformation object to GPU memory (CUDA or MPS if available) + + Returns: + A version of the transformation on GPU (CUDA or MPS) + """ + # Prefer CUDA if available, else MPS, else CPU + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + return self.to(device) + + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Rigid: + """Moves the transformation object to the specified device and dtype. + + Args: + device: torch.device or None + dtype: torch.dtype or None + Returns: + A version of the transformation on the specified device/dtype + """ + rots = self._rots + trans = self._trans + if rots is not None: + rots = rots.to(device=device, dtype=dtype) + if trans is not None: + trans = trans.to(device=device, dtype=dtype) + return Rigid(rots, trans) + + def detach(self) -> Rotation: + """Returns a copy of the Rotation whose underlying Tensor has been + detached from its torch graph. + + Returns: + A copy of the Rotation whose underlying Tensor has been detached + from its torch graph + """ + if self._rot_mats is not None: + return Rotation(rot_mats=self._rot_mats.detach(), quats=None) + elif self._quats is not None: + return Rotation( + rot_mats=None, + quats=self._quats.detach(), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") diff --git a/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/stereo_chemical_props.txt b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/stereo_chemical_props.txt new file mode 100644 index 0000000000..25262efd76 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/stereo_chemical_props.txt @@ -0,0 +1,345 @@ +Bond Residue Mean StdDev +CA-CB ALA 1.520 0.021 +N-CA ALA 1.459 0.020 +CA-C ALA 1.525 0.026 +C-O ALA 1.229 0.019 +CA-CB ARG 1.535 0.022 +CB-CG ARG 1.521 0.027 +CG-CD ARG 1.515 0.025 +CD-NE ARG 1.460 0.017 +NE-CZ ARG 1.326 0.013 +CZ-NH1 ARG 1.326 0.013 +CZ-NH2 ARG 1.326 0.013 +N-CA ARG 1.459 0.020 +CA-C ARG 1.525 0.026 +C-O ARG 1.229 0.019 +CA-CB ASN 1.527 0.026 +CB-CG ASN 1.506 0.023 +CG-OD1 ASN 1.235 0.022 +CG-ND2 ASN 1.324 0.025 +N-CA ASN 1.459 0.020 +CA-C ASN 1.525 0.026 +C-O ASN 1.229 0.019 +CA-CB ASP 1.535 0.022 +CB-CG ASP 1.513 0.021 +CG-OD1 ASP 1.249 0.023 +CG-OD2 ASP 1.249 0.023 +N-CA ASP 1.459 0.020 +CA-C ASP 1.525 0.026 +C-O ASP 1.229 0.019 +CA-CB CYS 1.526 0.013 +CB-SG CYS 1.812 0.016 +N-CA CYS 1.459 0.020 +CA-C CYS 1.525 0.026 +C-O CYS 1.229 0.019 +CA-CB GLU 1.535 0.022 +CB-CG GLU 1.517 0.019 +CG-CD GLU 1.515 0.015 +CD-OE1 GLU 1.252 0.011 +CD-OE2 GLU 1.252 0.011 +N-CA GLU 1.459 0.020 +CA-C GLU 1.525 0.026 +C-O GLU 1.229 0.019 +CA-CB GLN 1.535 0.022 +CB-CG GLN 1.521 0.027 +CG-CD GLN 1.506 0.023 +CD-OE1 GLN 1.235 0.022 +CD-NE2 GLN 1.324 0.025 +N-CA GLN 1.459 0.020 +CA-C GLN 1.525 0.026 +C-O GLN 1.229 0.019 +N-CA GLY 1.456 0.015 +CA-C GLY 1.514 0.016 +C-O GLY 1.232 0.016 +CA-CB HIS 1.535 0.022 +CB-CG HIS 1.492 0.016 +CG-ND1 HIS 1.369 0.015 +CG-CD2 HIS 1.353 0.017 +ND1-CE1 HIS 1.343 0.025 +CD2-NE2 HIS 1.415 0.021 +CE1-NE2 HIS 1.322 0.023 +N-CA HIS 1.459 0.020 +CA-C HIS 1.525 0.026 +C-O HIS 1.229 0.019 +CA-CB ILE 1.544 0.023 +CB-CG1 ILE 1.536 0.028 +CB-CG2 ILE 1.524 0.031 +CG1-CD1 ILE 1.500 0.069 +N-CA ILE 1.459 0.020 +CA-C ILE 1.525 0.026 +C-O ILE 1.229 0.019 +CA-CB LEU 1.533 0.023 +CB-CG LEU 1.521 0.029 +CG-CD1 LEU 1.514 0.037 +CG-CD2 LEU 1.514 0.037 +N-CA LEU 1.459 0.020 +CA-C LEU 1.525 0.026 +C-O LEU 1.229 0.019 +CA-CB LYS 1.535 0.022 +CB-CG LYS 1.521 0.027 +CG-CD LYS 1.520 0.034 +CD-CE LYS 1.508 0.025 +CE-NZ LYS 1.486 0.025 +N-CA LYS 1.459 0.020 +CA-C LYS 1.525 0.026 +C-O LYS 1.229 0.019 +CA-CB MET 1.535 0.022 +CB-CG MET 1.509 0.032 +CG-SD MET 1.807 0.026 +SD-CE MET 1.774 0.056 +N-CA MET 1.459 0.020 +CA-C MET 1.525 0.026 +C-O MET 1.229 0.019 +CA-CB PHE 1.535 0.022 +CB-CG PHE 1.509 0.017 +CG-CD1 PHE 1.383 0.015 +CG-CD2 PHE 1.383 0.015 +CD1-CE1 PHE 1.388 0.020 +CD2-CE2 PHE 1.388 0.020 +CE1-CZ PHE 1.369 0.019 +CE2-CZ PHE 1.369 0.019 +N-CA PHE 1.459 0.020 +CA-C PHE 1.525 0.026 +C-O PHE 1.229 0.019 +CA-CB PRO 1.531 0.020 +CB-CG PRO 1.495 0.050 +CG-CD PRO 1.502 0.033 +CD-N PRO 1.474 0.014 +N-CA PRO 1.468 0.017 +CA-C PRO 1.524 0.020 +C-O PRO 1.228 0.020 +CA-CB SER 1.525 0.015 +CB-OG SER 1.418 0.013 +N-CA SER 1.459 0.020 +CA-C SER 1.525 0.026 +C-O SER 1.229 0.019 +CA-CB THR 1.529 0.026 +CB-OG1 THR 1.428 0.020 +CB-CG2 THR 1.519 0.033 +N-CA THR 1.459 0.020 +CA-C THR 1.525 0.026 +C-O THR 1.229 0.019 +CA-CB TRP 1.535 0.022 +CB-CG TRP 1.498 0.018 +CG-CD1 TRP 1.363 0.014 +CG-CD2 TRP 1.432 0.017 +CD1-NE1 TRP 1.375 0.017 +NE1-CE2 TRP 1.371 0.013 +CD2-CE2 TRP 1.409 0.012 +CD2-CE3 TRP 1.399 0.015 +CE2-CZ2 TRP 1.393 0.017 +CE3-CZ3 TRP 1.380 0.017 +CZ2-CH2 TRP 1.369 0.019 +CZ3-CH2 TRP 1.396 0.016 +N-CA TRP 1.459 0.020 +CA-C TRP 1.525 0.026 +C-O TRP 1.229 0.019 +CA-CB TYR 1.535 0.022 +CB-CG TYR 1.512 0.015 +CG-CD1 TYR 1.387 0.013 +CG-CD2 TYR 1.387 0.013 +CD1-CE1 TYR 1.389 0.015 +CD2-CE2 TYR 1.389 0.015 +CE1-CZ TYR 1.381 0.013 +CE2-CZ TYR 1.381 0.013 +CZ-OH TYR 1.374 0.017 +N-CA TYR 1.459 0.020 +CA-C TYR 1.525 0.026 +C-O TYR 1.229 0.019 +CA-CB VAL 1.543 0.021 +CB-CG1 VAL 1.524 0.021 +CB-CG2 VAL 1.524 0.021 +N-CA VAL 1.459 0.020 +CA-C VAL 1.525 0.026 +C-O VAL 1.229 0.019 +- + +Angle Residue Mean StdDev +N-CA-CB ALA 110.1 1.4 +CB-CA-C ALA 110.1 1.5 +N-CA-C ALA 111.0 2.7 +CA-C-O ALA 120.1 2.1 +N-CA-CB ARG 110.6 1.8 +CB-CA-C ARG 110.4 2.0 +CA-CB-CG ARG 113.4 2.2 +CB-CG-CD ARG 111.6 2.6 +CG-CD-NE ARG 111.8 2.1 +CD-NE-CZ ARG 123.6 1.4 +NE-CZ-NH1 ARG 120.3 0.5 +NE-CZ-NH2 ARG 120.3 0.5 +NH1-CZ-NH2 ARG 119.4 1.1 +N-CA-C ARG 111.0 2.7 +CA-C-O ARG 120.1 2.1 +N-CA-CB ASN 110.6 1.8 +CB-CA-C ASN 110.4 2.0 +CA-CB-CG ASN 113.4 2.2 +CB-CG-ND2 ASN 116.7 2.4 +CB-CG-OD1 ASN 121.6 2.0 +ND2-CG-OD1 ASN 121.9 2.3 +N-CA-C ASN 111.0 2.7 +CA-C-O ASN 120.1 2.1 +N-CA-CB ASP 110.6 1.8 +CB-CA-C ASP 110.4 2.0 +CA-CB-CG ASP 113.4 2.2 +CB-CG-OD1 ASP 118.3 0.9 +CB-CG-OD2 ASP 118.3 0.9 +OD1-CG-OD2 ASP 123.3 1.9 +N-CA-C ASP 111.0 2.7 +CA-C-O ASP 120.1 2.1 +N-CA-CB CYS 110.8 1.5 +CB-CA-C CYS 111.5 1.2 +CA-CB-SG CYS 114.2 1.1 +N-CA-C CYS 111.0 2.7 +CA-C-O CYS 120.1 2.1 +N-CA-CB GLU 110.6 1.8 +CB-CA-C GLU 110.4 2.0 +CA-CB-CG GLU 113.4 2.2 +CB-CG-CD GLU 114.2 2.7 +CG-CD-OE1 GLU 118.3 2.0 +CG-CD-OE2 GLU 118.3 2.0 +OE1-CD-OE2 GLU 123.3 1.2 +N-CA-C GLU 111.0 2.7 +CA-C-O GLU 120.1 2.1 +N-CA-CB GLN 110.6 1.8 +CB-CA-C GLN 110.4 2.0 +CA-CB-CG GLN 113.4 2.2 +CB-CG-CD GLN 111.6 2.6 +CG-CD-OE1 GLN 121.6 2.0 +CG-CD-NE2 GLN 116.7 2.4 +OE1-CD-NE2 GLN 121.9 2.3 +N-CA-C GLN 111.0 2.7 +CA-C-O GLN 120.1 2.1 +N-CA-C GLY 113.1 2.5 +CA-C-O GLY 120.6 1.8 +N-CA-CB HIS 110.6 1.8 +CB-CA-C HIS 110.4 2.0 +CA-CB-CG HIS 113.6 1.7 +CB-CG-ND1 HIS 123.2 2.5 +CB-CG-CD2 HIS 130.8 3.1 +CG-ND1-CE1 HIS 108.2 1.4 +ND1-CE1-NE2 HIS 109.9 2.2 +CE1-NE2-CD2 HIS 106.6 2.5 +NE2-CD2-CG HIS 109.2 1.9 +CD2-CG-ND1 HIS 106.0 1.4 +N-CA-C HIS 111.0 2.7 +CA-C-O HIS 120.1 2.1 +N-CA-CB ILE 110.8 2.3 +CB-CA-C ILE 111.6 2.0 +CA-CB-CG1 ILE 111.0 1.9 +CB-CG1-CD1 ILE 113.9 2.8 +CA-CB-CG2 ILE 110.9 2.0 +CG1-CB-CG2 ILE 111.4 2.2 +N-CA-C ILE 111.0 2.7 +CA-C-O ILE 120.1 2.1 +N-CA-CB LEU 110.4 2.0 +CB-CA-C LEU 110.2 1.9 +CA-CB-CG LEU 115.3 2.3 +CB-CG-CD1 LEU 111.0 1.7 +CB-CG-CD2 LEU 111.0 1.7 +CD1-CG-CD2 LEU 110.5 3.0 +N-CA-C LEU 111.0 2.7 +CA-C-O LEU 120.1 2.1 +N-CA-CB LYS 110.6 1.8 +CB-CA-C LYS 110.4 2.0 +CA-CB-CG LYS 113.4 2.2 +CB-CG-CD LYS 111.6 2.6 +CG-CD-CE LYS 111.9 3.0 +CD-CE-NZ LYS 111.7 2.3 +N-CA-C LYS 111.0 2.7 +CA-C-O LYS 120.1 2.1 +N-CA-CB MET 110.6 1.8 +CB-CA-C MET 110.4 2.0 +CA-CB-CG MET 113.3 1.7 +CB-CG-SD MET 112.4 3.0 +CG-SD-CE MET 100.2 1.6 +N-CA-C MET 111.0 2.7 +CA-C-O MET 120.1 2.1 +N-CA-CB PHE 110.6 1.8 +CB-CA-C PHE 110.4 2.0 +CA-CB-CG PHE 113.9 2.4 +CB-CG-CD1 PHE 120.8 0.7 +CB-CG-CD2 PHE 120.8 0.7 +CD1-CG-CD2 PHE 118.3 1.3 +CG-CD1-CE1 PHE 120.8 1.1 +CG-CD2-CE2 PHE 120.8 1.1 +CD1-CE1-CZ PHE 120.1 1.2 +CD2-CE2-CZ PHE 120.1 1.2 +CE1-CZ-CE2 PHE 120.0 1.8 +N-CA-C PHE 111.0 2.7 +CA-C-O PHE 120.1 2.1 +N-CA-CB PRO 103.3 1.2 +CB-CA-C PRO 111.7 2.1 +CA-CB-CG PRO 104.8 1.9 +CB-CG-CD PRO 106.5 3.9 +CG-CD-N PRO 103.2 1.5 +CA-N-CD PRO 111.7 1.4 +N-CA-C PRO 112.1 2.6 +CA-C-O PRO 120.2 2.4 +N-CA-CB SER 110.5 1.5 +CB-CA-C SER 110.1 1.9 +CA-CB-OG SER 111.2 2.7 +N-CA-C SER 111.0 2.7 +CA-C-O SER 120.1 2.1 +N-CA-CB THR 110.3 1.9 +CB-CA-C THR 111.6 2.7 +CA-CB-OG1 THR 109.0 2.1 +CA-CB-CG2 THR 112.4 1.4 +OG1-CB-CG2 THR 110.0 2.3 +N-CA-C THR 111.0 2.7 +CA-C-O THR 120.1 2.1 +N-CA-CB TRP 110.6 1.8 +CB-CA-C TRP 110.4 2.0 +CA-CB-CG TRP 113.7 1.9 +CB-CG-CD1 TRP 127.0 1.3 +CB-CG-CD2 TRP 126.6 1.3 +CD1-CG-CD2 TRP 106.3 0.8 +CG-CD1-NE1 TRP 110.1 1.0 +CD1-NE1-CE2 TRP 109.0 0.9 +NE1-CE2-CD2 TRP 107.3 1.0 +CE2-CD2-CG TRP 107.3 0.8 +CG-CD2-CE3 TRP 133.9 0.9 +NE1-CE2-CZ2 TRP 130.4 1.1 +CE3-CD2-CE2 TRP 118.7 1.2 +CD2-CE2-CZ2 TRP 122.3 1.2 +CE2-CZ2-CH2 TRP 117.4 1.0 +CZ2-CH2-CZ3 TRP 121.6 1.2 +CH2-CZ3-CE3 TRP 121.2 1.1 +CZ3-CE3-CD2 TRP 118.8 1.3 +N-CA-C TRP 111.0 2.7 +CA-C-O TRP 120.1 2.1 +N-CA-CB TYR 110.6 1.8 +CB-CA-C TYR 110.4 2.0 +CA-CB-CG TYR 113.4 1.9 +CB-CG-CD1 TYR 121.0 0.6 +CB-CG-CD2 TYR 121.0 0.6 +CD1-CG-CD2 TYR 117.9 1.1 +CG-CD1-CE1 TYR 121.3 0.8 +CG-CD2-CE2 TYR 121.3 0.8 +CD1-CE1-CZ TYR 119.8 0.9 +CD2-CE2-CZ TYR 119.8 0.9 +CE1-CZ-CE2 TYR 119.8 1.6 +CE1-CZ-OH TYR 120.1 2.7 +CE2-CZ-OH TYR 120.1 2.7 +N-CA-C TYR 111.0 2.7 +CA-C-O TYR 120.1 2.1 +N-CA-CB VAL 111.5 2.2 +CB-CA-C VAL 111.4 1.9 +CA-CB-CG1 VAL 110.9 1.5 +CA-CB-CG2 VAL 110.9 1.5 +CG1-CB-CG2 VAL 110.9 1.6 +N-CA-C VAL 111.0 2.7 +CA-C-O VAL 120.1 2.1 +- + +Non-bonded distance Minimum Dist Tolerance +C-C 3.4 1.5 +C-N 3.25 1.5 +C-S 3.5 1.5 +C-O 3.22 1.5 +N-N 3.1 1.5 +N-S 3.35 1.5 +N-O 3.07 1.5 +O-S 3.32 1.5 +O-O 3.04 1.5 +S-S 2.03 1.0 +- diff --git a/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/tensor_utils.py b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/tensor_utils.py new file mode 100644 index 0000000000..1f3e9876f7 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/minifold_utils/tensor_utils.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import List + +import torch +import torch.nn as nn + + +def add(m1, m2, inplace): + # The first operation in a checkpoint can't be in-place, but it's + # nice to have in-place addition during inference. Thus... + if not inplace: + m1 = m1 + m2 + else: + m1 += m2 + + return m1 + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def masked_mean(mask, value, dim, eps=1e-4): + mask = mask.expand(*value.shape) + return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): + boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device) + dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)) + return torch.bucketize(dists, boundaries) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if type(v) is dict: + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def one_hot(x, v_bins): + reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) + diffs = x[..., None] - reshaped_bins + am = torch.argmin(torch.abs(diffs), dim=-1) + return nn.functional.one_hot(am, num_classes=len(v_bins)).float() + + +def batched_gather(data, inds, dim=0, no_batch_dims=0): + ranges = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + return data[ranges] + + +# With tree_map, a poor man's JAX tree_map +def dict_map(fn, dic, leaf_type): + new_dict = {} + for k, v in dic.items(): + if type(v) is dict: + new_dict[k] = dict_map(fn, v, leaf_type) + else: + new_dict[k] = tree_map(fn, v, leaf_type) + + return new_dict + + +def tree_map(fn, tree, leaf_type): + if isinstance(tree, dict): + return dict_map(fn, tree, leaf_type) + elif isinstance(tree, list): + return [tree_map(fn, x, leaf_type) for x in tree] + elif isinstance(tree, tuple): + return tuple([tree_map(fn, x, leaf_type) for x in tree]) + elif isinstance(tree, leaf_type): + return fn(tree) + else: + print(type(tree)) + raise ValueError("Not supported") + + +tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) diff --git a/bionemo-recipes/recipes/esm2_minifold_te/miniformer_te.py b/bionemo-recipes/recipes/esm2_minifold_te/miniformer_te.py new file mode 100644 index 0000000000..099852ad67 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/miniformer_te.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from contextlib import nullcontext +from typing import ContextManager + +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformer_engine.pytorch as te +from torch import Tensor + +from minifold_utils import init +from quantization import ComponentPrecisionConfig +from te_utils import te_layernorm_nd, te_linear_nd, tri_mul_bmm + + +class TransitionUpdateTE(nn.Module): + """TE version of TransitionUpdate: two-layer MLP with residual connection. + + Replaces raw nn.Parameter + F.linear with te.LayerNorm + te.Linear modules. + """ + + def __init__( + self, + dim: int = 128, + hidden: int = 512, + params_dtype: torch.dtype = torch.float32, + component_precision: ComponentPrecisionConfig | None = None, + ): + super().__init__() + self._component_precision = component_precision + self.norm = te.LayerNorm(dim, eps=1e-5, params_dtype=params_dtype) + self.fc1 = te.Linear(dim, hidden, params_dtype=params_dtype) + self.fc2 = te.Linear(hidden, dim, params_dtype=params_dtype) + + # Match original initialization + init.bias_init_one_(self.norm.weight) + init.bias_init_zero_(self.norm.bias) + init.he_normal_init_(self.fc1.weight) + init.bias_init_zero_(self.fc1.bias) + init.final_init_(self.fc2.weight) + init.bias_init_zero_(self.fc2.bias) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x: Input tensor of shape (B, N, N, D). + + Returns: + Output tensor of shape (B, N, N, D). + """ + ctx = self._component_precision.get_context("ffn") if self._component_precision else nullcontext() + with ctx: + x = te_layernorm_nd(self.norm, x) + x = te_linear_nd(self.fc1, x) + x = F.relu(x) + x = te_linear_nd(self.fc2, x) + return x + + +class TriangularUpdateTE(nn.Module): + """TE version of TriangularUpdate. + + Replaces raw nn.Parameter + F.linear/F.layer_norm with te.LayerNorm + te.Linear. + The einsum triangular multiplication operations remain in FP32. + """ + + def __init__( + self, + dim: int = 128, + params_dtype: torch.dtype = torch.float32, + component_precision: ComponentPrecisionConfig | None = None, + ): + super().__init__() + self._component_precision = component_precision + + # Input gating: LayerNorm + two parallel linears (projection and gate) + self.input_norm = te.LayerNorm(dim, eps=1e-5, params_dtype=params_dtype) + self.pi = te.Linear(dim, dim, params_dtype=params_dtype) # input projection + self.gi = te.Linear(dim, dim, params_dtype=params_dtype) # input gate (sigmoid) + + # Output gating: LayerNorm + two parallel linears + self.output_norm = te.LayerNorm(dim // 2, eps=1e-5, params_dtype=params_dtype) + self.po = te.Linear(dim // 2, dim, params_dtype=params_dtype) # output projection + self.go = te.Linear(dim // 2, dim, params_dtype=params_dtype) # output gate (sigmoid) + + # Match original initialization + init.bias_init_one_(self.input_norm.weight) + init.bias_init_zero_(self.input_norm.bias) + + init.lecun_normal_init_(self.pi.weight) + init.bias_init_zero_(self.pi.bias) + init.gating_init_(self.gi.weight) + init.bias_init_one_(self.gi.bias) + + init.bias_init_one_(self.output_norm.weight) + init.bias_init_zero_(self.output_norm.bias) + + init.final_init_(self.po.weight) + init.bias_init_zero_(self.po.bias) + init.gating_init_(self.go.weight) + init.bias_init_one_(self.go.bias) + + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + """Forward pass. + + Args: + x: Input tensor of shape (B, N, N, D). + mask: Mask tensor of shape (B, N, N). + + Returns: + Output tensor of shape (B, N, N, D). + """ + cp = self._component_precision + + def _proj_ctx(): + return cp.get_context("tri_proj") if cp else nullcontext() + + def _gate_ctx(): + return cp.get_context("tri_gate") if cp else nullcontext() + + # Input gating: D -> D + x = te_layernorm_nd(self.input_norm, x) + with _proj_ctx(): + pi_out = te_linear_nd(self.pi, x) + with _gate_ctx(): + gi_out = te_linear_nd(self.gi, x).sigmoid() + x = pi_out * gi_out + + # Apply mask + x = x * mask.unsqueeze(-1) + + # Triangular multiplication via batched GEMM + tri_mode = cp.tri_einsum if cp else "off" + use_fp32 = tri_mode == "off" + x_in = x.float() if use_fp32 else x + a1, b1, a2, b2 = torch.chunk(x_in, 4, dim=-1) + x1 = tri_mul_bmm(a1, b1, k_dim=2, mode=tri_mode) # "bikd,bjkd->bijd" + x2 = tri_mul_bmm(a2, b2, k_dim=1, mode=tri_mode) # "bkid,bkjd->bijd" + x = torch.cat([x1, x2], dim=-1) + if use_fp32: + x = x.to(mask.dtype if mask.is_floating_point() else torch.float32) + + # Output gating: D/2 -> D + x = te_layernorm_nd(self.output_norm, x) + with _proj_ctx(): + po_out = te_linear_nd(self.po, x) + with _gate_ctx(): + go_out = te_linear_nd(self.go, x).sigmoid() + x = po_out * go_out + + return x + + +class BlockTE(nn.Module): + """TE version of a MiniFormer block: TriangularUpdate + TransitionUpdate.""" + + def __init__( + self, + dim: int = 128, + params_dtype: torch.dtype = torch.float32, + component_precision: ComponentPrecisionConfig | None = None, + ): + super().__init__() + self.triangular = TriangularUpdateTE(dim, params_dtype=params_dtype, component_precision=component_precision) + self.transition = TransitionUpdateTE( + dim, dim * 4, params_dtype=params_dtype, component_precision=component_precision + ) + + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + """Forward pass. + + Args: + x: Input tensor of shape (B, N, N, D). + mask: Mask tensor of shape (B, N, N). + + Returns: + Output tensor of shape (B, N, N, D). + """ + x = x + self.triangular(x, mask) + x = x + self.transition(x) + return x + + +class MiniFormerTE(nn.Module): + """TE version of the MiniFormer module with optional per-block FP8/FP4 precision.""" + + def __init__( + self, + dim: int = 128, + blocks: int = 48, + params_dtype: torch.dtype = torch.float32, + block_precision: list[str | None] | None = None, + fp8_recipe=None, + fp4_recipe=None, + component_precision: ComponentPrecisionConfig | None = None, + ): + super().__init__() + self.blocks = nn.ModuleList( + [BlockTE(dim, params_dtype=params_dtype, component_precision=component_precision) for _ in range(blocks)] + ) + self._block_precision = block_precision + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + + if block_precision is not None and len(block_precision) != blocks: + raise ValueError(f"block_precision length ({len(block_precision)}) must match number of blocks ({blocks})") + + def get_autocast_context(self, block_number: int | None, outer: bool = False) -> ContextManager: + """Return the appropriate TE autocast context manager for a given block. + + Args: + block_number: The 0-indexed block number. + outer: Whether to return a global te.autocast() context to wrap the entire block stack. + """ + if self._block_precision is None: + return nullcontext() + + if outer: + if "fp8" not in self._block_precision: + return nullcontext() + if self._fp8_recipe is None: + warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning) + return te.autocast(enabled=True, recipe=self._fp8_recipe) + + precision = self._block_precision[block_number] + recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision) + + if precision == "fp8": + if recipe is None: + warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning) + return te.autocast(enabled=True, recipe=recipe) + if precision == "fp4": + if recipe is None: + raise RuntimeError("No FP4 recipe provided, but block precision is set to FP4.") + return te.autocast(enabled=True, recipe=recipe) + return te.autocast(enabled=False) + + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + """Forward pass. + + Args: + x: Input tensor of shape (B, N, N, D). + mask: Mask tensor of shape (B, N, N). + + Returns: + Output tensor of shape (B, N, N, D). + """ + with self.get_autocast_context(None, outer=True): + for block_idx, block in enumerate(self.blocks): + with self.get_autocast_context(block_idx): + x = block(x, mask) + return x diff --git a/bionemo-recipes/recipes/esm2_minifold_te/model_te.py b/bionemo-recipes/recipes/esm2_minifold_te/model_te.py new file mode 100644 index 0000000000..d9f9acf9da --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/model_te.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext + +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformer_engine.pytorch as te + +from miniformer_te import MiniFormerTE +from te_utils import te_layernorm_nd, te_linear_nd + + +class RelativePosition(nn.Module): + def __init__(self, bins, pairwise_state_dim): + super().__init__() + self.bins = bins + + # Note an additional offset is used so that the 0th position + # is reserved for masked pairs. + self.embedding = torch.nn.Embedding(2 * bins + 2, pairwise_state_dim) + + def forward(self, residue_index, mask): + """Input: + residue_index: B x L tensor of indices (dytpe=torch.long) + mask: B x L tensor of booleans + + Output: + pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings + """ + diff = residue_index[:, None, :] - residue_index[:, :, None] + diff = diff.clamp(-self.bins, self.bins) + diff = diff + self.bins + 1 # Add 1 to adjust for padding index. + diff[mask == 0] = 0 + output = self.embedding(diff) + return output + + +class SequenceToPairTE(nn.Module): + """TE version of SequenceToPair.""" + + def __init__( + self, + sequence_state_dim, + inner_dim, + pairwise_state_dim, + params_dtype=torch.float32, + component_precision=None, + ): + super().__init__() + self._component_precision = component_precision + self.layernorm = te.LayerNorm(sequence_state_dim, eps=1e-5, params_dtype=params_dtype) + self.proj = te.Linear(sequence_state_dim, inner_dim * 2, bias=True, params_dtype=params_dtype) + self.o_proj = te.Linear(2 * inner_dim, pairwise_state_dim, bias=True, params_dtype=params_dtype) + + torch.nn.init.zeros_(self.proj.bias) + torch.nn.init.zeros_(self.o_proj.bias) + + def forward(self, sequence_state): + """Forward pass. + + Args: + sequence_state: B x L x sequence_state_dim + + Returns: + pairwise_state: B x L x L x pairwise_state_dim + """ + cp = self._component_precision + ctx = cp.get_context("seq_proj") if cp else nullcontext() + with ctx: + s = te_layernorm_nd(self.layernorm, sequence_state) + s = te_linear_nd(self.proj, s) + q, k = s.chunk(2, dim=-1) + + prod = q[:, None, :, :] * k[:, :, None, :] + diff = q[:, None, :, :] - k[:, :, None, :] + + x = torch.cat([prod, diff], dim=-1) + x = te_linear_nd(self.o_proj, x) + + return x + + +class PairToSequenceTE(nn.Module): + """TE version of PairToSequence.""" + + def __init__(self, c_z=128, c_s=1024, c_s_out=1024, params_dtype=torch.float32): + super().__init__() + self.s_z_norm = te.LayerNorm(c_z, eps=1e-5, params_dtype=params_dtype) + self.s_z_fc1 = te.Linear(c_z, c_z, params_dtype=params_dtype) + self.s_z_fc2 = te.Linear(c_z, c_z, params_dtype=params_dtype) + self.combiner = te.Linear(2 * c_z + c_s, c_s_out, params_dtype=params_dtype) + + def forward(self, s_z, s_s_in, pair_mask): + """Forward pass. + + Args: + s_z: Pair representation (B, L, L, c_z). + s_s_in: Sequence representation (B, L, c_s). + pair_mask: Pair mask (B, L, L). + + Returns: + Sequence representation (B, L, c_s_out). + """ + # MLP on pair features + s_z = te_layernorm_nd(self.s_z_norm, s_z) + s_z = te_linear_nd(self.s_z_fc1, s_z) + s_z = F.relu(s_z) + s_z = te_linear_nd(self.s_z_fc2, s_z) + + # Apply mask + s_z = s_z * pair_mask[..., None] + + # Column average + norm = pair_mask.sum(dim=2).clamp(min=1) + s_s_c = s_z.sum(dim=2) / norm[..., None] + + # Row average + norm = pair_mask.sum(dim=1).clamp(min=1) + s_s_r = s_z.sum(dim=1) / norm[..., None] + + # Combine with initial s_s + s_s = te_linear_nd(self.combiner, torch.cat([s_s_c, s_s_r, s_s_in], dim=-1)) + return s_s + + +class FoldingTrunkTE(nn.Module): + """TE version of FoldingTrunk.""" + + def __init__( + self, + c_s, + c_z, + bins, + disto_bins=64, + num_layers=1, + params_dtype=torch.float32, + block_precision=None, + fp8_recipe=None, + fp4_recipe=None, + component_precision=None, + ): + super().__init__() + self._component_precision = component_precision + self.disto_bins = disto_bins + self.positional_embedding = RelativePosition(bins, c_z) + self.seq_to_pair = SequenceToPairTE( + c_s, c_z // 2, c_z, params_dtype=params_dtype, component_precision=component_precision + ) + self.projection = te.Linear(c_z * 3, c_z, params_dtype=params_dtype) + self.recycle = te.Linear(disto_bins, c_z, params_dtype=params_dtype) + self.miniformer = MiniFormerTE( + c_z, + blocks=num_layers, + params_dtype=params_dtype, + block_precision=block_precision, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + component_precision=component_precision, + ) + self.fc_out_1 = te.Linear(c_z, c_z, params_dtype=params_dtype) + self.fc_out_2 = te.Linear(c_z, disto_bins, params_dtype=params_dtype) + + torch.nn.init.zeros_(self.seq_to_pair.o_proj.weight) + torch.nn.init.zeros_(self.seq_to_pair.o_proj.bias) + + def forward(self, s_s, s_z, mask, num_recycling=0): + """Forward pass. + + Args: + s_s: Sequence features (B, L, C). + s_z: Pair features (B, L, L, C). + mask: Residue mask (B, L). + num_recycling: Number of recycling iterations. + + Returns: + Tuple of (predictions, pair representation). + """ + # Make pairwise mask + pair_mask = mask[:, None, :] * mask[:, :, None] + + # Add positional embeddings + residx = torch.arange(s_s.shape[1], device=s_s.device) + residx = residx.unsqueeze(0).expand(s_s.shape[0], -1) + + # Concatenate and project + s_z = torch.cat( + [ + s_z, + self.seq_to_pair(s_s), + self.positional_embedding(residx, mask=pair_mask), + ], + dim=-1, + ) + s_z = te_linear_nd(self.projection, s_z) + + # Set masks to floats + mask = mask.to(s_z) + pair_mask = pair_mask.to(s_z) + + # Initialize binned distance + shape = tuple(s_z.shape[:3]) + (self.disto_bins,) + dists = torch.zeros(shape, device=s_z.device, dtype=s_z.dtype) + + # Perform folding rounds + for i in range(num_recycling + 1): + with torch.set_grad_enabled(self.training and (i == num_recycling)): + if self.training and (i == num_recycling) and torch.is_autocast_enabled(): + torch.clear_autocast_cache() + + # Compute blocks + s_z_c = s_z + te_linear_nd(self.recycle, dists) + s_z_c = self.miniformer(s_z_c, pair_mask) + + # Output MLP + cp = self._component_precision + dist_ctx = cp.get_context("dist_head") if cp else nullcontext() + with dist_ctx: + fc_out = te_linear_nd(self.fc_out_1, s_z_c + s_z_c.transpose(1, 2)) + fc_out = F.relu(fc_out) + preds = te_linear_nd(self.fc_out_2, fc_out) + + # Compute binned distance for recycling + dists = preds.detach().argmax(dim=-1) + dists = nn.functional.one_hot(dists, self.disto_bins).to(s_z) + + return preds, s_z_c diff --git a/bionemo-recipes/recipes/esm2_minifold_te/modeling_esm2_minifold_te.py b/bionemo-recipes/recipes/esm2_minifold_te/modeling_esm2_minifold_te.py new file mode 100644 index 0000000000..edc0b038d6 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/modeling_esm2_minifold_te.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ESM2-MiniFold TE: End-to-end protein structure prediction model. + +Combines a frozen HuggingFace ESM-2 backbone with a TE-based MiniFold folding head. +The ESM-2 backbone extracts per-residue embeddings and pairwise attention maps, +which are projected and fed into the FoldingTrunkTE for distogram prediction. + +Optionally includes the StructureModuleTE for full 3D structure prediction (Stage 2). +""" + +from contextlib import nullcontext + +import torch +import torch.nn as nn +import transformer_engine.pytorch as te + +from esm_backbone import ESM2Backbone +from heads_te import AuxiliaryHeadsTE +from minifold_utils.feats import atom14_to_atom37 +from minifold_utils.tensor_utils import tensor_tree_map +from model_te import FoldingTrunkTE, PairToSequenceTE +from quantization import ComponentPrecisionConfig +from structure_te import StructureModuleTE +from te_utils import te_linear_nd + + +class ESM2MiniFoldTE(nn.Module): + """ESM-2 backbone + MiniFold TE folding head. + + Stage 1: ESM-2 (frozen) -> projections -> FoldingTrunkTE -> distogram predictions + Stage 2: + PairToSequenceTE -> StructureModuleTE -> 3D coordinates + pLDDT + """ + + def __init__( + self, + esm_model_name: str = "facebook/esm2_t33_650M_UR50D", + c_s: int = 1024, + c_z: int = 128, + num_blocks: int = 48, + no_bins: int = 64, + use_structure_module: bool = False, + num_structure_blocks: int = 8, + structure_config: dict | None = None, + params_dtype: torch.dtype = torch.float32, + block_precision: list[str | None] | None = None, + fp8_recipe=None, + fp4_recipe=None, + component_precision: ComponentPrecisionConfig | None = None, + ): + """Initialize ESM2MiniFoldTE. + + Args: + esm_model_name: HuggingFace ESM-2 model name. + c_s: Sequence feature dimension after projection. + c_z: Pair feature dimension after projection. + num_blocks: Number of MiniFormer blocks in the folding trunk. + no_bins: Number of distogram bins. + use_structure_module: Whether to include the structure module (Stage 2). + num_structure_blocks: Number of IPA blocks in the structure module. + structure_config: Optional config dict for auxiliary heads. + params_dtype: Data type for TE layer parameters. + block_precision: Per-block quantization precision list from resolve_layer_precision(). + fp8_recipe: FP8 recipe for TE autocast. + fp4_recipe: FP4 recipe for TE autocast. + component_precision: Per-component precision overrides within FP8/FP4 blocks. + """ + super().__init__() + + self.c_s = c_s + self.c_z = c_z + self.use_structure_module = use_structure_module + self._component_precision = component_precision + + # ESM-2 backbone (frozen) + self.backbone = ESM2Backbone(esm_model_name) + embed_dim = self.backbone.embed_dim + attn_dim = self.backbone.attn_dim + + # Sequence projection: embed_dim -> c_s + self.fc_s_1 = te.Linear(embed_dim, c_s, params_dtype=params_dtype) + self.fc_s_2 = te.Linear(c_s, c_s, params_dtype=params_dtype) + + # Pairwise projection: attn_dim -> c_z + self.fc_z_1 = te.Linear(attn_dim, c_z, params_dtype=params_dtype) + self.fc_z_2 = te.Linear(c_z, c_z, params_dtype=params_dtype) + + # Folding trunk + self.fold = FoldingTrunkTE( + c_s=c_s, + c_z=c_z, + bins=32, + disto_bins=no_bins, + num_layers=num_blocks, + params_dtype=params_dtype, + block_precision=block_precision, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + component_precision=component_precision, + ) + + # Optional structure module (Stage 2) + if use_structure_module: + self.sz_project = PairToSequenceTE(c_z=c_z, c_s=c_s, params_dtype=params_dtype) + self.structure_module = StructureModuleTE( + c_s=c_s, + c_z=c_z, + c_resnet=128, + head_dim=64, + no_heads=16, + no_blocks=num_structure_blocks, + no_resnet_blocks=2, + no_angles=7, + trans_scale_factor=10, + epsilon=1e-5, + inf=1e5, + params_dtype=params_dtype, + ) + if structure_config is not None: + self.aux_heads = AuxiliaryHeadsTE(structure_config["heads"]) + + def forward(self, batch: dict, num_recycling: int = 0) -> dict: + """Forward pass. + + Args: + batch: Dictionary with: + "input_ids": ESM-2 token IDs (B, L). + "attention_mask": Optional attention mask (B, L). + "mask": Residue validity mask (B, L). + "batch_of": Optional OpenFold features for structure module. + num_recycling: Number of recycling iterations. + + Returns: + Dictionary with predictions: + "preds": Distogram logits (B, L, L, no_bins). + "pair": Final pair representation (B, L, L, c_z). + + structure module outputs if use_structure_module. + """ + # Extract ESM-2 embeddings and attention maps + esm_out = self.backbone( + input_ids=batch["input_ids"], + attention_mask=batch.get("attention_mask"), + ) + + # Project sequence embeddings: embed_dim -> c_s + cp = self._component_precision + seq_ctx = cp.get_context("seq_proj") if cp else nullcontext() + with seq_ctx: + s_s = esm_out["representations"] + s_s = te_linear_nd(self.fc_s_1, s_s) + s_s = torch.relu(s_s) + s_s = te_linear_nd(self.fc_s_2, s_s) + + # Project attention maps: attn_dim -> c_z + s_z = esm_out["attentions"] + s_z = te_linear_nd(self.fc_z_1, s_z) + s_z = torch.relu(s_z) + s_z = te_linear_nd(self.fc_z_2, s_z) + + # Run folding trunk + preds, s_z = self.fold( + s_s, + s_z, + mask=batch["mask"], + num_recycling=num_recycling, + ) + + r_dict = {"preds": preds, "pair": s_z} + + # Optional structure module + if self.use_structure_module: + mask = batch["mask"] + pair_mask = mask[:, None, :] * mask[:, :, None] + r_dict["single"] = self.sz_project(s_z, s_s, pair_mask) + + feats = tensor_tree_map(lambda t: t[..., 0], batch["batch_of"]) + + r_dict["sm"] = self.structure_module( + s=r_dict["single"], + z=r_dict["pair"], + aatype=feats["aatype"], + mask=feats["seq_mask"].to(dtype=r_dict["single"].dtype), + ) + + r_dict["final_atom_positions"] = atom14_to_atom37(r_dict["sm"]["positions"][-1], feats) + r_dict["final_atom_mask"] = feats["atom37_atom_exists"] + r_dict["final_affine_tensor"] = r_dict["sm"]["frames"][-1] + + if hasattr(self, "aux_heads"): + r_dict.update(self.aux_heads(r_dict)) + + return r_dict + + def get_folding_head_params(self): + """Return parameters for the folding head (for optimizer param groups).""" + excluded = {"backbone", "structure_module", "aux_heads", "sz_project"} + for name, param in self.named_parameters(): + if not any(name.startswith(prefix) for prefix in excluded): + if param.requires_grad: + yield param + + def get_structure_module_params(self): + """Return parameters for the structure module (for optimizer param groups).""" + for name, param in self.named_parameters(): + if any(name.startswith(prefix) for prefix in ("structure_module", "aux_heads", "sz_project")): + if param.requires_grad: + yield param + + def get_backbone_params(self): + """Return backbone parameters (typically frozen).""" + return self.backbone.parameters() diff --git a/bionemo-recipes/recipes/esm2_minifold_te/perf_logger.py b/bionemo-recipes/recipes/esm2_minifold_te/perf_logger.py new file mode 100644 index 0000000000..91f449b83d --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/perf_logger.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performance logger for ESM2-MiniFold TE structure prediction training.""" + +import logging +import time + +import nvdlfw_inspect.api as debug_api +import torch +import torchmetrics +from omegaconf import DictConfig, OmegaConf +from torch.distributed.tensor import DTensor +from tqdm import tqdm + +import wandb +from distributed_config import DistributedConfig + + +logger = logging.getLogger(__name__) + + +class PerfLogger: + """Logs training metrics (loss, lDDT, grad norm, timing) to stdout and wandb. + + Attributes: + min_loss: The minimum loss seen so far. + """ + + def __init__(self, dist_config: DistributedConfig, args: DictConfig, quant_logger=None): + self._dist_config = dist_config + self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) + + self.min_loss = torch.tensor(float("inf"), device=torch.device(f"cuda:{dist_config.local_rank}")) + self.logging_frequency = args.logger.frequency + self.quant_stats_enabled = args.quant_stats_config.enabled + self._quant_logger = quant_logger + self._log_heatmap = args.quant_stats_config.get("log_heatmap", False) + + metrics_dict = { + "train/loss": torchmetrics.MeanMetric(), + "train/disto_loss": torchmetrics.MeanMetric(), + "train/grad_norm": torchmetrics.MeanMetric(), + "train/learning_rate": torchmetrics.MeanMetric(), + "train/step_time": torchmetrics.MeanMetric(), + "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), + "train/distogram_acc": torchmetrics.MeanMetric(), + "train/contact_precision_8A": torchmetrics.MeanMetric(), + "train/contact_recall_8A": torchmetrics.MeanMetric(), + "train/lddt_from_distogram": torchmetrics.MeanMetric(), + "train/mean_distance_error": torchmetrics.MeanMetric(), + "train/unpadded_tokens_per_sec": torchmetrics.MeanMetric(), + } + + self.metrics = torchmetrics.MetricCollection(metrics_dict) + self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}")) + self.previous_step_time = time.perf_counter() + + if self._dist_config.is_main_process(): + wandb.init(**args.wandb_init_args, config=self._run_config) + self._progress_bar = tqdm(total=args.num_train_steps, desc="Training") + + def log_step( + self, + step: int, + loss: torch.Tensor, + disto_loss: torch.Tensor | None = None, + grad_norm: torch.Tensor | DTensor | float = 0.0, + lr: float = 0.0, + structure_metrics: dict[str, torch.Tensor] | None = None, + unpadded_tokens: float = 0.0, + ): + """Log a training step.""" + with torch.no_grad(): + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.to_local() + + if step % self.logging_frequency == 0 and step > 0: + self.min_loss = torch.minimum(self.min_loss, loss) + elapsed_time, self.previous_step_time = ( + time.perf_counter() - self.previous_step_time, + time.perf_counter(), + ) + step_time = elapsed_time / self.logging_frequency + + self.metrics["train/loss"].update(loss) + if disto_loss is not None: + self.metrics["train/disto_loss"].update(disto_loss) + self.metrics["train/learning_rate"].update(lr) + self.metrics["train/grad_norm"].update(grad_norm) + self.metrics["train/step_time"].update(step_time) + if unpadded_tokens > 0 and step_time > 0: + self.metrics["train/unpadded_tokens_per_sec"].update(unpadded_tokens / step_time) + + if structure_metrics is not None: + for key, value in structure_metrics.items(): + metric_key = f"train/{key}" + if metric_key in self.metrics: + self.metrics[metric_key].update(value) + + memory_allocated = torch.cuda.memory_allocated() / (1024**3) + self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) + + metrics = self.metrics.compute() + self.metrics.reset() + metrics = { + k: v.detach().cpu().item() if isinstance(v, torch.Tensor) and v.dim() == 0 else v + for k, v in metrics.items() + } + metrics["train/global_step"] = step + + if self._dist_config.is_main_process(): + wandb.log(metrics, step=step) + self._progress_bar.update(self.logging_frequency) + self._progress_bar.set_postfix({"loss": loss.item()}) + + if self._dist_config.local_rank == 0: + logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) + + if self.quant_stats_enabled: + debug_api.step() + if self._log_heatmap and self._quant_logger is not None and self._dist_config.is_main_process(): + import matplotlib.pyplot as plt + + fig = self._quant_logger.generate_heatmap() + if fig is not None: + wandb.log({"quant/gradient_underflow_heatmap": wandb.Image(fig)}, step=step) + plt.close(fig) + + def finish(self): + """Finish the logger.""" + if self.quant_stats_enabled: + debug_api.end_debug() + if not self._dist_config.is_main_process(): + return + wandb.finish() + self._progress_bar.close() diff --git a/bionemo-recipes/recipes/esm2_minifold_te/quantization.py b/bionemo-recipes/recipes/esm2_minifold_te/quantization.py new file mode 100644 index 0000000000..15d9c7c413 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/quantization.py @@ -0,0 +1,412 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for block-wise quantization configuration (FP8/FP4) for the MiniFold folding head. + +Adapted from esm2_native_te/quantization.py. Uses the same API (fp8_layers/fp4_layers) +but applied to MiniFormer blocks instead of transformer layers. +""" + +import logging +import re +import tempfile +from collections import defaultdict +from contextlib import nullcontext +from dataclasses import dataclass +from pathlib import Path +from typing import ContextManager + +import matplotlib +import numpy as np +import transformer_engine.pytorch as te +import yaml +from nvdlfw_inspect.logging import BaseLogger + + +matplotlib.use("Agg") + + +logger = logging.getLogger(__name__) + + +def resolve_layer_precision( + num_layers: int, + fp8_enabled: bool, + fp4_enabled: bool, + fp8_layers: list[int] | None, + fp4_layers: list[int] | None, +) -> list[str | None]: + """Resolve block-wise quantization assignments from user config. + + Takes 1-indexed block lists (as specified by the user in YAML config) and returns a per-block + precision list (0-indexed by position). When a quantization format is enabled but no block list + is provided, all blocks default to that format. When one format has explicit blocks and the other + is enabled without a block list, the unspecified format defaults to the remaining (unclaimed) blocks. + + Args: + num_layers: Total number of MiniFormer blocks in the folding head. + fp8_enabled: Whether FP8 quantization is enabled. + fp4_enabled: Whether FP4 quantization is enabled. + fp8_layers: 1-indexed list of blocks for FP8, or None if not specified. + fp4_layers: 1-indexed list of blocks for FP4, or None if not specified. + + Returns: + A list of length ``num_layers`` where each element is ``"fp8"``, ``"fp4"``, or ``None`` + (BF16 fallback), indexed by block position (0-indexed). + + Raises: + ValueError: If both formats are enabled with no block lists, or if block lists overlap. + """ + all_layers = set(range(1, num_layers + 1)) + + if fp8_enabled and fp4_enabled and fp8_layers is None and fp4_layers is None: + raise ValueError( + "Both fp8_config and fp4_config are enabled but neither fp8_layers nor fp4_layers is specified. " + "When both are enabled, you must explicitly provide layer lists to indicate which blocks use which format." + ) + + # When one format has explicit layers and the other defaults, fill in the remaining layers. + if fp8_enabled and fp8_layers is None: + claimed_by_fp4 = set(fp4_layers) if fp4_layers is not None else set() + fp8_layers = sorted(all_layers - claimed_by_fp4) + if claimed_by_fp4: + logger.warning( + f"fp8_config.enabled=True with no fp8_layers specified, but fp4_layers={sorted(claimed_by_fp4)} " + f"are already claimed by FP4. Defaulting FP8 to the remaining blocks: {fp8_layers}" + ) + else: + logger.info( + f"fp8_config.enabled=True with no fp8_layers specified, defaulting all {num_layers} blocks to FP8" + ) + + if fp4_enabled and fp4_layers is None: + claimed_by_fp8 = set(fp8_layers) if fp8_layers is not None else set() + fp4_layers = sorted(all_layers - claimed_by_fp8) + if claimed_by_fp8: + logger.warning( + f"fp4_config.enabled=True with no fp4_layers specified, but fp8_layers={sorted(claimed_by_fp8)} " + f"are already claimed by FP8. Defaulting FP4 to the remaining blocks: {fp4_layers}" + ) + else: + logger.info( + f"fp4_config.enabled=True with no fp4_layers specified, defaulting all {num_layers} blocks to FP4" + ) + + # Disable layer lists when corresponding config is not enabled. + if not fp8_enabled: + fp8_layers = None + if not fp4_enabled: + fp4_layers = None + + # Validate no overlap between FP8 and FP4 layer assignments. + if fp8_layers is not None and fp4_layers is not None: + overlap = set(fp8_layers) & set(fp4_layers) + if overlap: + raise ValueError( + f"fp8_layers and fp4_layers cannot have overlapping block numbers. Found overlap: {sorted(overlap)}" + ) + + # Build per-block precision list (0-indexed by position, 1-indexed for lookup). + fp8_set = set(fp8_layers) if fp8_layers is not None else set() + fp4_set = set(fp4_layers) if fp4_layers is not None else set() + return [ + "fp8" if layer_1indexed in fp8_set else "fp4" if layer_1indexed in fp4_set else None + for layer_1indexed in range(1, num_layers + 1) + ] + + +@dataclass +class ComponentPrecisionConfig: + """Per-component precision overrides within FP8/FP4 blocks. + + When a block runs in FP8/FP4 via te.autocast, these flags control which sub-components + participate. Components set to True run in the block's precision; components set to False + are wrapped in te.autocast(enabled=False) to stay in BF16. + + Only meaningful when block-level FP8/FP4 is enabled. When all blocks are BF16, + these flags have no effect. + + Attributes: + tri_proj: Triangular update input/output projections (pi, po). + tri_gate: Triangular update sigmoid gates (gi, go). + tri_einsum: Triangular multiplication matmuls (reshaped einsum). + "off" = forced FP32 (default). "bf16" = ambient dtype (recommended). + ffn: Transition update FFN layers (fc1, fc2). + struct_attn: Structure module attention projections (proj, o_proj, g_proj). + struct_ffn: Structure module transition MLP layers. + seq_proj: Sequence and pair feature projections (fc_s, fc_z, seq_to_pair). + dist_head: Distogram output head (fc_out_1, fc_out_2). + """ + + tri_proj: bool = True + tri_gate: bool = True + tri_einsum: str = "off" + ffn: bool = True + struct_attn: bool = True + struct_ffn: bool = True + seq_proj: bool = True + dist_head: bool = True + + def __post_init__(self): + """Normalize tri_einsum for backward compatibility with bool configs.""" + if isinstance(self.tri_einsum, bool): + self.tri_einsum = "bf16" if self.tri_einsum else "off" + + def get_context(self, component: str) -> ContextManager: + """Return te.autocast(enabled=False) if the component is disabled, else nullcontext.""" + if getattr(self, component, True): + return nullcontext() + return te.autocast(enabled=False) + + +class WandBQuantLogger(BaseLogger): + """Forward nvdlfw_inspect quant stats to WandB as scalars. + + Each stat is logged under the ``quant/`` prefix so it appears alongside + training metrics (loss, lDDT, etc.) in a single WandB dashboard. + """ + + def log_scalar(self, name: str, value: float | int, iteration: int, **kwargs): + """Log a single quant stat to WandB.""" + import wandb + + wandb.log({f"quant/{name}": value}) + + +_MINIFOLD_UNDERFLOW_PATTERN = re.compile(r"blocks\.(\d+)\.(\w+)\.(\w+)_gradient_underflows%") + + +class BufferedQuantLogger(BaseLogger): + """Buffer gradient underflow stats in memory and optionally forward all stats to WandB. + + Accumulates gradient_underflows% values keyed by metric name and iteration, + enabling periodic heatmap generation without post-hoc log parsing. + """ + + def __init__(self): + self._underflow_buffer: dict[str, list[tuple[int, float]]] = defaultdict(list) + + def log_scalar(self, name: str, value: float | int, iteration: int, **kwargs): + """Buffer gradient_underflows% for heatmaps. Scalar stats are logged via file logger.""" + if "gradient_underflows%" in name: + self._underflow_buffer[name].append((iteration, value)) + + def generate_heatmap(self): + """Create a heatmap figure from buffered gradient underflow data. + + Returns: + matplotlib.figure.Figure or None if no data has been buffered. + """ + import matplotlib.pyplot as plt + import seaborn as sns + + if not self._underflow_buffer: + return None + + # Parse metric names into (block_num, module, sublayer) tuples + components = [] + for metric_name in self._underflow_buffer: + match = _MINIFOLD_UNDERFLOW_PATTERN.search(metric_name) + if match: + block = int(match.group(1)) + module = match.group(2) + sublayer = match.group(3) + sort_key = (block, module, sublayer) + label = f"B{block} {sublayer}" + components.append((sort_key, label, metric_name)) + + if not components: + return None + + components.sort(key=lambda x: x[0]) + + # Collect all unique iterations + all_iterations = sorted({it for data in self._underflow_buffer.values() for it, _ in data}) + + # Build 2D array + iter_to_col = {it: i for i, it in enumerate(all_iterations)} + matrix = np.full((len(components), len(all_iterations)), np.nan) + labels = [] + + for row_idx, (_, label, metric_name) in enumerate(components): + labels.append(label) + for iteration, value in self._underflow_buffer[metric_name]: + col = iter_to_col[iteration] + matrix[row_idx, col] = value + + # Create figure + fig, ax = plt.subplots(figsize=(14, max(6, len(components) * 0.3))) + cmap = sns.color_palette("rocket_r", as_cmap=True) + max_val = min(6.0, float(np.nanmax(matrix))) if not np.all(np.isnan(matrix)) else 6.0 + + ax.imshow(matrix, aspect="auto", cmap=cmap, interpolation="nearest", vmin=0, vmax=max_val) + + # Y-axis + ax.set_yticks(range(len(labels))) + ax.set_yticklabels(labels, fontsize=max(6, min(10, 200 // max(len(labels), 1)))) + ax.set_ylabel("MiniFold Block / Component") + + # X-axis + n_ticks = min(12, len(all_iterations)) + tick_positions = np.linspace(0, len(all_iterations) - 1, n_ticks).astype(int) + ax.set_xticks(tick_positions) + ax.set_xticklabels([str(all_iterations[i]) for i in tick_positions]) + ax.set_xlabel("Training Iteration") + + ax.set_title("FP8 Gradient Underflows: MiniFold Blocks") + + # Block separator lines + prev_block = None + for idx, (sort_key, _, _) in enumerate(components): + block = sort_key[0] + if prev_block is not None and block != prev_block: + ax.axhline(y=idx - 0.5, color="white", linewidth=2) + prev_block = block + + fig.tight_layout() + return fig + + +def generate_layer_regex( + block_numbers: list[int] | None, + component_precision: ComponentPrecisionConfig | None = None, +) -> str: + """Generate a regex pattern to match specific MiniFormer block numbers (1-indexed). + + The debug API (nvdlfw_inspect) uses layer names assigned by ``infer_and_assign_layer_names``. + Block numbers in the user config are 1-indexed, but module names are 0-indexed, so this + function converts accordingly. Only includes sublayers whose component_precision is enabled. + + Args: + block_numbers: List of block numbers (1-indexed, as specified in fp8_layers config). + If empty or None, returns a pattern that matches nothing. + component_precision: Per-component precision config. Only sublayers with enabled components + are included in the regex. If None, all sublayers are included. + + Returns: + Regex pattern string for matching those blocks' te.Linear sublayers. + """ + if not block_numbers: + return r"fold\.miniformer\.blocks\.DISABLED_NO_BLOCKS_SPECIFIED" + + # Determine which sublayers are actually running in FP8 based on component_precision + sublayers = [] + if component_precision is None or component_precision.tri_proj: + sublayers.extend(["pi", "po"]) + if component_precision is None or component_precision.tri_gate: + sublayers.extend(["gi", "go"]) + if component_precision is None or component_precision.ffn: + sublayers.extend(["fc1", "fc2"]) + + if not sublayers: + return r"fold\.miniformer\.blocks\.DISABLED_NO_COMPONENTS_ENABLED" + + # Convert 1-indexed (user config) to 0-indexed (module names) + block_pattern = "|".join(str(n - 1) for n in sorted(block_numbers)) + sublayer_pattern = "|".join(sublayers) + return rf"fold\.miniformer\.blocks\.({block_pattern})\..*({sublayer_pattern})" + + +def update_quant_stats_config( + config_file: str, + fp4_layers: list[int] | None, + fp8_layers: list[int] | None, + component_precision: ComponentPrecisionConfig | None = None, +) -> str: + """Update the quant stats YAML config with block-specific regex patterns. + + Args: + config_file: Path to the original YAML config file. + fp4_layers: List of block numbers for FP4 (1-indexed). + fp8_layers: List of block numbers for FP8 (1-indexed). + + Returns: + Path to the updated config file (a temp file). + """ + with open(config_file, "r") as f: + config = yaml.safe_load(f) + + if "example_fp4_tensor_stat_collection" in config: + fp4_regex = generate_layer_regex(fp4_layers, component_precision=component_precision) + config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex + if fp4_layers: + logger.info(f"Updated FP4 block regex to match blocks: {fp4_layers}") + else: + logger.info("FP4 blocks empty - regex set to match nothing") + + if "example_fp8_tensor_stat_collection" in config: + fp8_regex = generate_layer_regex(fp8_layers, component_precision=component_precision) + config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex + if fp8_layers: + logger.info(f"Updated FP8 block regex to match blocks: {fp8_layers}") + else: + logger.info("FP8 blocks empty - regex set to match nothing") + + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) + yaml.dump(config, temp_file, default_flow_style=False) + temp_file.close() + + config_str = yaml.dump(config, default_flow_style=False) + logger.info(f"Created updated quant stats config at: {temp_file.name}") + logger.info(f"Updated quant stats config contents:\n{config_str}") + + return temp_file.name + + +def initialize_quant_stats_logging( + quant_stats_file: str, + quant_log_dir: str, + rank: int, + layer_precision: list[str | None], + statistics_logger: BaseLogger | None = None, + component_precision: ComponentPrecisionConfig | None = None, +) -> None: + """Set up quantization stats logging via nvdlfw_inspect. + + Args: + quant_stats_file: Path to the base quant stats YAML config file. + quant_log_dir: Base directory for quant stats logs (a rank subdirectory will be created). + rank: The global rank of this process. + layer_precision: Per-block precision list (0-indexed by position). Each element is + ``"fp8"``, ``"fp4"``, or ``None``. + statistics_logger: Optional custom logger (e.g. :class:`WandBQuantLogger`) that receives + every ``log_scalar`` call from the debug API. + component_precision: Per-component precision config. Only sublayers with enabled components + are included in the stats regex to avoid inspecting layers not running in FP8. + """ + import nvdlfw_inspect.api as debug_api + import transformer_engine + + fp8_layers_1indexed = [i + 1 for i, p in enumerate(layer_precision) if p == "fp8"] or None + fp4_layers_1indexed = [i + 1 for i, p in enumerate(layer_precision) if p == "fp4"] or None + updated_config = update_quant_stats_config( + config_file=quant_stats_file, + fp4_layers=fp4_layers_1indexed, + fp8_layers=fp8_layers_1indexed, + component_precision=component_precision, + ) + + rank_log_dir = Path(quant_log_dir) / f"rank_{rank}" + rank_log_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Logging quant stats to {rank_log_dir}") + + te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") + debug_api.initialize( + config_file=updated_config, + feature_dirs=[te_features_dir], + log_dir=rank_log_dir, + statistics_logger=statistics_logger, + default_logging_enabled=True, + ) diff --git a/bionemo-recipes/recipes/esm2_minifold_te/requirements.txt b/bionemo-recipes/recipes/esm2_minifold_te/requirements.txt new file mode 100644 index 0000000000..7915b4929e --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/requirements.txt @@ -0,0 +1,18 @@ +biopython>=1.80 +datasets +einops +fair-esm +hydra-core +megatron-fsdp +ml_collections +pytest +scipy +torch +torchao!=0.14.0 +torchdata +torchmetrics +tqdm +transformer_engine[pytorch] +transformers>=4.44.0 +wandb +nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect diff --git a/bionemo-recipes/recipes/esm2_minifold_te/scheduler.py b/bionemo-recipes/recipes/esm2_minifold_te/scheduler.py new file mode 100644 index 0000000000..9f9da8da91 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/scheduler.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.optim.lr_scheduler import LambdaLR + + +def get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=2_000, + num_training_steps=500_000, + last_epoch=-1, +): + """Linear warmup and decay scheduler for ESM-2 pretraining. + + The description from Lin 2022 is: The learning rate is warmed up over the first 2,000 steps + to a peak value of 4e-4 (1.6e-4 for the 15B parameter model), and then linearly decayed to + one tenth of its peak value over the 90% of training duration. We've found internally that a + longer warmup helps convergence for larger models (3B+) with bf16 precision. + """ + decay_steps = int(num_training_steps * 0.9) + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + # Warmup phase: linearly increase learning rate + return float(current_step) / float(max(1, num_warmup_steps)) + # Decay phase: linearly decay to one tenth of peak over 90% of training + elif current_step > decay_steps: + return 0.1 # one tenth of peak learning rate after decay period + else: + # Linear decay from 1.0 to 0.1 over decay_steps-num_warmup_steps + return 1.0 - 0.9 * (current_step - num_warmup_steps) / float(max(1, decay_steps - num_warmup_steps)) + + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/bionemo-recipes/recipes/esm2_minifold_te/structure_te.py b/bionemo-recipes/recipes/esm2_minifold_te/structure_te.py new file mode 100644 index 0000000000..ee6a2f7f81 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/structure_te.py @@ -0,0 +1,353 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformer_engine.pytorch as te +from einops import rearrange, repeat +from torch import Tensor + +from minifold_utils import init +from minifold_utils.feats import ( + frames_and_literature_positions_to_atom14_pos, + torsion_angles_to_frames, +) +from minifold_utils.residue_constants import ( + restype_atom14_mask, + restype_atom14_rigid_group_positions, + restype_atom14_to_rigid_group, + restype_rigid_group_default_frame, +) +from minifold_utils.rigid_utils import Rigid +from minifold_utils.tensor_utils import dict_multimap, permute_final_dims +from te_utils import te_layernorm_nd, te_linear_nd + + +class AttentionTE(nn.Module): + """TE version of gated self-attention used in the StructureModule.""" + + def __init__(self, dim: int, num_heads: int, head_width: int, params_dtype: torch.dtype = torch.float32): + super().__init__() + assert dim == num_heads * head_width + + self.dim = dim + self.num_heads = num_heads + self.head_width = head_width + self.rescale_factor = self.head_width**-0.5 + + # Cannot fuse LN+proj because g_proj also reads the LN output + self.layer_norm = te.LayerNorm(dim, eps=1e-5, params_dtype=params_dtype) + self.proj = te.Linear(dim, dim * 3, bias=False, params_dtype=params_dtype) + self.o_proj = te.Linear(dim, dim, bias=True, params_dtype=params_dtype) + self.g_proj = te.Linear(dim, dim, bias=True, params_dtype=params_dtype) + + torch.nn.init.zeros_(self.o_proj.bias) + torch.nn.init.zeros_(self.g_proj.weight) + torch.nn.init.ones_(self.g_proj.bias) + + def forward(self, x: Tensor, bias: Tensor, mask: Tensor) -> Tensor: + """Forward pass. + + Args: + x: Input tensor (B, N, D). + bias: External attention bias (B, H, N, N). + mask: Mask tensor (B, N). + + Returns: + Output tensor (B, N, D). + """ + x = te_layernorm_nd(self.layer_norm, x) + + t = rearrange(te_linear_nd(self.proj, x), "... l (h c) -> ... h l c", h=self.num_heads) + q, k, v = t.chunk(3, dim=-1) + + q = self.rescale_factor * q + a = torch.einsum("...qc,...kc->...qk", q, k) + + # Add external attention bias + a = a + bias + + # Mask padding tokens + mask = repeat(mask, "... lk -> ... h lq lk", h=self.num_heads, lq=q.shape[-2]) + a = a.masked_fill(mask == 0, -np.inf) + a = F.softmax(a, dim=-1) + + y = torch.einsum("...hqk,...hkc->...qhc", a, v) + y = rearrange(y, "... h c -> ... (h c)", h=self.num_heads) + y = te_linear_nd(self.g_proj, x).sigmoid() * y + y = te_linear_nd(self.o_proj, y) + + return y + + +class MLPTE(nn.Module): + """TE version of the MLP used in StructureModule transitions.""" + + def __init__(self, in_dim: int, out_dim: int, params_dtype: torch.dtype = torch.float32): + super().__init__() + self.norm = te.LayerNorm(in_dim, eps=1e-5, params_dtype=params_dtype) + self.fc1 = te.Linear(in_dim, in_dim, params_dtype=params_dtype) + self.fc2 = te.Linear(in_dim, out_dim, params_dtype=params_dtype) + + init.he_normal_init_(self.fc1.weight) + init.final_init_(self.fc2.weight) + init.bias_init_zero_(self.fc1.bias) + init.bias_init_zero_(self.fc2.bias) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x: Input tensor (..., D_in). + + Returns: + Output tensor (..., D_out). + """ + x = te_layernorm_nd(self.norm, x) + x = te_linear_nd(self.fc1, x) + x = F.relu(x) + x = te_linear_nd(self.fc2, x) + return x + + +class AngleResnetBlockTE(nn.Module): + """TE version of AngleResnetBlock.""" + + def __init__(self, dim, params_dtype=torch.float32): + super().__init__() + self.fc1 = te.Linear(dim, dim, params_dtype=params_dtype) + self.fc2 = te.Linear(dim, dim, params_dtype=params_dtype) + + init.he_normal_init_(self.fc1.weight) + init.final_init_(self.fc2.weight) + init.bias_init_zero_(self.fc1.bias) + init.bias_init_zero_(self.fc2.bias) + + def forward(self, a: Tensor) -> Tensor: + x = F.relu(a) + x = te_linear_nd(self.fc1, x) + x = F.relu(x) + x = te_linear_nd(self.fc2, x) + return a + x + + +class AngleResnetTE(nn.Module): + """TE version of AngleResnet.""" + + def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon, params_dtype=torch.float32): + super().__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_blocks = no_blocks + self.no_angles = no_angles + self.eps = epsilon + + self.linear_in = te.Linear(self.c_in, self.c_hidden, params_dtype=params_dtype) + self.linear_initial = te.Linear(self.c_in, self.c_hidden, params_dtype=params_dtype) + + self.layers = nn.ModuleList() + for _ in range(self.no_blocks): + self.layers.append(AngleResnetBlockTE(dim=self.c_hidden, params_dtype=params_dtype)) + + self.linear_out = te.Linear(self.c_hidden, self.no_angles * 2, params_dtype=params_dtype) + + init.lecun_normal_init_(self.linear_in.weight) + init.lecun_normal_init_(self.linear_initial.weight) + init.final_init_(self.linear_out.weight) + + init.bias_init_zero_(self.linear_in.bias) + init.bias_init_zero_(self.linear_initial.bias) + init.bias_init_zero_(self.linear_out.bias) + + self.relu = nn.ReLU() + + def forward(self, s: Tensor, s_initial: Tensor) -> Tuple[Tensor, Tensor]: + """Forward pass. + + Args: + s: Single embedding [*, C_hidden]. + s_initial: Initial single embedding [*, C_hidden]. + + Returns: + Tuple of (unnormalized_angles, normalized_angles), each [*, no_angles, 2]. + """ + s_initial = self.relu(s_initial) + s_initial = te_linear_nd(self.linear_initial, s_initial) + s = self.relu(s) + s = te_linear_nd(self.linear_in, s) + s = s + s_initial + + for layer in self.layers: + s = layer(s) + + s = self.relu(s) + s = te_linear_nd(self.linear_out, s) + + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(s**2, dim=-1, keepdim=True), + min=self.eps, + ) + ) + s = s / norm_denom + + return unnormalized_s, s + + +class StructureModuleTE(nn.Module): + """TE version of the StructureModule.""" + + def __init__( + self, + c_s: int, + c_z: int, + c_resnet: int, + head_dim: int, + no_heads: int, + no_blocks: int, + no_resnet_blocks: int, + no_angles: int, + trans_scale_factor: float, + epsilon: float, + inf: float, + params_dtype: torch.dtype = torch.float32, + ): + super().__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_resnet = c_resnet + self.no_heads = no_heads + self.head_dim = head_dim + self.no_blocks = no_blocks + self.no_resnet_blocks = no_resnet_blocks + self.no_angles = no_angles + self.trans_scale_factor = trans_scale_factor + self.epsilon = epsilon + self.inf = inf + + self.layer_norm_s = te.LayerNorm(self.c_s, eps=1e-5, params_dtype=params_dtype) + self.layer_norm_z = te.LayerNorm(self.c_z, eps=1e-5, params_dtype=params_dtype) + self.linear_in = te.Linear(self.c_s, self.c_s, params_dtype=params_dtype) + self.linear_b = te.Linear(self.c_z, self.no_blocks * self.no_heads, params_dtype=params_dtype) + + self.attn = nn.ModuleList( + [ + AttentionTE(self.c_s, self.no_heads, self.head_dim, params_dtype=params_dtype) + for _ in range(self.no_blocks) + ] + ) + self.transitions = nn.ModuleList( + [MLPTE(self.c_s, self.c_s, params_dtype=params_dtype) for _ in range(self.no_blocks)] + ) + + self.bb_update = te.Linear(self.c_s, 9, params_dtype=params_dtype) + self.angle_resnet = AngleResnetTE( + self.c_s, + self.c_resnet, + self.no_resnet_blocks, + self.no_angles, + self.epsilon, + params_dtype=params_dtype, + ) + + # Initialize weights + init.lecun_normal_init_(self.linear_in.weight) + init.bias_init_zero_(self.linear_in.bias) + init.lecun_normal_init_(self.bb_update.weight) + init.bias_init_zero_(self.bb_update.bias) + init.lecun_normal_init_(self.linear_b.weight) + init.bias_init_zero_(self.linear_b.bias) + + # Initialize buffers + frames = torch.tensor(restype_rigid_group_default_frame) + groups = torch.tensor(restype_atom14_to_rigid_group) + atom_mask = torch.tensor(restype_atom14_mask) + positions = torch.tensor(restype_atom14_rigid_group_positions) + + self.register_buffer("default_frames", frames, persistent=False) + self.register_buffer("group_idx", groups, persistent=False) + self.register_buffer("atom_mask", atom_mask, persistent=False) + self.register_buffer("lit_positions", positions, persistent=False) + + def forward(self, s, z, aatype, mask): + """Forward pass. + + Args: + s: Single representation (B, N, c_s). + z: Pair representation (B, N, N, c_z). + aatype: Amino acid types (B, N). + mask: Residue mask (B, N). + + Returns: + Dictionary with angles, frames, positions, states. + """ + # Input projection + s = te_layernorm_nd(self.layer_norm_s, s) + s_initial = s + s = te_linear_nd(self.linear_in, s) + + # Pairwise bias + B, N = s.shape[:2] + z = te_layernorm_nd(self.layer_norm_z, z) + b = te_linear_nd(self.linear_b, z) + b = permute_final_dims(b, (2, 0, 1)) + b = b.reshape(B, self.no_blocks, self.no_heads, N, N) + + # Apply transformer layers + outputs = [] + for i in range(self.no_blocks): + s = s + self.attn[i](s, b[:, i], mask) + s = s + self.transitions[i](s) + + # Predict angles + unnormalized_angles, angles = self.angle_resnet(s, s_initial) + + # Predict positions (in FP32 via explicit .float() cast) + n, ca, c = te_linear_nd(self.bb_update, s.float()).chunk(3, dim=-1) + rigids = Rigid.make_transform_from_reference(n, ca, c, eps=1e-7) + scaled_rigids = rigids.scale_translation(self.trans_scale_factor) + + all_frames_to_global = torsion_angles_to_frames(scaled_rigids, angles, aatype, self.default_frames) + pred_xyz = frames_and_literature_positions_to_atom14_pos( + all_frames_to_global, + aatype, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) + outputs.append( + { + "angles": angles, + "unnormalized_angles": unnormalized_angles, + "frames": scaled_rigids.to_tensor_4x4(), + "sidechain_frames": all_frames_to_global.to_tensor_4x4(), + "positions": pred_xyz, + "states": s, + } + ) + + outputs = dict_multimap(torch.stack, outputs) + outputs["single"] = s + return outputs diff --git a/bionemo-recipes/recipes/esm2_minifold_te/te_utils.py b/bionemo-recipes/recipes/esm2_minifold_te/te_utils.py new file mode 100644 index 0000000000..552e0abeca --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/te_utils.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import transformer_engine.pytorch as te + + +def te_linear_nd(module: te.Linear, x: torch.Tensor) -> torch.Tensor: + """Apply a te.Linear module to an N-dimensional tensor (N >= 2). + + te.Linear is validated for 2D (B, D) and 3D (B, S, D) inputs. + For 4D+ inputs (e.g. pair representations with shape B, N, N, D), + we flatten leading dimensions to 2D, apply the linear, and reshape back. + + Args: + module: A transformer_engine.pytorch.Linear module. + x: Input tensor of shape (*leading_dims, in_features). + + Returns: + Tensor of shape (*leading_dims, out_features). + """ + if x.ndim <= 3: + return module(x) + leading = x.shape[:-1] + x = x.reshape(-1, x.shape[-1]) + x = module(x) + return x.reshape(*leading, -1) + + +def te_layernorm_nd(module: te.LayerNorm, x: torch.Tensor) -> torch.Tensor: + """Apply a te.LayerNorm module to an N-dimensional tensor (N >= 2). + + Args: + module: A transformer_engine.pytorch.LayerNorm module. + x: Input tensor of shape (*leading_dims, normalized_shape). + + Returns: + Tensor of same shape as input. + """ + if x.ndim <= 3: + return module(x) + leading = x.shape[:-1] + x = x.reshape(-1, x.shape[-1]) + x = module(x) + return x.reshape(*leading, -1) + + +def tri_mul_bmm(a: torch.Tensor, b: torch.Tensor, k_dim: int, mode: str = "off") -> torch.Tensor: + """Batched GEMM equivalent of triangular multiplication einsum. + + Replaces: + k_dim=2: torch.einsum("bikd,bjkd->bijd", a, b) + k_dim=1: torch.einsum("bkid,bkjd->bijd", a, b) + + Args: + a: Tensor of shape (B, N, N, D). + b: Tensor of shape (B, N, N, D). + k_dim: Spatial dimension to contract over (1 or 2). + mode: Precision mode. + "off": FP32 bmm (caller upcasts via .float(), default). + "bf16": BF16 bmm (skip .float() upcast). + + Returns: + Tensor of shape (B, N, N, D). + """ + B, N1, N2, D = a.shape + # Move D to dim 1: (B, D, N, N), then merge B*D for batched mm + a = a.permute(0, 3, 1, 2).contiguous().reshape(B * D, N1, N2) + b = b.permute(0, 3, 1, 2).contiguous().reshape(B * D, N1, N2) + + if k_dim == 2: + # "bikd,bjkd->bijd": a is (batch, i, k), b is (batch, j, k) + # result = a @ b^T = (batch, i, j) + out = torch.bmm(a, b.transpose(1, 2)) + elif k_dim == 1: + # "bkid,bkjd->bijd": a is (batch, k, i), b is (batch, k, j) + # result = a^T @ b = (batch, i, j) + out = torch.bmm(a.transpose(1, 2), b) + else: + raise ValueError(f"k_dim must be 1 or 2, got {k_dim}") + + # Reshape back: (B*D, N, N) -> (B, D, N, N) -> (B, N, N, D) + return out.reshape(B, D, N1, N2).permute(0, 2, 3, 1) + + +def te_layernorm_linear_nd(module: te.LayerNormLinear, x: torch.Tensor) -> torch.Tensor: + """Apply a te.LayerNormLinear module to an N-dimensional tensor (N >= 2). + + Args: + module: A transformer_engine.pytorch.LayerNormLinear module. + x: Input tensor of shape (*leading_dims, in_features). + + Returns: + Tensor of shape (*leading_dims, out_features). + """ + if x.ndim <= 3: + return module(x) + leading = x.shape[:-1] + x = x.reshape(-1, x.shape[-1]) + x = module(x) + return x.reshape(*leading, -1) diff --git a/bionemo-recipes/recipes/esm2_minifold_te/tests/__init__.py b/bionemo-recipes/recipes/esm2_minifold_te/tests/__init__.py new file mode 100644 index 0000000000..1dd47a63cf --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/tests/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/bionemo-recipes/recipes/esm2_minifold_te/tests/conftest.py b/bionemo-recipes/recipes/esm2_minifold_te/tests/conftest.py new file mode 100644 index 0000000000..670a1cfbba --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/tests/conftest.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pytest fixtures for ESM2-MiniFold TE tests.""" + +import sys +from pathlib import Path + +import pytest +import torch + + +# Add recipe root to path so we can import recipe modules +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +SEED = 42 + + +@pytest.fixture(autouse=True) +def set_seed(): + """Set seed before each test for reproducibility.""" + torch.manual_seed(SEED) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(SEED) + + +@pytest.fixture +def device(): + return DEVICE diff --git a/bionemo-recipes/recipes/esm2_minifold_te/tests/test_data_pipeline.py b/bionemo-recipes/recipes/esm2_minifold_te/tests/test_data_pipeline.py new file mode 100644 index 0000000000..b0317abc36 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/tests/test_data_pipeline.py @@ -0,0 +1,289 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the PDB data pipeline. + +Validates: +- mmCIF parsing correctness (BioPython) +- MmcifStructureDataset batch format +- ParquetStructureDataset batch format +- Equivalence between both dataset implementations + +Requires network access to download test structures from RCSB PDB. +Run with: pytest tests/test_data_pipeline.py -v +""" + +import sys +from pathlib import Path +from urllib.request import urlretrieve + +import pandas as pd +import pytest +import torch + + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from dataset import MmcifStructureDataset, ParquetStructureDataset + + +# Test protein: 1CRN (crambin, 46 residues, very high resolution) +TEST_PDB_ID = "1CRN" +TEST_PDB_URL = f"https://files.rcsb.org/download/{TEST_PDB_ID}.cif" +TEST_SEQ_LENGTH = 46 +MAX_SEQ_LENGTH = 64 + + +@pytest.fixture(scope="session") +def tokenizer(): + """Load ESM-2 tokenizer (small model for speed).""" + from transformers import EsmTokenizer + + return EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") + + +@pytest.fixture(scope="session") +def cif_dir(tmp_path_factory): + """Download 1CRN.cif to a temp directory.""" + d = tmp_path_factory.mktemp("cif_files") + cif_path = d / f"{TEST_PDB_ID}.cif" + urlretrieve(TEST_PDB_URL, cif_path) + return str(d) + + +@pytest.fixture(scope="session") +def parsed_data(cif_dir): + """Parse 1CRN and return (sequence, coords, ca_mask).""" + from Bio.PDB.MMCIFParser import MMCIFParser + + aa_3to1 = { + "ALA": "A", + "CYS": "C", + "ASP": "D", + "GLU": "E", + "PHE": "F", + "GLY": "G", + "HIS": "H", + "ILE": "I", + "LYS": "K", + "LEU": "L", + "MET": "M", + "ASN": "N", + "PRO": "P", + "GLN": "Q", + "ARG": "R", + "SER": "S", + "THR": "T", + "VAL": "V", + "TRP": "W", + "TYR": "Y", + } + parser = MMCIFParser(QUIET=True) + structure = parser.get_structure(TEST_PDB_ID, str(Path(cif_dir) / f"{TEST_PDB_ID}.cif")) + model = structure[0] + chain = next(iter(model)) + + sequence, coords, ca_mask = [], [], [] + for res in chain.get_residues(): + if res.id[0] != " ": + continue + resname = res.get_resname().strip() + if resname not in aa_3to1: + continue + sequence.append(aa_3to1[resname]) + if "CA" in res: + ca = res["CA"].get_vector() + coords.append([float(ca[0]), float(ca[1]), float(ca[2])]) + ca_mask.append(1) + else: + coords.append([0.0, 0.0, 0.0]) + ca_mask.append(0) + + return "".join(sequence), coords, ca_mask + + +@pytest.fixture(scope="session") +def parquet_path(parsed_data, tmp_path_factory): + """Create a parquet file from parsed 1CRN data.""" + sequence, coords, ca_mask = parsed_data + d = tmp_path_factory.mktemp("parquet") + path = d / "test_structures.parquet" + df = pd.DataFrame( + [ + { + "pdb_id": TEST_PDB_ID, + "sequence": sequence, + "coords": coords, + "ca_mask": ca_mask, + "num_residues": len(sequence), + } + ] + ) + df.to_parquet(str(path), index=False) + return str(path) + + +# =========================================================================== +# CIF Parsing +# =========================================================================== + + +class TestCifParsing: + def test_sequence_length(self, parsed_data): + sequence, coords, ca_mask = parsed_data + assert len(sequence) == TEST_SEQ_LENGTH, f"Expected {TEST_SEQ_LENGTH} residues, got {len(sequence)}" + + def test_coords_count_matches_sequence(self, parsed_data): + sequence, coords, ca_mask = parsed_data + assert len(coords) == len(sequence) + assert len(ca_mask) == len(sequence) + + def test_ca_coords_finite(self, parsed_data): + _, coords, _ = parsed_data + for i, c in enumerate(coords): + assert all(abs(v) < 1e6 for v in c), f"Non-finite coords at residue {i}: {c}" + + def test_all_ca_present(self, parsed_data): + """1CRN is a high-quality structure - all Ca should be present.""" + _, _, ca_mask = parsed_data + assert all(m == 1 for m in ca_mask), "1CRN should have all Ca atoms resolved" + + def test_sequence_standard_amino_acids(self, parsed_data): + sequence, _, _ = parsed_data + valid_aa = set("ACDEFGHIKLMNPQRSTVWY") + for i, aa in enumerate(sequence): + assert aa in valid_aa, f"Non-standard amino acid '{aa}' at position {i}" + + def test_first_residue_is_threonine(self, parsed_data): + """1CRN starts with Thr-Thr-Cys.""" + sequence, _, _ = parsed_data + assert sequence[:3] == "TTC", f"Expected TTC, got {sequence[:3]}" + + +# =========================================================================== +# MmcifStructureDataset +# =========================================================================== + + +class TestMmcifStructureDataset: + def test_batch_keys(self, cif_dir, tokenizer): + ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20) + sample = ds[0] + assert set(sample.keys()) == {"input_ids", "attention_mask", "mask", "coords"} + + def test_batch_shapes(self, cif_dir, tokenizer): + ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20) + sample = ds[0] + assert sample["input_ids"].shape == (MAX_SEQ_LENGTH,) + assert sample["attention_mask"].shape == (MAX_SEQ_LENGTH,) + assert sample["mask"].shape == (MAX_SEQ_LENGTH,) + assert sample["coords"].shape == (MAX_SEQ_LENGTH, 3) + + def test_batch_dtypes(self, cif_dir, tokenizer): + ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20) + sample = ds[0] + assert sample["input_ids"].dtype == torch.long + assert sample["attention_mask"].dtype == torch.long + assert sample["mask"].dtype == torch.float32 + assert sample["coords"].dtype == torch.float32 + + def test_cls_eos_tokens(self, cif_dir, tokenizer): + ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20) + sample = ds[0] + assert sample["input_ids"][0].item() == 0, "First token should be CLS (0)" + # Find EOS position + real_len = sample["attention_mask"].sum().item() + assert sample["input_ids"][int(real_len) - 1].item() == 2, "Last real token should be EOS (2)" + + def test_padding_is_zero(self, cif_dir, tokenizer): + ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20) + sample = ds[0] + real_len = sample["attention_mask"].sum().item() + assert (sample["attention_mask"][int(real_len) :] == 0).all() + assert (sample["coords"][TEST_SEQ_LENGTH:] == 0).all() + + +# =========================================================================== +# ParquetStructureDataset +# =========================================================================== + + +class TestParquetStructureDataset: + def test_batch_keys(self, parquet_path, tokenizer): + ds = ParquetStructureDataset(parquet_path, tokenizer, max_seq_length=MAX_SEQ_LENGTH) + sample = ds[0] + assert set(sample.keys()) == {"input_ids", "attention_mask", "mask", "coords"} + + def test_batch_shapes(self, parquet_path, tokenizer): + ds = ParquetStructureDataset(parquet_path, tokenizer, max_seq_length=MAX_SEQ_LENGTH) + sample = ds[0] + assert sample["input_ids"].shape == (MAX_SEQ_LENGTH,) + assert sample["coords"].shape == (MAX_SEQ_LENGTH, 3) + + def test_batch_dtypes(self, parquet_path, tokenizer): + ds = ParquetStructureDataset(parquet_path, tokenizer, max_seq_length=MAX_SEQ_LENGTH) + sample = ds[0] + assert sample["input_ids"].dtype == torch.long + assert sample["coords"].dtype == torch.float32 + + +# =========================================================================== +# Dataset Equivalence +# =========================================================================== + + +class TestDatasetEquivalence: + """Both datasets must produce matching outputs for the same protein.""" + + def _get_samples(self, cif_dir, parquet_path, tokenizer): + ds_cif = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20) + ds_pq = ParquetStructureDataset(parquet_path, tokenizer, max_seq_length=MAX_SEQ_LENGTH) + return ds_cif[0], ds_pq[0] + + def test_same_input_ids(self, cif_dir, parquet_path, tokenizer): + s_cif, s_pq = self._get_samples(cif_dir, parquet_path, tokenizer) + assert torch.equal(s_cif["input_ids"], s_pq["input_ids"]), "input_ids mismatch" + + def test_same_attention_mask(self, cif_dir, parquet_path, tokenizer): + s_cif, s_pq = self._get_samples(cif_dir, parquet_path, tokenizer) + assert torch.equal(s_cif["attention_mask"], s_pq["attention_mask"]), "attention_mask mismatch" + + def test_same_mask(self, cif_dir, parquet_path, tokenizer): + s_cif, s_pq = self._get_samples(cif_dir, parquet_path, tokenizer) + assert torch.equal(s_cif["mask"], s_pq["mask"]), "mask mismatch" + + def test_same_coords(self, cif_dir, parquet_path, tokenizer): + s_cif, s_pq = self._get_samples(cif_dir, parquet_path, tokenizer) + assert torch.allclose(s_cif["coords"], s_pq["coords"], atol=1e-4), ( + f"coords max diff: {(s_cif['coords'] - s_pq['coords']).abs().max().item()}" + ) + + def test_distogram_loss_equivalence(self, cif_dir, parquet_path, tokenizer): + """Both datasets should produce the same distogram loss.""" + sys.path.insert(0, str(Path(__file__).parent.parent)) + from train_fsdp2 import compute_distogram_loss + + s_cif, s_pq = self._get_samples(cif_dir, parquet_path, tokenizer) + + # Fake preds (same for both) + torch.manual_seed(42) + preds = torch.randn(1, MAX_SEQ_LENGTH, MAX_SEQ_LENGTH, 64) + + loss_cif = compute_distogram_loss(preds, s_cif["coords"].unsqueeze(0), s_cif["mask"].unsqueeze(0)) + loss_pq = compute_distogram_loss(preds, s_pq["coords"].unsqueeze(0), s_pq["mask"].unsqueeze(0)) + + assert torch.allclose(loss_cif, loss_pq, atol=1e-4), ( + f"Loss mismatch: cif={loss_cif.item()}, pq={loss_pq.item()}" + ) diff --git a/bionemo-recipes/recipes/esm2_minifold_te/tests/test_model.py b/bionemo-recipes/recipes/esm2_minifold_te/tests/test_model.py new file mode 100644 index 0000000000..9859b29146 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/tests/test_model.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ESM2-MiniFold TE model components. + +Tests model instantiation, forward pass shapes, and gradient flow. +Uses small model dimensions for fast testing. +""" + +import sys +from pathlib import Path + +import torch + + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from heads_te import PerResidueLDDTCaPredictorTE +from miniformer_te import BlockTE, MiniFormerTE, TransitionUpdateTE, TriangularUpdateTE +from model_te import FoldingTrunkTE, PairToSequenceTE, SequenceToPairTE +from quantization import ComponentPrecisionConfig, resolve_layer_precision +from structure_te import MLPTE, AngleResnetTE, AttentionTE +from te_utils import te_linear_nd + + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +B, N, DIM = 2, 16, 64 +HIDDEN = 256 + + +# =========================================================================== +# TE Module Shape Tests +# =========================================================================== + + +class TestTransitionUpdateTE: + def test_forward_shape(self): + mod = TransitionUpdateTE(dim=DIM, hidden=HIDDEN).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + out = mod(x) + assert out.shape == (B, N, N, DIM) + + def test_gradient_flow(self): + mod = TransitionUpdateTE(dim=DIM, hidden=HIDDEN).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE, requires_grad=True) + out = mod(x) + out.sum().backward() + assert x.grad is not None + + +class TestTriangularUpdateTE: + def test_forward_shape(self): + mod = TriangularUpdateTE(dim=DIM).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + out = mod(x, mask) + assert out.shape == (B, N, N, DIM) + + +class TestBlockTE: + def test_forward_shape(self): + mod = BlockTE(dim=DIM).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + out = mod(x, mask) + assert out.shape == (B, N, N, DIM) + + +class TestMiniFormerTE: + def test_forward_shape(self): + mod = MiniFormerTE(dim=DIM, blocks=2).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + out = mod(x, mask) + assert out.shape == (B, N, N, DIM) + + +class TestSequenceToPairTE: + def test_forward_shape(self): + seq_dim, inner, pair_dim = 512, 32, DIM + mod = SequenceToPairTE(seq_dim, inner, pair_dim).to(DEVICE) + x = torch.randn(B, N, seq_dim, device=DEVICE) + out = mod(x) + assert out.shape == (B, N, N, pair_dim) + + +class TestPairToSequenceTE: + def test_forward_shape(self): + c_z, c_s, c_s_out = DIM, 512, 512 + mod = PairToSequenceTE(c_z=c_z, c_s=c_s, c_s_out=c_s_out).to(DEVICE) + s_z = torch.randn(B, N, N, c_z, device=DEVICE) + s_s = torch.randn(B, N, c_s, device=DEVICE) + pair_mask = torch.ones(B, N, N, device=DEVICE) + out = mod(s_z, s_s, pair_mask) + assert out.shape == (B, N, c_s_out) + + +class TestFoldingTrunkTE: + def test_forward_shape(self): + c_s, c_z = 512, DIM + mod = FoldingTrunkTE(c_s=c_s, c_z=c_z, bins=32, disto_bins=64, num_layers=2).to(DEVICE) + s_s = torch.randn(B, N, c_s, device=DEVICE) + s_z = torch.randn(B, N, N, c_z, device=DEVICE) + mask = torch.ones(B, N, device=DEVICE) + + mod.eval() + with torch.no_grad(): + preds, sz = mod(s_s, s_z, mask, num_recycling=0) + + assert preds.shape == (B, N, N, 64) + assert sz.shape[:3] == (B, N, N) + + +class TestAttentionTE: + def test_forward_shape(self): + dim, heads, head_width = 512, 8, 64 + mod = AttentionTE(dim, heads, head_width).to(DEVICE) + x = torch.randn(B, N, dim, device=DEVICE) + bias = torch.randn(B, heads, N, N, device=DEVICE) + mask = torch.ones(B, N, device=DEVICE) + out = mod(x, bias, mask) + assert out.shape == (B, N, dim) + + +class TestMLPTE: + def test_forward_shape(self): + mod = MLPTE(512, 512).to(DEVICE) + x = torch.randn(B, N, 512, device=DEVICE) + out = mod(x) + assert out.shape == (B, N, 512) + + +class TestAngleResnetTE: + def test_forward_shape(self): + mod = AngleResnetTE(512, DIM, no_blocks=2, no_angles=7, epsilon=1e-5).to(DEVICE) + s = torch.randn(B, N, 512, device=DEVICE) + s_init = torch.randn(B, N, 512, device=DEVICE) + unnorm, norm = mod(s, s_init) + assert unnorm.shape == (B, N, 7, 2) + assert norm.shape == (B, N, 7, 2) + + +class TestPerResidueLDDTCaPredictorTE: + def test_forward_shape(self): + mod = PerResidueLDDTCaPredictorTE(50, 512, DIM).to(DEVICE) + x = torch.randn(B, N, 512, device=DEVICE) + out = mod(x) + assert out.shape == (B, N, 50) + + +# =========================================================================== +# Block Precision / Quantization Tests +# =========================================================================== + + +class TestResolveLayerPrecision: + def test_neither_enabled(self): + result = resolve_layer_precision(6, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None) + assert result == [None] * 6 + + def test_fp8_all_blocks(self): + result = resolve_layer_precision(4, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None) + assert result == ["fp8"] * 4 + + def test_fp4_all_blocks(self): + result = resolve_layer_precision(4, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=None) + assert result == ["fp4"] * 4 + + def test_fp8_specific_blocks(self): + result = resolve_layer_precision(4, fp8_enabled=True, fp4_enabled=False, fp8_layers=[1, 3], fp4_layers=None) + assert result == ["fp8", None, "fp8", None] + + def test_mixed_fp8_fp4(self): + result = resolve_layer_precision(4, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2], fp4_layers=[3, 4]) + assert result == ["fp8", "fp8", "fp4", "fp4"] + + def test_fp8_explicit_fp4_fills_remaining(self): + result = resolve_layer_precision(4, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2], fp4_layers=None) + assert result == ["fp8", "fp8", "fp4", "fp4"] + + def test_both_enabled_no_layers_raises(self): + import pytest + + with pytest.raises(ValueError): + resolve_layer_precision(4, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=None) + + def test_overlap_raises(self): + import pytest + + with pytest.raises(ValueError): + resolve_layer_precision(4, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2], fp4_layers=[2, 3]) + + +class TestMiniFormerTEPrecision: + def test_no_precision_config(self): + """MiniFormerTE works without block_precision (BF16 default).""" + mod = MiniFormerTE(dim=DIM, blocks=2).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + out = mod(x, mask) + assert out.shape == (B, N, N, DIM) + + def test_with_block_precision_none_list(self): + """MiniFormerTE works with all-None block_precision (explicit BF16).""" + mod = MiniFormerTE(dim=DIM, blocks=2, block_precision=[None, None]).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + out = mod(x, mask) + assert out.shape == (B, N, N, DIM) + + def test_block_precision_length_mismatch_raises(self): + import pytest + + with pytest.raises(ValueError): + MiniFormerTE(dim=DIM, blocks=2, block_precision=[None]) + + +class TestComponentPrecision: + """Test per-component precision overrides within FP8 blocks. + + These tests verify that te.autocast(enabled=False) is correctly applied + to keep specific sub-components (projections, gates, FFN, etc.) in BF16 + while the rest of the block runs in FP8. + """ + + def test_all_components_enabled_fp8(self): + """All components in FP8 — forward pass produces valid output.""" + cp = ComponentPrecisionConfig() # all True by default + mod = MiniFormerTE(dim=DIM, blocks=2, block_precision=["fp8", "fp8"], component_precision=cp).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + with torch.autocast("cuda", dtype=torch.bfloat16): + out = mod(x, mask) + assert out.shape == (B, N, N, DIM) + + def test_ffn_bf16_only(self): + """FFN in BF16, everything else in FP8.""" + cp = ComponentPrecisionConfig(ffn=False) + mod = MiniFormerTE(dim=DIM, blocks=2, block_precision=["fp8", "fp8"], component_precision=cp).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + with torch.autocast("cuda", dtype=torch.bfloat16): + out = mod(x, mask) + assert out.shape == (B, N, N, DIM) + + def test_tri_proj_bf16_only(self): + """Triangular projections in BF16, everything else in FP8.""" + cp = ComponentPrecisionConfig(tri_proj=False) + mod = MiniFormerTE(dim=DIM, blocks=2, block_precision=["fp8", "fp8"], component_precision=cp).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + with torch.autocast("cuda", dtype=torch.bfloat16): + out = mod(x, mask) + assert out.shape == (B, N, N, DIM) + + def test_tri_gate_bf16_only(self): + """Triangular gates in BF16, everything else in FP8.""" + cp = ComponentPrecisionConfig(tri_gate=False) + mod = MiniFormerTE(dim=DIM, blocks=2, block_precision=["fp8", "fp8"], component_precision=cp).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + with torch.autocast("cuda", dtype=torch.bfloat16): + out = mod(x, mask) + assert out.shape == (B, N, N, DIM) + + def test_all_components_bf16(self): + """All components forced to BF16 within FP8 blocks.""" + cp = ComponentPrecisionConfig( + tri_proj=False, + tri_gate=False, + ffn=False, + struct_attn=False, + struct_ffn=False, + seq_proj=False, + dist_head=False, + ) + mod = MiniFormerTE(dim=DIM, blocks=2, block_precision=["fp8", "fp8"], component_precision=cp).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + with torch.autocast("cuda", dtype=torch.bfloat16): + out = mod(x, mask) + assert out.shape == (B, N, N, DIM) + + def test_mixed_block_and_component(self): + """Block 1 in FP8 with FFN in BF16, block 2 fully in BF16.""" + cp = ComponentPrecisionConfig(ffn=False) + mod = MiniFormerTE(dim=DIM, blocks=2, block_precision=["fp8", None], component_precision=cp).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + with torch.autocast("cuda", dtype=torch.bfloat16): + out = mod(x, mask) + assert out.shape == (B, N, N, DIM) + + def test_gradient_flow_with_component_override(self): + """Gradients flow correctly through mixed-precision components.""" + cp = ComponentPrecisionConfig(ffn=False, tri_gate=False) + mod = MiniFormerTE(dim=DIM, blocks=2, block_precision=["fp8", "fp8"], component_precision=cp).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE, requires_grad=True) + mask = torch.ones(B, N, N, device=DEVICE) + with torch.autocast("cuda", dtype=torch.bfloat16): + out = mod(x, mask) + out.sum().backward() + assert x.grad is not None + + def test_folding_trunk_with_dist_head_bf16(self): + """FoldingTrunkTE with dist_head forced to BF16 within FP8 blocks.""" + cp = ComponentPrecisionConfig(dist_head=False) + c_s, c_z = 512, DIM + mod = FoldingTrunkTE( + c_s=c_s, + c_z=c_z, + bins=32, + disto_bins=64, + num_layers=2, + block_precision=["fp8", "fp8"], + component_precision=cp, + ).to(DEVICE) + s_s = torch.randn(B, N, c_s, device=DEVICE) + s_z = torch.randn(B, N, N, c_z, device=DEVICE) + mask = torch.ones(B, N, device=DEVICE) + mod.eval() + with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): + preds, sz = mod(s_s, s_z, mask, num_recycling=0) + assert preds.shape == (B, N, N, 64) + + def test_folding_trunk_with_seq_proj_bf16(self): + """FoldingTrunkTE with seq_proj forced to BF16 within FP8 blocks.""" + cp = ComponentPrecisionConfig(seq_proj=False) + c_s, c_z = 512, DIM + mod = FoldingTrunkTE( + c_s=c_s, + c_z=c_z, + bins=32, + disto_bins=64, + num_layers=2, + block_precision=["fp8", "fp8"], + component_precision=cp, + ).to(DEVICE) + s_s = torch.randn(B, N, c_s, device=DEVICE) + s_z = torch.randn(B, N, N, c_z, device=DEVICE) + mask = torch.ones(B, N, device=DEVICE) + mod.eval() + with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): + preds, sz = mod(s_s, s_z, mask, num_recycling=0) + assert preds.shape == (B, N, N, 64) + + +# =========================================================================== +# te_utils Tests +# =========================================================================== + + +class TestTeUtils: + def test_te_linear_nd_2d(self): + import transformer_engine.pytorch as te + + linear = te.Linear(DIM, HIDDEN).to(DEVICE) + x = torch.randn(B * N, DIM, device=DEVICE) + out = te_linear_nd(linear, x) + assert out.shape == (B * N, HIDDEN) + + def test_te_linear_nd_3d(self): + import transformer_engine.pytorch as te + + linear = te.Linear(DIM, HIDDEN).to(DEVICE) + x = torch.randn(B, N, DIM, device=DEVICE) + out = te_linear_nd(linear, x) + assert out.shape == (B, N, HIDDEN) + + def test_te_linear_nd_4d(self): + import transformer_engine.pytorch as te + + linear = te.Linear(DIM, HIDDEN).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE) + out = te_linear_nd(linear, x) + assert out.shape == (B, N, N, HIDDEN) diff --git a/bionemo-recipes/recipes/esm2_minifold_te/tests/test_precisions.py b/bionemo-recipes/recipes/esm2_minifold_te/tests/test_precisions.py new file mode 100644 index 0000000000..078b5e536f --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/tests/test_precisions.py @@ -0,0 +1,468 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests that verify actual precision of intermediate tensors during forward passes. + +Checks that: +- Triangular multiplication batched GEMMs stay in FP32 by default +- tri_einsum toggle controls FP32 vs ambient dtype for triangular matmuls +- te.autocast correctly enables/disables FP8 for specific blocks +- Component precision overrides correctly keep components out of FP8 +""" + +import sys +from contextlib import nullcontext +from pathlib import Path + +import torch +import transformer_engine.pytorch as te + + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from transformer_engine.pytorch.quantization import FP8GlobalStateManager + +from miniformer_te import MiniFormerTE, TransitionUpdateTE, TriangularUpdateTE +from model_te import FoldingTrunkTE +from quantization import ComponentPrecisionConfig +from te_utils import tri_mul_bmm + + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +B, N, DIM = 2, 16, 64 + + +class TestTriMulPrecision: + """Verify the triangular multiplication batched GEMMs run in the expected precision.""" + + def test_bmm_intermediates_are_fp32_by_default(self): + """Even with BF16 input, the bmm should compute in FP32 via .float() when tri_einsum="off".""" + mod = TriangularUpdateTE(dim=DIM).to(DEVICE) + x_bf16 = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, N, device=DEVICE, dtype=torch.bfloat16) + + captured = {} + orig_bmm = torch.bmm + + def patched_bmm(a, b): + captured["input_dtype"] = a.dtype + result = orig_bmm(a, b) + captured["output_dtype"] = result.dtype + return result + + torch.bmm = patched_bmm + try: + mod(x_bf16, mask) + finally: + torch.bmm = orig_bmm + + assert captured["input_dtype"] == torch.float32, ( + f"BMM input should be FP32 but got {captured['input_dtype']}" + ) + assert captured["output_dtype"] == torch.float32, ( + f"BMM output should be FP32 but got {captured['output_dtype']}" + ) + + def test_bmm_fp32_with_fp8_block(self): + """BMM stays FP32 even when block is wrapped in te.autocast(enabled=True).""" + mod = TriangularUpdateTE(dim=DIM).to(DEVICE) + x_bf16 = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, N, device=DEVICE, dtype=torch.bfloat16) + + captured = {} + orig_bmm = torch.bmm + + def patched_bmm(a, b): + captured["dtype"] = a.dtype + return orig_bmm(a, b) + + torch.bmm = patched_bmm + try: + with te.autocast(enabled=True): + mod(x_bf16, mask) + finally: + torch.bmm = orig_bmm + + assert captured["dtype"] == torch.float32, ( + f"BMM should be FP32 inside te.autocast but got {captured['dtype']}" + ) + + def test_bmm_bf16_when_tri_einsum_bf16(self): + """BMM runs in BF16 when tri_einsum="bf16" (ambient dtype, no .float() cast).""" + cp = ComponentPrecisionConfig(tri_einsum="bf16") + mod = TriangularUpdateTE(dim=DIM, component_precision=cp).to(DEVICE) + x_bf16 = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, N, device=DEVICE, dtype=torch.bfloat16) + + captured = {} + orig_bmm = torch.bmm + + def patched_bmm(a, b): + captured["dtype"] = a.dtype + return orig_bmm(a, b) + + torch.bmm = patched_bmm + try: + mod(x_bf16, mask) + finally: + torch.bmm = orig_bmm + + assert captured["dtype"] == torch.bfloat16, ( + f"BMM should be BF16 when tri_einsum='bf16' but got {captured['dtype']}" + ) + + def test_bmm_bf16_backward_compat_bool_true(self): + """Bool True normalizes to 'bf16' via __post_init__.""" + cp = ComponentPrecisionConfig(tri_einsum=True) + assert cp.tri_einsum == "bf16" + + def test_bmm_bf16_backward_compat_bool_false(self): + """Bool False normalizes to 'off' via __post_init__.""" + cp = ComponentPrecisionConfig(tri_einsum=False) + assert cp.tri_einsum == "off" + + def test_bmm_fp32_when_tri_einsum_off(self): + """BMM stays FP32 when tri_einsum="off" (explicit .float() cast).""" + cp = ComponentPrecisionConfig(tri_einsum="off") + mod = TriangularUpdateTE(dim=DIM, component_precision=cp).to(DEVICE) + x_bf16 = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, N, device=DEVICE, dtype=torch.bfloat16) + + captured = {} + orig_bmm = torch.bmm + + def patched_bmm(a, b): + captured["dtype"] = a.dtype + return orig_bmm(a, b) + + torch.bmm = patched_bmm + try: + mod(x_bf16, mask) + finally: + torch.bmm = orig_bmm + + assert captured["dtype"] == torch.float32, ( + f"BMM should be FP32 when tri_einsum='off' but got {captured['dtype']}" + ) + + def test_bmm_fp32_with_no_component_precision(self): + """Without ComponentPrecisionConfig, BMM defaults to FP32.""" + mod = TriangularUpdateTE(dim=DIM).to(DEVICE) + x_bf16 = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, N, device=DEVICE, dtype=torch.bfloat16) + + captured = {} + orig_bmm = torch.bmm + + def patched_bmm(a, b): + captured["dtype"] = a.dtype + return orig_bmm(a, b) + + torch.bmm = patched_bmm + try: + mod(x_bf16, mask) + finally: + torch.bmm = orig_bmm + + assert captured["dtype"] == torch.float32, ( + f"BMM should default to FP32 without component_precision but got {captured['dtype']}" + ) + + +class TestFP8StateInBlocks: + """Verify te.autocast correctly enables/disables FP8 per block.""" + + def test_fp8_enabled_in_fp8_block(self): + """FP8GlobalStateManager.is_fp8_enabled() should be True inside an FP8 block's forward.""" + fp8_states = [] + + mod = MiniFormerTE(dim=DIM, blocks=2, block_precision=["fp8", None]).to(DEVICE) + + for i, block in enumerate(mod.blocks): + + def make_hook(block_idx): + def hook(module, input, output): + fp8_states.append((block_idx, FP8GlobalStateManager.is_fp8_enabled())) + + return hook + + block.register_forward_hook(make_hook(i)) + + x = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, N, device=DEVICE, dtype=torch.bfloat16) + mod(x, mask) + + assert fp8_states[0] == (0, True), f"Block 0 should have FP8 enabled but got {fp8_states[0]}" + assert fp8_states[1] == (1, False), f"Block 1 should have FP8 disabled but got {fp8_states[1]}" + + def test_fp8_disabled_in_bf16_block(self): + """All blocks BF16 — FP8 should never be enabled.""" + fp8_states = [] + + mod = MiniFormerTE(dim=DIM, blocks=2, block_precision=[None, None]).to(DEVICE) + + for i, block in enumerate(mod.blocks): + + def make_hook(block_idx): + def hook(module, input, output): + fp8_states.append((block_idx, FP8GlobalStateManager.is_fp8_enabled())) + + return hook + + block.register_forward_hook(make_hook(i)) + + x = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, N, device=DEVICE, dtype=torch.bfloat16) + mod(x, mask) + + for block_idx, fp8_enabled in fp8_states: + assert not fp8_enabled, f"Block {block_idx} should have FP8 disabled but it's enabled" + + def test_all_blocks_fp8(self): + """All blocks FP8 — FP8 should be enabled in every block.""" + fp8_states = [] + + mod = MiniFormerTE(dim=DIM, blocks=3, block_precision=["fp8", "fp8", "fp8"]).to(DEVICE) + + for i, block in enumerate(mod.blocks): + + def make_hook(block_idx): + def hook(module, input, output): + fp8_states.append((block_idx, FP8GlobalStateManager.is_fp8_enabled())) + + return hook + + block.register_forward_hook(make_hook(i)) + + x = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, N, device=DEVICE, dtype=torch.bfloat16) + mod(x, mask) + + for block_idx, fp8_enabled in fp8_states: + assert fp8_enabled, f"Block {block_idx} should have FP8 enabled but it's disabled" + + def test_mixed_precision_pattern(self): + """Alternating FP8/BF16 blocks — verify correct per-block FP8 state.""" + fp8_states = [] + + mod = MiniFormerTE(dim=DIM, blocks=4, block_precision=["fp8", None, "fp8", None]).to(DEVICE) + + for i, block in enumerate(mod.blocks): + + def make_hook(block_idx): + def hook(module, input, output): + fp8_states.append((block_idx, FP8GlobalStateManager.is_fp8_enabled())) + + return hook + + block.register_forward_hook(make_hook(i)) + + x = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, N, device=DEVICE, dtype=torch.bfloat16) + mod(x, mask) + + expected = [(0, True), (1, False), (2, True), (3, False)] + for (idx, actual), (_, expect) in zip(fp8_states, expected): + assert actual == expect, f"Block {idx}: expected FP8={expect}, got FP8={actual}" + + +class TestComponentPrecisionOverrides: + """Verify component-level precision overrides actually change FP8 state for sub-operations.""" + + def test_ffn_excluded_from_fp8(self): + """FFN runs with FP8 disabled even when block is in FP8.""" + cp = ComponentPrecisionConfig(ffn=False) + mod = TransitionUpdateTE(dim=DIM, hidden=DIM * 4, component_precision=cp).to(DEVICE) + + fp8_in_ffn = [] + + def hook(module, input, output): + fp8_in_ffn.append(FP8GlobalStateManager.is_fp8_enabled()) + + mod.fc1.register_forward_hook(hook) + + x = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + + with te.autocast(enabled=True): + mod(x) + + assert len(fp8_in_ffn) == 1 + assert not fp8_in_ffn[0], "FFN fc1 should have FP8 disabled when component_precision.ffn=False" + + def test_ffn_included_in_fp8(self): + """FFN runs with FP8 enabled when component_precision.ffn=True.""" + cp = ComponentPrecisionConfig(ffn=True) + mod = TransitionUpdateTE(dim=DIM, hidden=DIM * 4, component_precision=cp).to(DEVICE) + + fp8_in_ffn = [] + + def hook(module, input, output): + fp8_in_ffn.append(FP8GlobalStateManager.is_fp8_enabled()) + + mod.fc1.register_forward_hook(hook) + + x = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + + with te.autocast(enabled=True): + mod(x) + + assert len(fp8_in_ffn) == 1 + assert fp8_in_ffn[0], "FFN fc1 should have FP8 enabled when component_precision.ffn=True" + + def test_tri_proj_excluded_from_fp8(self): + """Triangular projections run with FP8 disabled when tri_proj=False.""" + cp = ComponentPrecisionConfig(tri_proj=False, tri_gate=True) + mod = TriangularUpdateTE(dim=DIM, component_precision=cp).to(DEVICE) + + fp8_states = {"proj": [], "gate": []} + + def proj_hook(module, input, output): + fp8_states["proj"].append(FP8GlobalStateManager.is_fp8_enabled()) + + def gate_hook(module, input, output): + fp8_states["gate"].append(FP8GlobalStateManager.is_fp8_enabled()) + + mod.pi.register_forward_hook(proj_hook) + mod.gi.register_forward_hook(gate_hook) + + x = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, N, device=DEVICE, dtype=torch.bfloat16) + + with te.autocast(enabled=True): + mod(x, mask) + + assert not any(fp8_states["proj"]), "tri_proj should have FP8 disabled" + assert all(fp8_states["gate"]), "tri_gate should have FP8 enabled" + + def test_tri_gate_excluded_from_fp8(self): + """Triangular gates run with FP8 disabled when tri_gate=False.""" + cp = ComponentPrecisionConfig(tri_proj=True, tri_gate=False) + mod = TriangularUpdateTE(dim=DIM, component_precision=cp).to(DEVICE) + + fp8_states = {"proj": [], "gate": []} + + def proj_hook(module, input, output): + fp8_states["proj"].append(FP8GlobalStateManager.is_fp8_enabled()) + + def gate_hook(module, input, output): + fp8_states["gate"].append(FP8GlobalStateManager.is_fp8_enabled()) + + mod.pi.register_forward_hook(proj_hook) + mod.gi.register_forward_hook(gate_hook) + + x = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, N, device=DEVICE, dtype=torch.bfloat16) + + with te.autocast(enabled=True): + mod(x, mask) + + assert all(fp8_states["proj"]), "tri_proj should have FP8 enabled" + assert not any(fp8_states["gate"]), "tri_gate should have FP8 disabled" + + def test_dist_head_excluded_from_fp8(self): + """Distogram head runs with FP8 disabled when dist_head=False.""" + cp = ComponentPrecisionConfig(dist_head=False) + c_s, c_z = 512, DIM + mod = FoldingTrunkTE( + c_s=c_s, + c_z=c_z, + bins=32, + disto_bins=64, + num_layers=2, + block_precision=["fp8", "fp8"], + component_precision=cp, + ).to(DEVICE) + + fp8_in_dist = [] + + def hook(module, input, output): + fp8_in_dist.append(FP8GlobalStateManager.is_fp8_enabled()) + + mod.fc_out_1.register_forward_hook(hook) + + s_s = torch.randn(B, N, c_s, device=DEVICE, dtype=torch.bfloat16) + s_z = torch.randn(B, N, N, c_z, device=DEVICE, dtype=torch.bfloat16) + mask = torch.ones(B, N, device=DEVICE, dtype=torch.bfloat16) + + mod.eval() + with torch.no_grad(): + mod(s_s, s_z, mask, num_recycling=0) + + assert len(fp8_in_dist) >= 1 + assert not any(fp8_in_dist), "dist_head fc_out_1 should have FP8 disabled" + + def test_no_component_precision_all_fp8(self): + """Without component_precision, all te.Linear layers in FP8 block run in FP8.""" + mod = TransitionUpdateTE(dim=DIM, hidden=DIM * 4).to(DEVICE) + + fp8_in_ffn = [] + + def hook(module, input, output): + fp8_in_ffn.append(FP8GlobalStateManager.is_fp8_enabled()) + + mod.fc1.register_forward_hook(hook) + + x = torch.randn(B, N, N, DIM, device=DEVICE, dtype=torch.bfloat16) + + with te.autocast(enabled=True): + mod(x) + + assert len(fp8_in_ffn) == 1 + assert fp8_in_ffn[0], "Without component_precision, FFN should run in FP8" + + +class TestTriMulBmmEquivalence: + """Verify tri_mul_bmm produces identical results to the original einsum.""" + + def test_outgoing_einsum_equivalence(self): + """tri_mul_bmm(a, b, k_dim=2) == torch.einsum('bikd,bjkd->bijd', a, b).""" + torch.manual_seed(42) + a = torch.randn(B, N, N, DIM // 4, device=DEVICE, dtype=torch.float32) + b = torch.randn(B, N, N, DIM // 4, device=DEVICE, dtype=torch.float32) + + expected = torch.einsum("bikd,bjkd->bijd", a, b) + actual = tri_mul_bmm(a, b, k_dim=2) + + assert torch.allclose(actual, expected, atol=1e-5), ( + f"k_dim=2 mismatch: max diff {(actual - expected).abs().max()}" + ) + + def test_incoming_einsum_equivalence(self): + """tri_mul_bmm(a, b, k_dim=1) == torch.einsum('bkid,bkjd->bijd', a, b).""" + torch.manual_seed(42) + a = torch.randn(B, N, N, DIM // 4, device=DEVICE, dtype=torch.float32) + b = torch.randn(B, N, N, DIM // 4, device=DEVICE, dtype=torch.float32) + + expected = torch.einsum("bkid,bkjd->bijd", a, b) + actual = tri_mul_bmm(a, b, k_dim=1) + + assert torch.allclose(actual, expected, atol=1e-5), ( + f"k_dim=1 mismatch: max diff {(actual - expected).abs().max()}" + ) + + def test_bf16_close_to_fp32(self): + """BF16 mode produces results close to FP32 reference.""" + if not torch.cuda.is_available(): + return + torch.manual_seed(42) + a = torch.randn(B, N, N, DIM // 4, device="cuda", dtype=torch.bfloat16) + b = torch.randn(B, N, N, DIM // 4, device="cuda", dtype=torch.bfloat16) + + ref = tri_mul_bmm(a.float(), b.float(), k_dim=2).to(torch.bfloat16) + bf16_result = tri_mul_bmm(a, b, k_dim=2, mode="bf16") + + assert torch.allclose(bf16_result, ref, atol=0.5, rtol=0.05), ( + f"BF16 vs FP32 mismatch: max diff {(bf16_result - ref).abs().max()}" + ) diff --git a/bionemo-recipes/recipes/esm2_minifold_te/tests/test_quantization.py b/bionemo-recipes/recipes/esm2_minifold_te/tests/test_quantization.py new file mode 100644 index 0000000000..9a4ad845b5 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/tests/test_quantization.py @@ -0,0 +1,410 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for block-wise quantization and quant stats logging for MiniFold TE. + +Adapted from esm2_native_te/tests/test_quantization.py with regex patterns +updated for MiniFold's module hierarchy (fold.miniformer.blocks.N.{triangular,transition}). +""" + +import re +import sys +from pathlib import Path + +import pytest +import yaml + + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from quantization import BufferedQuantLogger, generate_layer_regex, resolve_layer_precision, update_quant_stats_config + + +# -- resolve_layer_precision -- + + +def test_fp8_enabled_no_layers_defaults_all(): + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result == ["fp8"] * 6 + + +def test_fp4_enabled_no_layers_defaults_all(): + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=None + ) + assert result == ["fp4"] * 6 + + +def test_fp8_explicit_layers(): + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[1, 3, 5], fp4_layers=None + ) + assert result == ["fp8", None, "fp8", None, "fp8", None] + + +def test_fp4_explicit_layers(): + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=[2, 4, 6] + ) + assert result == [None, "fp4", None, "fp4", None, "fp4"] + + +def test_mixed_fp8_fp4_explicit(): + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 3, 4], fp4_layers=[2, 5] + ) + assert result == ["fp8", "fp4", "fp8", "fp8", "fp4", None] + + +def test_both_enabled_no_layers_raises(): + with pytest.raises(ValueError, match="Both fp8_config and fp4_config are enabled"): + resolve_layer_precision(num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=None) + + +def test_overlapping_layers_raises(): + with pytest.raises(ValueError, match="fp8_layers and fp4_layers cannot have overlapping"): + resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=[3, 4, 5] + ) + + +def test_disabled_ignores_layers(): + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=[1, 2, 3], fp4_layers=[4, 5, 6] + ) + assert result == [None] * 6 + + +def test_both_disabled(): + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result == [None] * 6 + + +def test_48_block_model_defaults_all(): + result = resolve_layer_precision( + num_layers=48, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result == ["fp8"] * 48 + + +def test_fp8_enabled_empty_list(): + result = resolve_layer_precision(num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[], fp4_layers=None) + assert result == [None] * 6 + + +def test_both_enabled_fp8_specified_fp4_defaults_to_remaining(): + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=None + ) + assert result == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] + + +def test_returns_correct_length(): + for n in [1, 8, 48]: + result = resolve_layer_precision( + num_layers=n, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert len(result) == n + + +# -- generate_layer_regex (MiniFold-specific patterns) -- + + +def test_single_block(): + """Single block (1-indexed=3, 0-indexed=2) should match block 2 in module names.""" + regex = generate_layer_regex([3]) + assert re.search(regex, "fold.miniformer.blocks.2.triangular.pi") + assert re.search(regex, "fold.miniformer.blocks.2.transition.fc1") + assert not re.search(regex, "fold.miniformer.blocks.1.triangular.pi") + assert not re.search(regex, "fold.miniformer.blocks.3.triangular.pi") + + +def test_multiple_blocks(): + """Multiple blocks should match any of them (converted to 0-indexed).""" + regex = generate_layer_regex([1, 2, 3]) + # 1-indexed [1,2,3] -> 0-indexed [0,1,2] + assert re.search(regex, "fold.miniformer.blocks.0.triangular.pi") + assert re.search(regex, "fold.miniformer.blocks.1.transition.fc1") + assert re.search(regex, "fold.miniformer.blocks.2.triangular.go") + assert not re.search(regex, "fold.miniformer.blocks.3.triangular.pi") + + +def test_matches_correct_sublayers(): + """Regex should match pi, gi, po, go, fc1, fc2.""" + regex = generate_layer_regex([1]) + # Block 0 (1-indexed=1) + assert re.search(regex, "fold.miniformer.blocks.0.triangular.pi") + assert re.search(regex, "fold.miniformer.blocks.0.triangular.gi") + assert re.search(regex, "fold.miniformer.blocks.0.triangular.po") + assert re.search(regex, "fold.miniformer.blocks.0.triangular.go") + assert re.search(regex, "fold.miniformer.blocks.0.transition.fc1") + assert re.search(regex, "fold.miniformer.blocks.0.transition.fc2") + # Should not match unrelated names + assert not re.search(regex, "fold.miniformer.blocks.0.triangular.input_norm") + assert not re.search(regex, "fold.miniformer.blocks.0.transition.norm") + + +def test_none_returns_disabled_pattern(): + regex = generate_layer_regex(None) + assert "DISABLED" in regex + assert not re.search(regex, "fold.miniformer.blocks.0.triangular.pi") + + +def test_empty_list_returns_disabled_pattern(): + regex = generate_layer_regex([]) + assert "DISABLED" in regex + + +def test_1indexed_to_0indexed_conversion(): + """User specifies 1-indexed, but module names are 0-indexed.""" + regex = generate_layer_regex([1]) + # Should match block 0 (0-indexed) + assert re.search(regex, "fold.miniformer.blocks.0.triangular.pi") + # Should NOT match block 1 (that would be user's block 2) + assert not re.search(regex, "fold.miniformer.blocks.1.triangular.pi") + + +def test_large_block_numbers(): + """High block numbers (e.g., 48-block model) should convert correctly.""" + regex = generate_layer_regex([47, 48]) + # 1-indexed [47,48] -> 0-indexed [46,47] + assert re.search(regex, "fold.miniformer.blocks.46.transition.fc2") + assert re.search(regex, "fold.miniformer.blocks.47.triangular.gi") + assert not re.search(regex, "fold.miniformer.blocks.45.transition.fc1") + + +# -- update_quant_stats_config -- + + +@pytest.fixture +def fp8_only_config(tmp_path): + """Create an FP8-only stats config file.""" + config = { + "example_fp8_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogFp8TensorStats": { + "enabled": True, + "tensors_struct": [{"tensor": "activation", "stats": ["underflows%"], "freq": 10}], + } + }, + } + } + config_path = tmp_path / "fp8_stats.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + return str(config_path) + + +@pytest.fixture +def fp4_fp8_config(tmp_path): + """Create a combined FP4+FP8 stats config file.""" + config = { + "example_fp4_tensor_stat_collection": { + "enabled": True, + "layers": {"layer_name_regex_pattern": "PLACEHOLDER"}, + "transformer_engine": {"LogNvfp4TensorStats": {"enabled": True}}, + }, + "example_fp8_tensor_stat_collection": { + "enabled": True, + "layers": {"layer_name_regex_pattern": "PLACEHOLDER"}, + "transformer_engine": {"LogFp8TensorStats": {"enabled": True}}, + }, + } + config_path = tmp_path / "fp4_fp8_stats.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + return str(config_path) + + +def test_fp8_layers_updates_regex(fp8_only_config): + """FP8 block list should update the regex in the output config.""" + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2, 3]) + with open(output_path) as f: + result = yaml.safe_load(f) + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + # 1-indexed [1,2,3] -> 0-indexed [0,1,2] + assert re.search(regex, "fold.miniformer.blocks.0.triangular.pi") + assert re.search(regex, "fold.miniformer.blocks.2.transition.fc2") + assert not re.search(regex, "fold.miniformer.blocks.3.triangular.pi") + + +def test_none_layers_disables_matching(fp8_only_config): + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=None) + with open(output_path) as f: + result = yaml.safe_load(f) + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert "DISABLED" in regex + + +def test_fp4_and_fp8_both_updated(fp4_fp8_config): + output_path = update_quant_stats_config(config_file=fp4_fp8_config, fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6]) + with open(output_path) as f: + result = yaml.safe_load(f) + + # FP4 section should have regex for blocks 1-3 (0-indexed 0-2) + fp4_regex = result["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(fp4_regex, "fold.miniformer.blocks.0.transition.fc1") + assert re.search(fp4_regex, "fold.miniformer.blocks.2.triangular.pi") + assert not re.search(fp4_regex, "fold.miniformer.blocks.3.triangular.pi") + + # FP8 section should have regex for blocks 4-6 (0-indexed 3-5) + fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(fp8_regex, "fold.miniformer.blocks.4.triangular.pi") + assert not re.search(fp8_regex, "fold.miniformer.blocks.1.triangular.pi") + + +def test_original_file_not_modified(fp8_only_config): + with open(fp8_only_config) as f: + original_content = f.read() + + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2]) + + assert output_path != fp8_only_config + with open(fp8_only_config) as f: + assert f.read() == original_content + + +def test_preserves_other_config_fields(fp8_only_config): + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1]) + with open(output_path) as f: + result = yaml.safe_load(f) + assert result["example_fp8_tensor_stat_collection"]["transformer_engine"]["LogFp8TensorStats"]["enabled"] is True + + +def test_missing_section_is_skipped(fp8_only_config): + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=[1, 2], fp8_layers=[3, 4]) + with open(output_path) as f: + result = yaml.safe_load(f) + assert "example_fp4_tensor_stat_collection" not in result + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + # 1-indexed [3,4] -> 0-indexed [2,3] + assert re.search(regex, "fold.miniformer.blocks.2.triangular.pi") + + +def test_with_real_fp8_config(): + """Test with the actual fp8_debugging_stats.yaml file.""" + config_path = Path(__file__).parent.parent / "fp8_debugging_stats.yaml" + if not config_path.exists(): + pytest.skip("fp8_debugging_stats.yaml not found") + + output_path = update_quant_stats_config(config_file=str(config_path), fp4_layers=None, fp8_layers=[1, 4, 8]) + with open(output_path) as f: + result = yaml.safe_load(f) + + fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + # 1-indexed [1,4,8] -> 0-indexed [0,3,7] + assert re.search(fp8_regex, "fold.miniformer.blocks.0.transition.fc1") + assert re.search(fp8_regex, "fold.miniformer.blocks.3.triangular.go") + assert re.search(fp8_regex, "fold.miniformer.blocks.7.triangular.pi") + assert not re.search(fp8_regex, "fold.miniformer.blocks.1.transition.fc1") + + +# -- BufferedQuantLogger -- + + +class TestBufferedQuantLogger: + def test_captures_underflow_stats(self): + logger = BufferedQuantLogger() + logger.log_scalar("model.fold.miniformer.blocks.0.transition.fc1_gradient_underflows%", 0.5, 100) + logger.log_scalar("model.fold.miniformer.blocks.0.transition.fc2_gradient_underflows%", 1.2, 100) + assert len(logger._underflow_buffer) == 2 + assert logger._underflow_buffer["model.fold.miniformer.blocks.0.transition.fc1_gradient_underflows%"] == [ + (100, 0.5) + ] + + def test_ignores_non_underflow_stats(self): + logger = BufferedQuantLogger() + logger.log_scalar("model.fold.miniformer.blocks.0.transition.fc1_activation_scale_inv_min", 0.01, 100) + logger.log_scalar("model.fold.miniformer.blocks.0.transition.fc1_weight_mse", 0.001, 100) + assert len(logger._underflow_buffer) == 0 + + def test_accumulates_across_iterations(self): + logger = BufferedQuantLogger() + metric = "model.fold.miniformer.blocks.1.transition.fc1_gradient_underflows%" + logger.log_scalar(metric, 0.5, 100) + logger.log_scalar(metric, 0.3, 200) + logger.log_scalar(metric, 0.1, 300) + assert len(logger._underflow_buffer[metric]) == 3 + + def test_generate_heatmap_empty_returns_none(self): + logger = BufferedQuantLogger() + assert logger.generate_heatmap() is None + + def test_generate_heatmap_returns_figure(self): + import matplotlib.figure + + logger = BufferedQuantLogger() + # Populate with synthetic MiniFold metrics + for block in range(3): + for sublayer in ["fc1", "fc2"]: + metric = f"model.fold.miniformer.blocks.{block}.transition.{sublayer}_gradient_underflows%" + for step in range(0, 50, 10): + logger.log_scalar(metric, float(block * 0.5 + step * 0.01), step) + + fig = logger.generate_heatmap() + assert fig is not None + assert isinstance(fig, matplotlib.figure.Figure) + import matplotlib.pyplot as plt + + plt.close(fig) + + def test_generate_heatmap_correct_labels(self): + import matplotlib.pyplot as plt + + logger = BufferedQuantLogger() + logger.log_scalar("model.fold.miniformer.blocks.0.transition.fc1_gradient_underflows%", 0.5, 0) + logger.log_scalar("model.fold.miniformer.blocks.0.transition.fc2_gradient_underflows%", 0.3, 0) + logger.log_scalar("model.fold.miniformer.blocks.1.triangular.pi_gradient_underflows%", 1.0, 0) + + fig = logger.generate_heatmap() + assert fig is not None + ax = fig.axes[0] + y_labels = [t.get_text() for t in ax.get_yticklabels()] + assert "B0 fc1" in y_labels + assert "B0 fc2" in y_labels + assert "B1 pi" in y_labels + plt.close(fig) + + def test_minifold_layer_name_parsing(self): + """Verify regex extracts block/module/sublayer from metric names.""" + from quantization import _MINIFOLD_UNDERFLOW_PATTERN + + match = _MINIFOLD_UNDERFLOW_PATTERN.search("model.fold.miniformer.blocks.5.triangular.gi_gradient_underflows%") + assert match is not None + assert match.group(1) == "5" + assert match.group(2) == "triangular" + assert match.group(3) == "gi" + + match = _MINIFOLD_UNDERFLOW_PATTERN.search( + "model.fold.miniformer.blocks.47.transition.fc2_gradient_underflows%" + ) + assert match is not None + assert match.group(1) == "47" + assert match.group(2) == "transition" + assert match.group(3) == "fc2" + + def test_no_match_for_non_minifold_pattern(self): + from quantization import _MINIFOLD_UNDERFLOW_PATTERN + + match = _MINIFOLD_UNDERFLOW_PATTERN.search("model.encoder.layers.3.self_attention.proj_gradient_underflows%") + assert match is None diff --git a/bionemo-recipes/recipes/esm2_minifold_te/tests/test_te_equivalence.py b/bionemo-recipes/recipes/esm2_minifold_te/tests/test_te_equivalence.py new file mode 100644 index 0000000000..0df612615c --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/tests/test_te_equivalence.py @@ -0,0 +1,462 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Numerical equivalence tests between original MiniFold modules and TE versions. + +Each test: +1. Creates original and TE modules with matching dimensions +2. Copies weights from original -> TE +3. Runs both on the same random input (fixed seed) +4. Asserts outputs match within tolerance + +Run with: pytest tests/test_te_equivalence.py -v +""" + +import sys +from pathlib import Path + +import pytest +import torch + + +# Add recipe root to path for TE modules +sys.path.insert(0, str(Path(__file__).parent.parent)) + +# Add minifold root to path for original modules +sys.path.insert(0, "/workspaces/minifold") + +# Original modules +from minifold.model.heads import PerResidueLDDTCaPredictor +from minifold.model.miniformer import Block, MiniFormer, TransitionUpdate, TriangularUpdate +from minifold.model.model import FoldingTrunk, PairToSequence, SequenceToPair +from minifold.model.structure import MLP, AngleResnet, AngleResnetBlock, Attention + +# TE modules (recipe-local) +from heads_te import PerResidueLDDTCaPredictorTE +from miniformer_te import BlockTE, MiniFormerTE, TransitionUpdateTE, TriangularUpdateTE +from model_te import FoldingTrunkTE, PairToSequenceTE, SequenceToPairTE +from structure_te import MLPTE, AngleResnetBlockTE, AngleResnetTE, AttentionTE + +# Weight copy (recipe-local) +from weight_copy import ( + copy_angle_resnet_block_to_te, + copy_angle_resnet_to_te, + copy_attention_to_te, + copy_block_to_te, + copy_folding_trunk_to_te, + copy_miniformer_to_te, + copy_mlp_to_te, + copy_pair_to_seq_to_te, + copy_plddt_to_te, + copy_seq_to_pair_to_te, + copy_transition_update_from_te, + copy_transition_update_to_te, + copy_triangular_update_to_te, +) + + +DEVICE = "cuda" +ATOL = 1e-5 +RTOL = 1e-5 +SEED = 42 +INPUT_SEED = 123 +DIM = 128 +HIDDEN = 512 +B = 2 +N = 16 + + +@pytest.fixture(autouse=True) +def set_seed(): + """Set seed before each test for reproducibility.""" + torch.manual_seed(SEED) + torch.cuda.manual_seed_all(SEED) + + +# =========================================================================== +# TransitionUpdate +# =========================================================================== + + +class TestTransitionUpdate: + def test_equivalence(self): + orig = TransitionUpdate(dim=DIM, hidden=HIDDEN, kernels=False).to(DEVICE) + te_mod = TransitionUpdateTE(dim=DIM, hidden=HIDDEN).to(DEVICE) + copy_transition_update_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + x = torch.randn(B, N, N, DIM, device=DEVICE) + + with torch.no_grad(): + out_orig = orig(x) + out_te = te_mod(x) + + assert out_orig.shape == out_te.shape + assert torch.allclose(out_orig, out_te, atol=ATOL, rtol=RTOL), ( + f"Max diff: {(out_orig - out_te).abs().max().item()}" + ) + + def test_weight_roundtrip(self): + orig = TransitionUpdate(dim=DIM, hidden=HIDDEN, kernels=False).to(DEVICE) + te_mod = TransitionUpdateTE(dim=DIM, hidden=HIDDEN).to(DEVICE) + orig_copy = TransitionUpdate(dim=DIM, hidden=HIDDEN, kernels=False).to(DEVICE) + + copy_transition_update_to_te(orig, te_mod) + copy_transition_update_from_te(te_mod, orig_copy) + + for (n1, p1), (n2, p2) in zip(orig.named_parameters(), orig_copy.named_parameters()): + assert torch.equal(p1, p2), f"Mismatch in {n1}" + + def test_gradient_flow(self): + te_mod = TransitionUpdateTE(dim=DIM, hidden=HIDDEN).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE, requires_grad=True) + out = te_mod(x) + out.sum().backward() + for name, param in te_mod.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + + +# =========================================================================== +# TriangularUpdate +# =========================================================================== + + +class TestTriangularUpdate: + def test_equivalence(self): + orig = TriangularUpdate(dim=DIM, kernels=False).to(DEVICE) + te_mod = TriangularUpdateTE(dim=DIM).to(DEVICE) + copy_triangular_update_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + + with torch.no_grad(): + out_orig = orig(x, mask) + out_te = te_mod(x, mask) + + assert out_orig.shape == out_te.shape + assert torch.allclose(out_orig, out_te, atol=ATOL, rtol=RTOL), ( + f"Max diff: {(out_orig - out_te).abs().max().item()}" + ) + + def test_with_mask(self): + orig = TriangularUpdate(dim=DIM, kernels=False).to(DEVICE) + te_mod = TriangularUpdateTE(dim=DIM).to(DEVICE) + copy_triangular_update_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + mask[:, :, N // 2 :] = 0 # mask out half + + with torch.no_grad(): + out_orig = orig(x, mask) + out_te = te_mod(x, mask) + + assert torch.allclose(out_orig, out_te, atol=ATOL, rtol=RTOL), ( + f"Max diff: {(out_orig - out_te).abs().max().item()}" + ) + + def test_gradient_flow(self): + te_mod = TriangularUpdateTE(dim=DIM).to(DEVICE) + x = torch.randn(B, N, N, DIM, device=DEVICE, requires_grad=True) + mask = torch.ones(B, N, N, device=DEVICE) + out = te_mod(x, mask) + out.sum().backward() + for name, param in te_mod.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + + +# =========================================================================== +# Block +# =========================================================================== + + +class TestBlock: + def test_equivalence(self): + orig = Block(dim=DIM, kernels=False).to(DEVICE) + te_mod = BlockTE(dim=DIM).to(DEVICE) + copy_block_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + + with torch.no_grad(): + out_orig = orig(x, mask) + out_te = te_mod(x, mask) + + assert torch.allclose(out_orig, out_te, atol=ATOL, rtol=RTOL), ( + f"Max diff: {(out_orig - out_te).abs().max().item()}" + ) + + +# =========================================================================== +# MiniFormer +# =========================================================================== + + +class TestMiniFormer: + def test_equivalence(self): + num_blocks = 2 + orig = MiniFormer(dim=DIM, blocks=num_blocks, kernels=False).to(DEVICE) + te_mod = MiniFormerTE(dim=DIM, blocks=num_blocks).to(DEVICE) + copy_miniformer_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + x = torch.randn(B, N, N, DIM, device=DEVICE) + mask = torch.ones(B, N, N, device=DEVICE) + + with torch.no_grad(): + out_orig = orig(x, mask) + out_te = te_mod(x, mask) + + assert torch.allclose(out_orig, out_te, atol=1e-4, rtol=1e-4), ( + f"Max diff: {(out_orig - out_te).abs().max().item()}" + ) + + +# =========================================================================== +# SequenceToPair +# =========================================================================== + + +class TestSequenceToPair: + def test_equivalence(self): + seq_dim = 1024 + inner = 64 + pair_dim = DIM + orig = SequenceToPair(seq_dim, inner, pair_dim).to(DEVICE) + te_mod = SequenceToPairTE(seq_dim, inner, pair_dim).to(DEVICE) + copy_seq_to_pair_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + x = torch.randn(B, N, seq_dim, device=DEVICE) + + with torch.no_grad(): + out_orig = orig(x) + out_te = te_mod(x) + + assert out_orig.shape == out_te.shape == (B, N, N, pair_dim) + # Slightly relaxed tolerance: outer product + diff accumulate FP32 rounding + assert torch.allclose(out_orig, out_te, atol=5e-4, rtol=1e-4), ( + f"Max diff: {(out_orig - out_te).abs().max().item()}" + ) + + +# =========================================================================== +# PairToSequence +# =========================================================================== + + +class TestPairToSequence: + def test_equivalence(self): + c_z, c_s = DIM, 1024 + orig = PairToSequence(c_z=c_z, c_s=c_s).to(DEVICE) + te_mod = PairToSequenceTE(c_z=c_z, c_s=c_s).to(DEVICE) + copy_pair_to_seq_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + s_z = torch.randn(B, N, N, c_z, device=DEVICE) + s_s = torch.randn(B, N, c_s, device=DEVICE) + pair_mask = torch.ones(B, N, N, device=DEVICE) + + with torch.no_grad(): + out_orig = orig(s_z, s_s, pair_mask) + out_te = te_mod(s_z, s_s, pair_mask) + + assert out_orig.shape == out_te.shape == (B, N, c_s) + assert torch.allclose(out_orig, out_te, atol=ATOL, rtol=RTOL), ( + f"Max diff: {(out_orig - out_te).abs().max().item()}" + ) + + +# =========================================================================== +# FoldingTrunk +# =========================================================================== + + +class TestFoldingTrunk: + def test_equivalence(self): + c_s, c_z = 1024, DIM + num_blocks = 2 + orig = FoldingTrunk( + c_s=c_s, + c_z=c_z, + bins=32, + disto_bins=64, + num_layers=num_blocks, + kernels=False, + ).to(DEVICE) + te_mod = FoldingTrunkTE( + c_s=c_s, + c_z=c_z, + bins=32, + disto_bins=64, + num_layers=num_blocks, + ).to(DEVICE) + copy_folding_trunk_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + s_s = torch.randn(B, N, c_s, device=DEVICE) + s_z = torch.randn(B, N, N, c_z, device=DEVICE) + mask = torch.ones(B, N, device=DEVICE) + + orig.eval() + te_mod.eval() + + with torch.no_grad(): + preds_orig, sz_orig = orig(s_s, s_z, mask, num_recycling=0) + preds_te, sz_te = te_mod(s_s, s_z, mask, num_recycling=0) + + assert preds_orig.shape == preds_te.shape + assert sz_orig.shape == sz_te.shape + assert torch.allclose(preds_orig, preds_te, atol=1e-4, rtol=1e-4), ( + f"Preds max diff: {(preds_orig - preds_te).abs().max().item()}" + ) + assert torch.allclose(sz_orig, sz_te, atol=1e-4, rtol=1e-4), ( + f"s_z max diff: {(sz_orig - sz_te).abs().max().item()}" + ) + + +# =========================================================================== +# Attention (StructureModule) +# =========================================================================== + + +class TestAttention: + def test_equivalence(self): + dim, num_heads, head_width = 1024, 16, 64 + orig = Attention(dim, num_heads, head_width).to(DEVICE) + te_mod = AttentionTE(dim, num_heads, head_width).to(DEVICE) + copy_attention_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + x = torch.randn(B, N, dim, device=DEVICE) + bias = torch.randn(B, num_heads, N, N, device=DEVICE) + mask = torch.ones(B, N, device=DEVICE) + + with torch.no_grad(): + out_orig = orig(x, bias, mask) + out_te = te_mod(x, bias, mask) + + assert out_orig.shape == out_te.shape + # Slightly relaxed: te.Linear kernel ordering differs from nn.Linear + assert torch.allclose(out_orig, out_te, atol=5e-5, rtol=5e-5), ( + f"Max diff: {(out_orig - out_te).abs().max().item()}" + ) + + +# =========================================================================== +# MLP (StructureModule) +# =========================================================================== + + +class TestMLP: + def test_equivalence(self): + in_dim, out_dim = 1024, 1024 + orig = MLP(in_dim, out_dim).to(DEVICE) + te_mod = MLPTE(in_dim, out_dim).to(DEVICE) + copy_mlp_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + x = torch.randn(B, N, in_dim, device=DEVICE) + + with torch.no_grad(): + out_orig = orig(x) + out_te = te_mod(x) + + assert out_orig.shape == out_te.shape + assert torch.allclose(out_orig, out_te, atol=ATOL, rtol=RTOL), ( + f"Max diff: {(out_orig - out_te).abs().max().item()}" + ) + + +# =========================================================================== +# AngleResnetBlock +# =========================================================================== + + +class TestAngleResnetBlock: + def test_equivalence(self): + dim = DIM + orig = AngleResnetBlock(dim).to(DEVICE) + te_mod = AngleResnetBlockTE(dim).to(DEVICE) + copy_angle_resnet_block_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + x = torch.randn(B, N, dim, device=DEVICE) + + with torch.no_grad(): + out_orig = orig(x) + out_te = te_mod(x) + + assert torch.allclose(out_orig, out_te, atol=ATOL, rtol=RTOL), ( + f"Max diff: {(out_orig - out_te).abs().max().item()}" + ) + + +# =========================================================================== +# AngleResnet +# =========================================================================== + + +class TestAngleResnet: + def test_equivalence(self): + c_in, c_hidden = 1024, DIM + orig = AngleResnet(c_in, c_hidden, no_blocks=2, no_angles=7, epsilon=1e-5).to(DEVICE) + te_mod = AngleResnetTE(c_in, c_hidden, no_blocks=2, no_angles=7, epsilon=1e-5).to(DEVICE) + copy_angle_resnet_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + s = torch.randn(B, N, c_in, device=DEVICE) + s_init = torch.randn(B, N, c_in, device=DEVICE) + + with torch.no_grad(): + unnorm_orig, norm_orig = orig(s, s_init) + unnorm_te, norm_te = te_mod(s, s_init) + + assert torch.allclose(unnorm_orig, unnorm_te, atol=ATOL, rtol=RTOL), ( + f"Unnorm max diff: {(unnorm_orig - unnorm_te).abs().max().item()}" + ) + assert torch.allclose(norm_orig, norm_te, atol=ATOL, rtol=RTOL), ( + f"Norm max diff: {(norm_orig - norm_te).abs().max().item()}" + ) + + +# =========================================================================== +# PerResidueLDDTCaPredictor +# =========================================================================== + + +class TestPerResidueLDDTCaPredictor: + def test_equivalence(self): + no_bins, c_in, c_hidden = 50, 1024, DIM + orig = PerResidueLDDTCaPredictor(no_bins, c_in, c_hidden).to(DEVICE) + te_mod = PerResidueLDDTCaPredictorTE(no_bins, c_in, c_hidden).to(DEVICE) + copy_plddt_to_te(orig, te_mod) + + torch.manual_seed(INPUT_SEED) + x = torch.randn(B, N, c_in, device=DEVICE) + + with torch.no_grad(): + out_orig = orig(x) + out_te = te_mod(x) + + assert out_orig.shape == out_te.shape == (B, N, no_bins) + assert torch.allclose(out_orig, out_te, atol=ATOL, rtol=RTOL), ( + f"Max diff: {(out_orig - out_te).abs().max().item()}" + ) diff --git a/bionemo-recipes/recipes/esm2_minifold_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_minifold_te/train_fsdp2.py new file mode 100644 index 0000000000..bfd2647f4d --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/train_fsdp2.py @@ -0,0 +1,433 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FSDP2 training script for ESM2-MiniFold TE structure prediction. + +Hydra-based training loop that: +1. Loads a frozen ESM-2 backbone (HuggingFace) +2. Trains the TE folding head on distogram prediction (Stage 1) +3. Optionally trains the structure module for full 3D prediction (Stage 2) + +Three parameter groups with separate learning rates: +- Backbone: frozen (lr=0) or fine-tuned (lr=3e-5) +- Folding head: lr=1e-4 +- Structure module: lr=1e-4 + +Usage: + # Single GPU + python train_fsdp2.py --config-name L0_sanity + + # Multi GPU + torchrun --nproc_per_node=8 train_fsdp2.py --config-name defaults +""" + +import logging +import os +from pathlib import Path + +import hydra +import torch +import torch.nn.functional as F +from omegaconf import DictConfig, OmegaConf +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard +from torch.optim import AdamW + +from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint +from dataset import create_dataloader +from distributed_config import DistributedConfig +from modeling_esm2_minifold_te import ESM2MiniFoldTE +from perf_logger import PerfLogger +from quantization import ( + BufferedQuantLogger, + ComponentPrecisionConfig, + initialize_quant_stats_logging, + resolve_layer_precision, +) +from scheduler import get_linear_schedule_with_warmup + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def compute_distogram_loss(preds, coords, mask, no_bins=64, max_dist=25.0): + """Compute distogram cross-entropy loss. + + Args: + preds: Predicted distogram logits (B, L, L, no_bins). + coords: Ca coordinates (B, L, 3). + mask: Residue mask (B, L). + no_bins: Number of distance bins. + max_dist: Maximum distance in Angstroms. + + Returns: + Scalar loss tensor. + """ + # Compute pairwise Ca distances + dists = torch.cdist(coords, coords) + + # Bin distances into one-hot labels + boundaries = torch.linspace(2, max_dist, no_bins - 1, device=dists.device) + labels = F.one_hot( + (dists.unsqueeze(-1) > boundaries).sum(dim=-1), + no_bins, + ).to(preds.dtype) + + # Cross-entropy loss + errors = -torch.sum(labels * F.log_softmax(preds, dim=-1), dim=-1) + + # Square mask (exclude self-distances and padding) + square_mask = mask[:, None, :] * mask[:, :, None] + eye = torch.eye(mask.shape[1], device=mask.device).unsqueeze(0) + square_mask = square_mask * (1 - eye) + + # FP16-friendly mean + denom = 1e-5 + square_mask.sum(dim=(-1, -2)) + mean = (errors * square_mask).sum(dim=-1) / denom[..., None] + mean = mean.sum(dim=-1) + return mean.mean() + + +def compute_distogram_metrics(preds, coords, mask, no_bins=64, max_dist=25.0, contact_threshold=8.0): + """Compute structure prediction quality metrics from distogram predictions. + + Args: + preds: Predicted distogram logits (B, L, L, no_bins). + coords: Ca coordinates (B, L, 3). + mask: Residue mask (B, L). + no_bins: Number of distance bins. + max_dist: Maximum distance in Angstroms. + contact_threshold: Distance threshold for contact prediction (Angstroms). + + Returns: + Dict with: distogram_acc, contact_precision, contact_recall, + lddt_from_distogram, mean_distance_error. + """ + with torch.no_grad(): + # True pairwise distances + true_dists = torch.cdist(coords, coords) + + # Bin boundaries and centers + boundaries = torch.linspace(2, max_dist, no_bins - 1, device=preds.device) + bin_centers = torch.cat( + [ + torch.tensor([1.0], device=preds.device), + (boundaries[:-1] + boundaries[1:]) / 2, + torch.tensor([max_dist + 2.0], device=preds.device), + ] + ) + + # True bin indices + true_bins = (true_dists.unsqueeze(-1) > boundaries).sum(dim=-1) + + # Predicted bin indices and probabilities + pred_bins = preds.argmax(dim=-1) + pred_probs = F.softmax(preds, dim=-1) + + # Expected predicted distance from distogram + pred_dists = (pred_probs * bin_centers).sum(dim=-1) + + # Valid pair mask (exclude self and padding) + square_mask = mask[:, None, :] * mask[:, :, None] + eye = torch.eye(mask.shape[1], device=mask.device).unsqueeze(0) + pair_mask = square_mask * (1 - eye) + n_pairs = pair_mask.sum().clamp(min=1) + + # 1. Distogram accuracy + correct = (pred_bins == true_bins).float() * pair_mask + distogram_acc = correct.sum() / n_pairs + + # 2. Contact precision and recall at threshold + true_contacts = (true_dists < contact_threshold).float() * pair_mask + pred_contacts = (pred_dists < contact_threshold).float() * pair_mask + + tp = (true_contacts * pred_contacts).sum() + contact_precision = tp / pred_contacts.sum().clamp(min=1) + contact_recall = tp / true_contacts.sum().clamp(min=1) + + # 3. lDDT from distogram expected distances + # Standard lDDT: fraction of pairwise distances within thresholds + dist_error = torch.abs(pred_dists - true_dists) + lddt_score = ( + (dist_error < 0.5).float() + + (dist_error < 1.0).float() + + (dist_error < 2.0).float() + + (dist_error < 4.0).float() + ) * 0.25 + + # Only score pairs within 15Å cutoff (standard lDDT) + lddt_mask = pair_mask * (true_dists < 15.0).float() + lddt_from_distogram = (lddt_score * lddt_mask).sum() / lddt_mask.sum().clamp(min=1) + + # 4. Mean distance error (on valid pairs within 15Å) + mean_dist_error = (dist_error * lddt_mask).sum() / lddt_mask.sum().clamp(min=1) + + return { + "distogram_acc": distogram_acc, + "contact_precision_8A": contact_precision, + "contact_recall_8A": contact_recall, + "lddt_from_distogram": lddt_from_distogram, + "mean_distance_error": mean_dist_error, + } + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") +def main(args: DictConfig) -> float | None: + """Train ESM2-MiniFold TE with FSDP2. + + Returns: + float: The final loss value. + """ + os.environ["HF_HUB_TRUST_REMOTE_CODE"] = "1" + logging.getLogger("httpx").setLevel(logging.WARNING) + + # Initialize distributed + dist_config = DistributedConfig() + logger.info("Initializing distributed training: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(dist_config.world_size,), + mesh_dim_names=("dp",), + ) + + # Resolve per-block quantization precision + block_precision = resolve_layer_precision( + num_layers=args.model.num_blocks, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + + fp8_recipe = None + fp4_recipe = None + if args.fp8_config.enabled: + from transformer_engine.common.recipe import Format + + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + if args.fp4_config.enabled: + from transformer_engine.common.recipe import Format + + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + + # Component-level precision overrides + component_precision = ComponentPrecisionConfig(**OmegaConf.to_container(args.component_precision, resolve=True)) + + # Quant stats logging + quant_logger = None + if args.quant_stats_config.enabled: + if dist_config.is_main_process(): + quant_logger = BufferedQuantLogger() + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + layer_precision=block_precision, + statistics_logger=quant_logger, + component_precision=component_precision, + ) + + # Create model + params_dtype = torch.float32 + model = ESM2MiniFoldTE( + esm_model_name=args.esm_model_name, + c_s=args.model.c_s, + c_z=args.model.c_z, + num_blocks=args.model.num_blocks, + no_bins=args.model.no_bins, + use_structure_module=args.model.use_structure_module, + params_dtype=params_dtype, + block_precision=block_precision, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + component_precision=component_precision, + ).to(device) + + logger.info("Model created: %d parameters", sum(p.numel() for p in model.parameters())) + logger.info("Trainable: %d parameters", sum(p.numel() for p in model.parameters() if p.requires_grad)) + + # FSDP2: shard MiniFormer blocks individually for memory efficiency + if args.use_fp32_master_weights: + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward + reduce_dtype=torch.float32, # Gradient reductions in FP32 + output_dtype=torch.bfloat16, # Forward output dtype + cast_forward_inputs=False, + ) + else: + mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) + for block in model.fold.miniformer.blocks: + fully_shard(block, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + + # Assign layer names for quant stats debug API + if args.quant_stats_config.enabled: + import nvdlfw_inspect.api as debug_api + + debug_api.infer_and_assign_layer_names(model) + + # Optimizer with parameter groups + param_groups = [ + { + "params": list(model.get_folding_head_params()), + "lr": args.optimizer.folding_lr, + "name": "folding_head", + }, + ] + if args.model.use_structure_module: + param_groups.append( + { + "params": list(model.get_structure_module_params()), + "lr": args.optimizer.struct_lr, + "name": "structure_module", + } + ) + + optimizer = AdamW( + param_groups, + betas=tuple(args.optimizer.betas), + eps=args.optimizer.eps, + weight_decay=args.optimizer.weight_decay, + fused=True, + ) + scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + + # Create dataloader + train_dataloader, sampler = create_dataloader(dist_config, **args.dataset) + + if dist_config.is_main_process(): + logger.info("Block precision: %s", block_precision) + + # Resume from checkpoint + ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None + if args.checkpoint.resume_from_checkpoint and ckpt_path: + model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_fsdp2( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + dist_config=dist_config, + dataloader=train_dataloader, + ) + else: + start_step = 0 + epoch = 0 + + perf_logger = PerfLogger(dist_config, args, quant_logger=quant_logger) + + # Training loop + step = start_step + while step < args.num_train_steps: + for batch in train_dataloader: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + # Forward pass (BF16 handled by FSDP2 MixedPrecisionPolicy) + r_dict = model(batch, num_recycling=args.model.get("num_recycling", 0)) + + # Compute distogram loss + disto_loss = compute_distogram_loss( + preds=r_dict["preds"], + coords=batch["coords"], + mask=batch["mask"], + no_bins=args.model.no_bins, + ) + + total_loss = disto_loss + + # Optional structure module loss (Stage 2) + if args.model.use_structure_module and "sm" in r_dict: + from loss import AlphaFoldLoss + + loss_of, _ = AlphaFoldLoss(r_dict, batch.get("batch_of", {})) + total_loss = 0.8 * disto_loss + 0.2 * loss_of + + # Backward pass + total_loss.backward() + + # Gradient clipping + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() + + # Optimizer step + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + # Compute structure quality metrics (no grad, cheap) + structure_metrics = compute_distogram_metrics( + preds=r_dict["preds"].float(), + coords=batch["coords"], + mask=batch["mask"], + no_bins=args.model.no_bins, + ) + + # Count unpadded tokens across all GPUs + unpadded_tokens = batch["mask"].sum().item() * dist_config.world_size + + # Logging + perf_logger.log_step( + step=step, + loss=total_loss, + disto_loss=disto_loss, + grad_norm=total_norm, + lr=optimizer.param_groups[0]["lr"], + structure_metrics=structure_metrics, + unpadded_tokens=unpadded_tokens, + ) + + # Checkpointing + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): + save_checkpoint_fsdp2( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + step=step, + epoch=epoch, + dist_config=dist_config, + max_checkpoints=args.checkpoint.max_checkpoints, + ) + + step += 1 + if step >= args.num_train_steps: + break + + epoch += 1 + sampler.set_epoch(epoch) + + # Save final model + if args.checkpoint.save_final_model and ckpt_path: + save_final_model_fsdp2( + model=model, + save_directory=ckpt_path / "final_model", + dist_config=dist_config, + ) + + perf_logger.finish() + torch.distributed.destroy_process_group() + + return perf_logger.min_loss + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/esm2_minifold_te/weight_copy.py b/bionemo-recipes/recipes/esm2_minifold_te/weight_copy.py new file mode 100644 index 0000000000..bdf7e5b6f6 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_minifold_te/weight_copy.py @@ -0,0 +1,354 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Weight copy utilities between original MiniFold modules and their TE counterparts. + +Both nn.Linear and te.Linear store weights as (out_features, in_features), +so direct copy works without transposition. The raw nn.Parameter tensors used +in TransitionUpdate/TriangularUpdate also use (out, in) layout via F.linear. +""" + +import torch + + +def _copy_param(src, dst): + """Copy parameter data from src to dst, with shape validation.""" + assert src.shape == dst.shape, f"Shape mismatch: src {src.shape} vs dst {dst.shape}" + with torch.no_grad(): + dst.copy_(src) + + +# --------------------------------------------------------------------------- +# TransitionUpdate <-> TransitionUpdateTE +# --------------------------------------------------------------------------- + + +def copy_transition_update_to_te(orig, te_mod): + """Copy weights from TransitionUpdate to TransitionUpdateTE.""" + _copy_param(orig.wn, te_mod.norm.weight) + _copy_param(orig.bn, te_mod.norm.bias) + _copy_param(orig.w1, te_mod.fc1.weight) + _copy_param(orig.b1, te_mod.fc1.bias) + _copy_param(orig.w2, te_mod.fc2.weight) + _copy_param(orig.b2, te_mod.fc2.bias) + + +def copy_transition_update_from_te(te_mod, orig): + """Copy weights from TransitionUpdateTE to TransitionUpdate.""" + _copy_param(te_mod.norm.weight, orig.wn) + _copy_param(te_mod.norm.bias, orig.bn) + _copy_param(te_mod.fc1.weight, orig.w1) + _copy_param(te_mod.fc1.bias, orig.b1) + _copy_param(te_mod.fc2.weight, orig.w2) + _copy_param(te_mod.fc2.bias, orig.b2) + + +# --------------------------------------------------------------------------- +# TriangularUpdate <-> TriangularUpdateTE +# --------------------------------------------------------------------------- + + +def copy_triangular_update_to_te(orig, te_mod): + """Copy weights from TriangularUpdate to TriangularUpdateTE.""" + _copy_param(orig.ni_w, te_mod.input_norm.weight) + _copy_param(orig.ni_b, te_mod.input_norm.bias) + _copy_param(orig.pi_w, te_mod.pi.weight) + _copy_param(orig.pi_b, te_mod.pi.bias) + _copy_param(orig.gi_w, te_mod.gi.weight) + _copy_param(orig.gi_b, te_mod.gi.bias) + + _copy_param(orig.no_w, te_mod.output_norm.weight) + _copy_param(orig.no_b, te_mod.output_norm.bias) + _copy_param(orig.po_w, te_mod.po.weight) + _copy_param(orig.po_b, te_mod.po.bias) + _copy_param(orig.go_w, te_mod.go.weight) + _copy_param(orig.go_b, te_mod.go.bias) + + +def copy_triangular_update_from_te(te_mod, orig): + """Copy weights from TriangularUpdateTE to TriangularUpdate.""" + _copy_param(te_mod.input_norm.weight, orig.ni_w) + _copy_param(te_mod.input_norm.bias, orig.ni_b) + _copy_param(te_mod.pi.weight, orig.pi_w) + _copy_param(te_mod.pi.bias, orig.pi_b) + _copy_param(te_mod.gi.weight, orig.gi_w) + _copy_param(te_mod.gi.bias, orig.gi_b) + + _copy_param(te_mod.output_norm.weight, orig.no_w) + _copy_param(te_mod.output_norm.bias, orig.no_b) + _copy_param(te_mod.po.weight, orig.po_w) + _copy_param(te_mod.po.bias, orig.po_b) + _copy_param(te_mod.go.weight, orig.go_w) + _copy_param(te_mod.go.bias, orig.go_b) + + +# --------------------------------------------------------------------------- +# Block <-> BlockTE +# --------------------------------------------------------------------------- + + +def copy_block_to_te(orig, te_mod): + """Copy weights from Block to BlockTE.""" + copy_triangular_update_to_te(orig.triangular, te_mod.triangular) + copy_transition_update_to_te(orig.transition, te_mod.transition) + + +def copy_block_from_te(te_mod, orig): + """Copy weights from BlockTE to Block.""" + copy_triangular_update_from_te(te_mod.triangular, orig.triangular) + copy_transition_update_from_te(te_mod.transition, orig.transition) + + +# --------------------------------------------------------------------------- +# MiniFormer <-> MiniFormerTE +# --------------------------------------------------------------------------- + + +def copy_miniformer_to_te(orig, te_mod): + """Copy weights from MiniFormer to MiniFormerTE.""" + assert len(orig.blocks) == len(te_mod.blocks) + for orig_block, te_block in zip(orig.blocks, te_mod.blocks): + copy_block_to_te(orig_block, te_block) + + +def copy_miniformer_from_te(te_mod, orig): + """Copy weights from MiniFormerTE to MiniFormer.""" + assert len(orig.blocks) == len(te_mod.blocks) + for orig_block, te_block in zip(orig.blocks, te_mod.blocks): + copy_block_from_te(te_block, orig_block) + + +# --------------------------------------------------------------------------- +# nn.Linear <-> te.Linear (generic helper) +# --------------------------------------------------------------------------- + + +def _copy_linear_to_te(orig_linear, te_linear): + """Copy from nn.Linear to te.Linear.""" + _copy_param(orig_linear.weight, te_linear.weight) + if orig_linear.bias is not None and te_linear.bias is not None: + _copy_param(orig_linear.bias, te_linear.bias) + + +def _copy_linear_from_te(te_linear, orig_linear): + """Copy from te.Linear to nn.Linear.""" + _copy_param(te_linear.weight, orig_linear.weight) + if orig_linear.bias is not None and te_linear.bias is not None: + _copy_param(te_linear.bias, orig_linear.bias) + + +def _copy_layernorm_to_te(orig_ln, te_ln): + """Copy from nn.LayerNorm to te.LayerNorm.""" + _copy_param(orig_ln.weight, te_ln.weight) + _copy_param(orig_ln.bias, te_ln.bias) + + +def _copy_layernorm_from_te(te_ln, orig_ln): + """Copy from te.LayerNorm to nn.LayerNorm.""" + _copy_param(te_ln.weight, orig_ln.weight) + _copy_param(te_ln.bias, orig_ln.bias) + + +# --------------------------------------------------------------------------- +# SequenceToPair <-> SequenceToPairTE +# --------------------------------------------------------------------------- + + +def copy_seq_to_pair_to_te(orig, te_mod): + _copy_layernorm_to_te(orig.layernorm, te_mod.layernorm) + _copy_linear_to_te(orig.proj, te_mod.proj) + _copy_linear_to_te(orig.o_proj, te_mod.o_proj) + + +def copy_seq_to_pair_from_te(te_mod, orig): + _copy_layernorm_from_te(te_mod.layernorm, orig.layernorm) + _copy_linear_from_te(te_mod.proj, orig.proj) + _copy_linear_from_te(te_mod.o_proj, orig.o_proj) + + +# --------------------------------------------------------------------------- +# PairToSequence <-> PairToSequenceTE +# --------------------------------------------------------------------------- + + +def copy_pair_to_seq_to_te(orig, te_mod): + # s_z_mlp: Sequential(LayerNorm, Linear, ReLU, Linear) -> separate TE modules + _copy_layernorm_to_te(orig.s_z_mlp[0], te_mod.s_z_norm) + _copy_linear_to_te(orig.s_z_mlp[1], te_mod.s_z_fc1) + _copy_linear_to_te(orig.s_z_mlp[3], te_mod.s_z_fc2) + # combiner: Sequential(Linear) -> te.Linear + _copy_linear_to_te(orig.combiner[0], te_mod.combiner) + + +def copy_pair_to_seq_from_te(te_mod, orig): + _copy_layernorm_from_te(te_mod.s_z_norm, orig.s_z_mlp[0]) + _copy_linear_from_te(te_mod.s_z_fc1, orig.s_z_mlp[1]) + _copy_linear_from_te(te_mod.s_z_fc2, orig.s_z_mlp[3]) + _copy_linear_from_te(te_mod.combiner, orig.combiner[0]) + + +# --------------------------------------------------------------------------- +# FoldingTrunk <-> FoldingTrunkTE +# --------------------------------------------------------------------------- + + +def copy_folding_trunk_to_te(orig, te_mod): + # Positional embedding + _copy_param(orig.positional_embedding.embedding.weight, te_mod.positional_embedding.embedding.weight) + copy_seq_to_pair_to_te(orig.seq_to_pair, te_mod.seq_to_pair) + _copy_linear_to_te(orig.projection, te_mod.projection) + _copy_linear_to_te(orig.recycle, te_mod.recycle) + copy_miniformer_to_te(orig.miniformer, te_mod.miniformer) + # fc_out: Sequential(Linear, ReLU, Linear) -> fc_out_1, fc_out_2 + _copy_linear_to_te(orig.fc_out[0], te_mod.fc_out_1) + _copy_linear_to_te(orig.fc_out[2], te_mod.fc_out_2) + + +def copy_folding_trunk_from_te(te_mod, orig): + _copy_param(te_mod.positional_embedding.embedding.weight, orig.positional_embedding.embedding.weight) + copy_seq_to_pair_from_te(te_mod.seq_to_pair, orig.seq_to_pair) + _copy_linear_from_te(te_mod.projection, orig.projection) + _copy_linear_from_te(te_mod.recycle, orig.recycle) + copy_miniformer_from_te(te_mod.miniformer, orig.miniformer) + _copy_linear_from_te(te_mod.fc_out_1, orig.fc_out[0]) + _copy_linear_from_te(te_mod.fc_out_2, orig.fc_out[2]) + + +# --------------------------------------------------------------------------- +# Attention <-> AttentionTE +# --------------------------------------------------------------------------- + + +def copy_attention_to_te(orig, te_mod): + _copy_layernorm_to_te(orig.layer_norm, te_mod.layer_norm) + _copy_linear_to_te(orig.proj, te_mod.proj) + _copy_linear_to_te(orig.o_proj, te_mod.o_proj) + _copy_linear_to_te(orig.g_proj, te_mod.g_proj) + + +def copy_attention_from_te(te_mod, orig): + _copy_layernorm_from_te(te_mod.layer_norm, orig.layer_norm) + _copy_linear_from_te(te_mod.proj, orig.proj) + _copy_linear_from_te(te_mod.o_proj, orig.o_proj) + _copy_linear_from_te(te_mod.g_proj, orig.g_proj) + + +# --------------------------------------------------------------------------- +# MLP <-> MLPTE +# --------------------------------------------------------------------------- + + +def copy_mlp_to_te(orig, te_mod): + # orig.mlp: Sequential(LayerNorm, Linear, ReLU, Linear) + _copy_layernorm_to_te(orig.mlp[0], te_mod.norm) + _copy_linear_to_te(orig.mlp[1], te_mod.fc1) + _copy_linear_to_te(orig.mlp[3], te_mod.fc2) + + +def copy_mlp_from_te(te_mod, orig): + _copy_layernorm_from_te(te_mod.norm, orig.mlp[0]) + _copy_linear_from_te(te_mod.fc1, orig.mlp[1]) + _copy_linear_from_te(te_mod.fc2, orig.mlp[3]) + + +# --------------------------------------------------------------------------- +# AngleResnetBlock <-> AngleResnetBlockTE +# --------------------------------------------------------------------------- + + +def copy_angle_resnet_block_to_te(orig, te_mod): + # orig.mlp: Sequential(ReLU, Linear, ReLU, Linear) + _copy_linear_to_te(orig.mlp[1], te_mod.fc1) + _copy_linear_to_te(orig.mlp[3], te_mod.fc2) + + +def copy_angle_resnet_block_from_te(te_mod, orig): + _copy_linear_from_te(te_mod.fc1, orig.mlp[1]) + _copy_linear_from_te(te_mod.fc2, orig.mlp[3]) + + +# --------------------------------------------------------------------------- +# AngleResnet <-> AngleResnetTE +# --------------------------------------------------------------------------- + + +def copy_angle_resnet_to_te(orig, te_mod): + _copy_linear_to_te(orig.linear_in, te_mod.linear_in) + _copy_linear_to_te(orig.linear_initial, te_mod.linear_initial) + _copy_linear_to_te(orig.linear_out, te_mod.linear_out) + for orig_layer, te_layer in zip(orig.layers, te_mod.layers): + copy_angle_resnet_block_to_te(orig_layer, te_layer) + + +def copy_angle_resnet_from_te(te_mod, orig): + _copy_linear_from_te(te_mod.linear_in, orig.linear_in) + _copy_linear_from_te(te_mod.linear_initial, orig.linear_initial) + _copy_linear_from_te(te_mod.linear_out, orig.linear_out) + for orig_layer, te_layer in zip(orig.layers, te_mod.layers): + copy_angle_resnet_block_from_te(te_layer, orig_layer) + + +# --------------------------------------------------------------------------- +# StructureModule <-> StructureModuleTE +# --------------------------------------------------------------------------- + + +def copy_structure_module_to_te(orig, te_mod): + _copy_layernorm_to_te(orig.layer_norm_s, te_mod.layer_norm_s) + _copy_layernorm_to_te(orig.layer_norm_z, te_mod.layer_norm_z) + _copy_linear_to_te(orig.linear_in, te_mod.linear_in) + _copy_linear_to_te(orig.linear_b, te_mod.linear_b) + _copy_linear_to_te(orig.bb_update, te_mod.bb_update) + + for orig_attn, te_attn in zip(orig.attn, te_mod.attn): + copy_attention_to_te(orig_attn, te_attn) + for orig_trans, te_trans in zip(orig.transitions, te_mod.transitions): + copy_mlp_to_te(orig_trans, te_trans) + + copy_angle_resnet_to_te(orig.angle_resnet, te_mod.angle_resnet) + + +def copy_structure_module_from_te(te_mod, orig): + _copy_layernorm_from_te(te_mod.layer_norm_s, orig.layer_norm_s) + _copy_layernorm_from_te(te_mod.layer_norm_z, orig.layer_norm_z) + _copy_linear_from_te(te_mod.linear_in, orig.linear_in) + _copy_linear_from_te(te_mod.linear_b, orig.linear_b) + _copy_linear_from_te(te_mod.bb_update, orig.bb_update) + + for orig_attn, te_attn in zip(orig.attn, te_mod.attn): + copy_attention_from_te(te_attn, orig_attn) + for orig_trans, te_trans in zip(orig.transitions, te_mod.transitions): + copy_mlp_from_te(te_trans, orig_trans) + + copy_angle_resnet_from_te(te_mod.angle_resnet, orig.angle_resnet) + + +# --------------------------------------------------------------------------- +# PerResidueLDDTCaPredictor <-> PerResidueLDDTCaPredictorTE +# --------------------------------------------------------------------------- + + +def copy_plddt_to_te(orig, te_mod): + _copy_layernorm_to_te(orig.layer_norm, te_mod.layer_norm) + _copy_linear_to_te(orig.linear_1, te_mod.linear_1) + _copy_linear_to_te(orig.linear_2, te_mod.linear_2) + _copy_linear_to_te(orig.linear_3, te_mod.linear_3) + + +def copy_plddt_from_te(te_mod, orig): + _copy_layernorm_from_te(te_mod.layer_norm, orig.layer_norm) + _copy_linear_from_te(te_mod.linear_1, orig.linear_1) + _copy_linear_from_te(te_mod.linear_2, orig.linear_2) + _copy_linear_from_te(te_mod.linear_3, orig.linear_3)