diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index ae6027cb..b534fd5d 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -25,6 +25,7 @@ dependencies: - pydantic>=2 - ray-all - graphviz + - openmmtools # Testing - pytest>=2.1 @@ -39,6 +40,5 @@ dependencies: - pytorch2jax - git+https://github.com/ArnNag/sake.git@nanometer - flax - - torch - pytest-xdist diff --git a/devtools/conda-envs/test_env_mac.yaml b/devtools/conda-envs/test_env_mac.yaml index 65b322d4..f96b4271 100644 --- a/devtools/conda-envs/test_env_mac.yaml +++ b/devtools/conda-envs/test_env_mac.yaml @@ -25,8 +25,8 @@ dependencies: - flax - pydantic>=2.0 - graphviz - - - + - openmmtools + # Testing - pytest>=2.1 - pytest-cov diff --git a/docs/index.rst b/docs/index.rst index 90cc4fd1..00638fc2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,6 +21,7 @@ The best way to get started is to read the :doc:`getting_started` guide, which o inference for_developer tuning + profiling api diff --git a/docs/profiling.rst b/docs/profiling.rst new file mode 100644 index 00000000..ec97459a --- /dev/null +++ b/docs/profiling.rst @@ -0,0 +1,7 @@ +Profiling +================ + +Profiling: Overview +------------------------------------------ + +It is common to profile models to identify bottlenecks and optimize performance. *Modelforge* provides a simple interface to profile models using the `torch.profiler` module. The profiler can be used to profile the forward pass, backward pass, or both, and can be used to profile the model on a single batch or multiple batches. \ No newline at end of file diff --git a/modelforge/potential/ani.py b/modelforge/potential/ani.py index 4b3347a2..98503a52 100644 --- a/modelforge/potential/ani.py +++ b/modelforge/potential/ani.py @@ -7,9 +7,7 @@ using a neural network model. """ -from __future__ import annotations - -from typing import TYPE_CHECKING, Dict, Tuple +from typing import Dict, Tuple, List import torch from loguru import logger as log diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 127b888a..1bd118aa 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -11,7 +11,7 @@ from torch.nn import Module from modelforge.potential.neighbors import PairlistData -from modelforge.dataset.dataset import DatasetParameters, NNPInput, NNPInputTuple +from modelforge.dataset.dataset import DatasetParameters, NNPInputTuple from modelforge.potential.parameters import ( AimNet2Parameters, ANI2xParameters, @@ -593,6 +593,12 @@ def generate_potential( neighborlist_strategy=inference_neighborlist_strategy, verlet_neighborlist_skin=verlet_neighborlist_skin, ) + # Disable gradients for model parameters + for param in model.parameters(): + param.requires_grad = False + # Set model to eval + model.eval() + if simulation_environment == "JAX": return PyTorch2JAXConverter().convert_to_jax_model(model) else: diff --git a/modelforge/potential/painn.py b/modelforge/potential/painn.py index df5f207c..b4965ad9 100644 --- a/modelforge/potential/painn.py +++ b/modelforge/potential/painn.py @@ -298,9 +298,9 @@ def forward( # featurize pairwise distances using radial basis functions (RBF) f_ij = self.radial_symmetry_function_module(d_ij) - f_ij_cut = self.cutoff_module(d_ij) + # Apply the filter network and cutoff function - filters = torch.mul(self.filter_net(f_ij), f_ij_cut) + filters = torch.mul(self.filter_net(f_ij), self.cutoff_module(d_ij)) # depending on whether we share filters or not filters have different # shape at dim=1 (dim=0 is always the number of atom pairs) if we share diff --git a/modelforge/potential/physnet.py b/modelforge/potential/physnet.py index 5b246d85..0fd174c8 100644 --- a/modelforge/potential/physnet.py +++ b/modelforge/potential/physnet.py @@ -227,7 +227,7 @@ def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: g = self.attention_mask(data["f_ij"]) # calculate the updated embedding for atom j embedding_atom_j = self.activation_function( - self.interaction_j(data["atomic_embedding"][idx_j]) + self.interaction_j(data["atomic_embedding"])[idx_j] ) updated_embedding_atom_j = torch.mul( g, embedding_atom_j diff --git a/modelforge/potential/representation.py b/modelforge/potential/representation.py index 07e00d2e..1a21ed87 100644 --- a/modelforge/potential/representation.py +++ b/modelforge/potential/representation.py @@ -151,37 +151,71 @@ def forward(self, r_ij: torch.Tensor) -> torch.Tensor: return sub_aev def compute_angular_sub_aev(self, vectors12: torch.Tensor) -> torch.Tensor: - """Compute the angular subAEV terms of the center atom given neighbor pairs. + """ + Compute the angular subAEV terms of the center atom given neighbor + pairs. This correspond to equation (4) in the ANI paper. This function just compute the terms. The sum in the equation is not computed. - The input tensor have shape (conformations, atoms, N), where N - is the number of neighbor atom pairs within the cutoff radius and - output tensor should have shape - (conformations, atoms, ``self.angular_sublength()``) + + Parameters + ---------- + vectors12: torch.Tensor + Pairwise distance vectors. Shape: [2, n_pairs, 3] + + Returns + ------- + torch.Tensor + Angular subAEV terms. Shape: [n_pairs, ShfZ_size * ShfA_size] """ - vectors12 = vectors12.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - distances12 = vectors12.norm(2, dim=-5) + # vectors12: (2, n_pairs, 3) + distances12 = vectors12.norm(p=2, dim=-1) # Shape: (2, n_pairs) + distances_sum = distances12.sum(dim=0) / 2 # Shape: (n_pairs,) + fcj12 = self.cosine_cutoff(distances12) # Shape: (2, n_pairs) + fcj12_prod = fcj12.prod(dim=0) # Shape: (n_pairs,) - # 0.95 is multiplied to the cos values to prevent acos from - # returning NaN. + # cos_angles: (n_pairs,) cos_angles = 0.95 * torch.nn.functional.cosine_similarity( - vectors12[0], vectors12[1], dim=-5 + vectors12[0], vectors12[1], dim=-1 ) - angles = torch.acos(cos_angles) - fcj12 = self.cosine_cutoff(distances12) - factor1 = ((1 + torch.cos(angles - self.ShfZ)) / 2) ** self.Zeta + angles = torch.acos(cos_angles) # Shape: (n_pairs,) + + # Prepare shifts for broadcasting + angles = angles.unsqueeze(-1) # Shape: (n_pairs, 1) + distances_sum = distances_sum.unsqueeze(-1) # Shape: (n_pairs, 1) + + # Compute factor1 + delta_angles = angles - self.ShfZ.view(1, -1) # Shape: (n_pairs, ShfZ_size) + factor1 = ( + (1 + torch.cos(delta_angles)) / 2 + ) ** self.Zeta # Shape: (n_pairs, ShfZ_size) + + # Compute factor2 + delta_distances = distances_sum - self.ShfA.view( + 1, -1 + ) # Shape: (n_pairs, ShfA_size) factor2 = torch.exp( - -self.EtaA * (distances12.sum(0) / 2 - self.ShfA) ** 2 - ).unsqueeze(-1) - factor2 = factor2.squeeze(4).squeeze(3) - ret = 2 * factor1 * factor2 * fcj12.prod(0) - # At this point, ret now have shape - # (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants. - # We then should flat the last 4 dimensions to view the subAEV as one - # dimension vector - return ret.flatten(start_dim=-4) + -self.EtaA * delta_distances**2 + ) # Shape: (n_pairs, ShfA_size) + + # Compute the outer product of factor1 and factor2 efficiently + # fcj12_prod: (n_pairs, 1, 1) + fcj12_prod = fcj12_prod.unsqueeze(-1).unsqueeze(-1) # Shape: (n_pairs, 1, 1) + + # factor1: (n_pairs, ShfZ_size, 1) + factor1 = factor1.unsqueeze(-1) + # factor2: (n_pairs, 1, ShfA_size) + factor2 = factor2.unsqueeze(-2) + + # Compute ret: (n_pairs, ShfZ_size, ShfA_size) + ret = 2 * fcj12_prod * factor1 * factor2 + + # Flatten the last two dimensions to get the final subAEV + # ret: (n_pairs, ShfZ_size * ShfA_size) + ret = ret.reshape(distances12.size(dim=1), -1) + + return ret import math diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index 5e969d56..6c25a4fd 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -143,16 +143,16 @@ def compute_properties( # Compute the atomic representation representation = self.schnet_representation_module(data, pairlist_output) atomic_embedding = representation["atomic_embedding"] - + f_ij = representation["f_ij"] + f_cutoff = representation["f_cutoff"] # Apply interaction modules to update the atomic embedding for interaction in self.interaction_modules: - v = interaction( + atomic_embedding = atomic_embedding + interaction( atomic_embedding, pairlist_output, - representation["f_ij"], - representation["f_cutoff"], + f_ij, + f_cutoff, ) - atomic_embedding = atomic_embedding + v # Update atomic features return { "per_atom_scalar_representation": atomic_embedding, @@ -293,14 +293,13 @@ def forward( # Generate interaction filters based on radial basis functions W_ij = self.filter_network(f_ij.squeeze(1)) - W_ij = W_ij * f_ij_cutoff + W_ij = W_ij * f_ij_cutoff # Shape: [n_pairs, number_of_filters] # Perform continuous-filter convolution x_j = atomic_embedding[idx_j] x_ij = x_j * W_ij # Element-wise multiplication - out = torch.zeros_like(atomic_embedding) - out.scatter_add_( + out = torch.zeros_like(atomic_embedding).scatter_add_( 0, idx_i.unsqueeze(-1).expand_as(x_ij), x_ij ) # Aggregate per-atom pair to per-atom diff --git a/modelforge/tests/test_profiling.py b/modelforge/tests/test_profiling.py new file mode 100644 index 00000000..3bf4a3db --- /dev/null +++ b/modelforge/tests/test_profiling.py @@ -0,0 +1,56 @@ +import torch +import pytest + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_profiling_function(): + from modelforge.tests.helper_functions import setup_potential_for_test + import torch + from modelforge.utils.profiling import ( + start_record_memory_history, + export_memory_snapshot, + stop_record_memory_history, + setup_waterbox_testsystem, + ) + + # define the potential, device and precision + potential_name = "tensornet" + precision = torch.float32 + device = "cuda" + + # setup the input and model + nnp_input = setup_waterbox_testsystem(2.5, device=device, precision=precision) + model = setup_potential_for_test( + potential_name, + "inference", + potential_seed=42, + use_training_mode_neighborlist=True, + simulation_environment="PyTorch", + ).to(device, precision) + # Disable gradients for model parameters + for param in model.parameters(): + param.requires_grad = False + # Set model to eval + model.eval() + + # this is the function that will be profiled + def loop_to_record(): + for _ in range(5): + # perform the forward pass through each of the models + r = model(nnp_input)["per_molecule_energy"] + # Compute the gradient (forces) from the predicted energies + grad = torch.autograd.grad( + r, + nnp_input.positions, + grad_outputs=torch.ones_like(r), + create_graph=False, + retain_graph=False, + )[0] + + # Start recording memory snapshot history + start_record_memory_history() + loop_to_record() + # Create the memory snapshot file + export_memory_snapshot() + # Stop recording memory snapshot history + stop_record_memory_history() diff --git a/modelforge/utils/io.py b/modelforge/utils/io.py index 8f0eab42..21b84e49 100644 --- a/modelforge/utils/io.py +++ b/modelforge/utils/io.py @@ -160,6 +160,18 @@ """ +MESSAGES[ + "openmmtools" +] = """ + +A batteries-included toolkit for the GPU-accelerated OpenMM molecular simulation engine. + +OpenMMTools can be installed via conda: + + conda install conda-forge::openmmtools + +""" + def import_(module: str): """Import a module or print a descriptive message and raise an ImportError diff --git a/modelforge/utils/profiling.py b/modelforge/utils/profiling.py new file mode 100644 index 00000000..f660eb3f --- /dev/null +++ b/modelforge/utils/profiling.py @@ -0,0 +1,273 @@ +import torch +from loguru import logger as log +import socket +from datetime import datetime +from modelforge.dataset.dataset import NNPInput + +TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S" +MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000 + + +def setup_waterbox_testsystem( + edge_size_in_nm: float, + device: torch.device, + precision: torch.dtype, +) -> NNPInput: + from modelforge.utils.io import import_ + + openmmtools = import_("openmmtools") + from simtk import unit + from modelforge.dataset.dataset import NNPInput + + test_system = openmmtools.testsystems.WaterBox( + box_edge=edge_size_in_nm * unit.nanometer + ) + positions = test_system.positions # Positions in nanometers + topology = test_system.topology + + # Extract atomic numbers and residue indices + atomic_numbers = [] + residue_indices = [] + for residue_index, residue in enumerate(topology.residues()): + for atom in residue.atoms(): + atomic_numbers.append(atom.element.atomic_number) + residue_indices.append(residue_index) + num_waters = len(list(topology.residues())) + positions_in_nanometers = positions.value_in_unit(unit.nanometer) + + # Convert to torch tensors and move to GPU + torch_atomic_numbers = torch.tensor(atomic_numbers, dtype=torch.long, device=device) + torch_positions = torch.tensor( + positions_in_nanometers, dtype=torch.float32, device=device, requires_grad=True + ) + torch_atomic_subsystem_indices = torch.zeros_like( + torch_atomic_numbers, dtype=torch.long, device=device + ) + torch_total_charge = torch.zeros(1, dtype=torch.float32, device=device) + + log.info(f"Waterbox system setup with {num_waters} waters") + return NNPInput( + atomic_numbers=torch_atomic_numbers, + positions=torch_positions, + atomic_subsystem_indices=torch_atomic_subsystem_indices, + total_charge=torch_total_charge, + ).to(dtype=precision) + + +from typing import List +import time + + +def measure_performance_for_edge_sizes( + edge_sizes: List[float], + potential_names: List[str], +): + """ + Measures GPU memory utilization and computation time for force calculations + for water boxes of different edge sizes across multiple potentials. + + Parameters + ---------- + edge_sizes : List[float] + A list of edge sizes (in nanometers) for the water boxes. + potential_names : List[str] + A list of potential names to use in the model setup. + + Returns + ------- + List[dict] + A list of dictionaries containing edge size, number of water molecules, + potential name, memory usage in bytes, and computation time in seconds. + """ + results = [] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + precicion = torch.float32 + for potential_name in potential_names: + for edge_size in edge_sizes: + + nnp_input = setup_waterbox_testsystem( + edge_size, + device, + precicion, + ) + + # Import your model setup function + from modelforge.tests.helper_functions import setup_potential_for_test + + # Setup model + model = setup_potential_for_test( + potential_name, + "inference", + potential_seed=42, + use_training_mode_neighborlist=False, + simulation_environment="PyTorch", + ) + + model.to(device) + model.to(precicion) + total_params = sum(p.numel() for p in model.parameters()) + + # Measure GPU memory usage and computation time + torch.cuda.reset_peak_memory_stats(device=device) + torch.cuda.synchronize() + + # Run forward pass and time it + start_time = time.perf_counter() + try: + output = model(nnp_input.as_namedtuple())["per_molecule_energy"] + except: + print("Out of memory error during forward pass") + continue + + try: + F_training = -torch.autograd.grad( + output.sum(), + nnp_input.positions, + create_graph=False, + retain_graph=False, + )[0] + except: + print("Out of memory error during backward pass") + continue + torch.cuda.synchronize() + end_time = time.perf_counter() + + max_memory_allocated = torch.cuda.max_memory_allocated(device=device) + computation_time = end_time - start_time + + results.append( + { + "potential_name": f"{potential_name}: {total_params:.1e} params", + "edge_size_nm": edge_size, + "num_waters": num_waters, + "memory_usage_bytes": max_memory_allocated, + "computation_time_s": computation_time, + } + ) + + # Clean up + del ( + nnp_input, + output, + model, + ) + try: + del F_training + except: + pass + torch.cuda.empty_cache() + time.sleep(1) # Sleep for a second to allow GPU memory to be freed + + return results + + +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt + + +def plot_computation_time(results): + """ + Plots computation time against the number of water molecules for multiple potentials. + + Parameters + ---------- + results : List[dict] + A list of dictionaries containing edge size, number of water molecules, + potential name, memory usage in bytes, and computation time in seconds. + """ + # Create a DataFrame for plotting + df = pd.DataFrame(results) + df["computation_time_ms"] = ( + df["computation_time_s"] * 1000 + ) # Convert seconds to milliseconds + + # Plot using seaborn + sns.set(style="whitegrid") + plt.figure(figsize=(10, 6)) + sns.lineplot( + data=df, + x="num_waters", + y="computation_time_ms", + hue="potential_name", + units="potential_name", + estimator=None, # Do not aggregate data + marker="o", + linewidth=2, + markersize=8, + ) + plt.title("Computation Time vs Number of Water Molecules for Different Potentials") + plt.xlabel("Number of Water Molecules") + plt.ylabel("Computation Time (ms)") + plt.xticks(sorted(df["num_waters"].unique())) + plt.legend(title="Potential Name") + plt.tight_layout() + plt.show() + + +def plot_gpu_memory_usage(results): + """ + Plots GPU memory usage against the number of water molecules for multiple potentials. + + Parameters + ---------- + results : List[dict] + A list of dictionaries containing edge size, number of water molecules, + potential name, and memory usage in bytes. + """ + # Create a DataFrame for plotting + df = pd.DataFrame(results) + df["memory_usage_mb"] = df["memory_usage_bytes"] / 1e6 # Convert bytes to megabytes + + # Plot using seaborn + sns.set(style="whitegrid") + plt.figure(figsize=(10, 6)) + sns.lineplot( + data=df, + x="num_waters", + y="memory_usage_mb", + units="potential_name", + estimator=None, # Do not aggregate data + hue="potential_name", + marker="o", + linewidth=2, + markersize=8, + ) + plt.title( + "Backward pass: GPU Memory Usage vs Number of Water Molecules for Different Potentials" + ) + plt.xlabel("Number of Water Molecules") + plt.ylabel("GPU Memory Usage (MB)") + plt.xticks(sorted(df["num_waters"].unique())) + plt.legend(title="Potential Name") + plt.tight_layout() + plt.show() + + +def start_record_memory_history() -> None: + if not torch.cuda.is_available(): + log.info("CUDA unavailable. Not recording memory history") + return + + log.info("Starting snapshot record_memory_history") + torch.cuda.memory._record_memory_history( + max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT + ) + + +def export_memory_snapshot() -> None: + if not torch.cuda.is_available(): + log.info("CUDA unavailable. Not exporting memory snapshot") + return + + # Prefix for file names. + host_name = socket.gethostname() + timestamp = datetime.now().strftime(TIME_FORMAT_STR) + file_prefix = f"{host_name}_{timestamp}" + + try: + log.info(f"Saving snapshot to local file: {file_prefix}.pickle") + torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle") + except Exception as e: + log.error(f"Failed to capture memory snapshot {e}") + return diff --git a/notebooks/profiling/_memory_viz.py b/notebooks/profiling/_memory_viz.py new file mode 100644 index 00000000..49599362 --- /dev/null +++ b/notebooks/profiling/_memory_viz.py @@ -0,0 +1,739 @@ +# mypy: allow-untyped-defs +import pickle +import sys +import os +import io +import subprocess +import json +from functools import lru_cache +from typing import Any +from itertools import groupby +import base64 +import warnings +import operator + +cache = lru_cache(None) + +__all__ = ["format_flamegraph", "segments", "memory", "compare"] + + +def _frame_fmt(f, full_filename=False): + i = f["line"] + fname = f["filename"] + if not full_filename: + fname = fname.split("/")[-1] + func = f["name"] + return f"{fname}:{i}:{func}" + + +@cache +def _frame_filter(name, filename): + omit_functions = [ + "unwind::unwind", + "CapturedTraceback::gather", + "gather_with_cpp", + "_start", + "__libc_start_main", + "PyEval_", + "PyObject_", + "PyFunction_", + ] + omit_filenames = [ + "core/boxing", + "/Register", + "/Redispatch", + "pythonrun.c", + "Modules/main.c", + "Objects/call.c", + "Objects/methodobject.c", + "pycore_ceval.h", + "ceval.c", + "cpython/abstract.h", + ] + for of in omit_functions: + if of in name: + return False + for of in omit_filenames: + if of in filename: + return False + return True + + +def _frames_fmt(frames, full_filename=False, reverse=False): + if reverse: + frames = reversed(frames) + return [ + _frame_fmt(f, full_filename) + for f in frames + if _frame_filter(f["name"], f["filename"]) + ] + + +def _block_extra_legacy(b): + if "history" in b: + frames = b["history"][0].get("frames", []) + real_size = b["history"][0]["real_size"] + else: + real_size = b.get("requested_size", b["size"]) + frames = [] + return frames, real_size + + +def _block_extra(b): + if "frames" not in b: + # old snapshot format made it more complicated to get frames/allocated size + return _block_extra_legacy(b) + return b["frames"], b["requested_size"] + + +def format_flamegraph(flamegraph_lines, flamegraph_script=None): + if flamegraph_script is None: + flamegraph_script = f"/tmp/{os.getuid()}_flamegraph.pl" + if not os.path.exists(flamegraph_script): + import urllib.request + + print(f"Downloading flamegraph.pl to: {flamegraph_script}") + urllib.request.urlretrieve( + "https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl", + flamegraph_script, + ) + subprocess.check_call(["chmod", "+x", flamegraph_script]) + args = [flamegraph_script, "--countname", "bytes"] + p = subprocess.Popen( + args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding="utf-8" + ) + assert p.stdin is not None + assert p.stdout is not None + p.stdin.write(flamegraph_lines) + p.stdin.close() + result = p.stdout.read() + p.stdout.close() + p.wait() + assert p.wait() == 0 + return result + + +def _write_blocks(f, prefix, blocks): + def frames_fragment(frames): + if not frames: + return "" + return ";".join(_frames_fmt(frames, reverse=True)) + + for b in blocks: + if "history" not in b: + frames, accounted_for_size = _block_extra(b) + f.write( + f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n' + ) + else: + accounted_for_size = 0 + for h in b["history"]: + sz = h["real_size"] + accounted_for_size += sz + if "frames" in h: + frames = h["frames"] + f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n') + else: + f.write(f'{prefix};{b["state"]}; {sz}\n') + gaps = b["size"] - accounted_for_size + if gaps: + f.write(f'{prefix};{b["state"]}; {gaps}\n') + + +def segments(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot["segments"]: + prefix = f'stream_{seg["stream"]};seg_{seg["address"]}' + _write_blocks(f, prefix, seg["blocks"]) + return format_flamegraph(f.getvalue()) + + +def memory(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot["segments"]: + prefix = f'stream_{seg["stream"]}' + _write_blocks(f, prefix, seg["blocks"]) + return format_flamegraph(f.getvalue()) + + +def compare(before, after, format_flamegraph=format_flamegraph): + def _seg_key(seg): + return (seg["address"], seg["total_size"]) + + def _seg_info(seg): + return f'stream_{seg["stream"]};seg_{seg["address"]}' + + f = io.StringIO() + + before_segs = {_seg_key(seg) for seg in before} + after_segs = {_seg_key(seg) for seg in after} + + print(f"only_before = {[a for a, _ in (before_segs - after_segs)]}") + print(f"only_after = {[a for a, _ in (after_segs - before_segs)]}") + + for seg in before: + if _seg_key(seg) not in after_segs: + _write_blocks(f, f"only_before;{_seg_info(seg)}", seg["blocks"]) + + for seg in after: + if _seg_key(seg) not in before_segs: + _write_blocks(f, f"only_after;{_seg_info(seg)}", seg["blocks"]) + + return format_flamegraph(f.getvalue()) + + +def _format_size(num): + # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: + if abs(num) < 1024.0: + return f"{num:3.1f}{unit}B" + num /= 1024.0 + return f"{num:.1f}YiB" + + +class Bytes: + def __init__(self, value): + self.value = value + + def __add__(self, rhs): + return Bytes(self.value + rhs) + + def __repr__(self): + return _format_size(self.value) + + +def calc_active(seg): + return sum(b["size"] for b in seg["blocks"] if b["state"] == "active_allocated") + + +def _report_free(free_external, free_internal): + total = free_external + free_internal + suffix = "" + if total != 0: + pct = (free_internal / total) * 100 + suffix = f" ({pct:.1f}% internal)" + return f"{Bytes(total)}{suffix}" + + +PAGE_SIZE = 1024 * 1024 * 20 +legend = f"""\ + +Legend: + [a ] - a segment in the allocator + ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment + a-z: pages filled with a single block's content + ' ': page is completely free + *: page if completely full with multiple blocks + 0-9: page is partially full with tensors of multiple blocks (9 == 90% full) + (X% internal) - of the free memory, X% is free because we rounded the size of the allocation. +""" + + +def segsum(data): + r"""Visually reports how the allocator has filled its segments. + + This printout can help debug fragmentation issues since free fragments + will appear as gaps in this printout. The amount of free space is reported + for each segment. + We distinguish between internal free memory which occurs because the + allocator rounds the allocation size, and external free memory, which are + the gaps between allocations in a segment. + Args: + data: snapshot dictionary created from _snapshot() + """ + segments = [] + out = io.StringIO() + out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n") + total_reserved = 0 + total_allocated = 0 + free_external = 0 + free_internal = 0 + for seg in sorted( + data["segments"], key=lambda x: (x["total_size"], calc_active(x)) + ): + total_reserved += seg["total_size"] + + seg_free_external = 0 + seg_free_internal = 0 + seg_allocated = 0 + all_ranges = [] + boffset = 0 + for b in seg["blocks"]: + active = b["state"] == "active_allocated" + if active: + _, allocated_size = _block_extra(b) + all_ranges.append((boffset, allocated_size, True)) + seg_allocated += allocated_size + seg_free_internal += b["size"] - allocated_size + else: + seg_free_external += b["size"] + + boffset += b["size"] + + total_allocated += seg_allocated + free_external += seg_free_external + free_internal += seg_free_internal + + nseg = (seg["total_size"] - 1) // PAGE_SIZE + 1 + occupied = [" " for _ in range(nseg)] + frac = [0.0 for _ in range(nseg)] + active_size = 0 + for i, (start_, size, active) in enumerate(all_ranges): + active_size += size + finish_ = start_ + size + start = start_ // PAGE_SIZE + finish = (finish_ - 1) // PAGE_SIZE + 1 + m = chr(ord("a" if active else "A") + (i % 26)) + for j in range(start, finish): + s = max(start_, j * PAGE_SIZE) + e = min(finish_, (j + 1) * PAGE_SIZE) + frac[j] += (e - s) / PAGE_SIZE + if occupied[j] != " ": + occupied[j] = "0123456789*"[int(frac[j] * 10)] + else: + occupied[j] = m + stream = "" if seg["stream"] == 0 else f', stream_{seg["stream"]}' + body = "".join(occupied) + assert ( + seg_free_external + seg_free_internal + seg_allocated == seg["total_size"] + ) + stream = f' stream_{seg["stream"]}' if seg["stream"] != 0 else "" + if seg["total_size"] >= PAGE_SIZE: + out.write( + f'[{body}] {Bytes(seg["total_size"])} allocated, ' + f"{_report_free(seg_free_external, seg_free_internal)} free{stream}\n" + ) + out.write(f'segments: {len(data["segments"])}\n') + out.write(f"total_reserved: {Bytes(total_reserved)}\n") + out.write(f"total_allocated: {Bytes(total_allocated)}\n") + internal_external = ( + f" ({Bytes(free_internal)} internal + {Bytes(free_external)} external)" + if free_internal + else "" + ) + out.write(f"total_free: {_report_free(free_external, free_internal)}\n") + out.write(legend) + assert free_internal + free_external + total_allocated == total_reserved + return out.getvalue() + + +def trace(data): + out = io.StringIO() + + def format(entries): + segment_intervals: list = [] + segment_addr_to_name = {} + allocation_addr_to_name = {} + + free_names: list = [] + next_name = 0 + + def _name(): + nonlocal next_name + if free_names: + return free_names.pop() + r, m = next_name // 26, next_name % 26 + next_name += 1 + return f'{chr(ord("a") + m)}{"" if r == 0 else r}' + + def find_segment(addr): + for name, saddr, size in segment_intervals: + if addr >= saddr and addr < saddr + size: + return name, saddr + for i, seg in enumerate(data["segments"]): + saddr = seg["address"] + size = seg["allocated_size"] + if addr >= saddr and addr < saddr + size: + return f"seg_{i}", saddr + return None, None + + count = 0 + out.write(f"{len(entries)} entries\n") + + total_reserved = 0 + for seg in data["segments"]: + total_reserved += seg["total_size"] + + for count, e in enumerate(entries): + if e["action"] == "alloc": + addr, size = e["addr"], e["size"] + n = _name() + seg_name, seg_addr = find_segment(addr) + if seg_name is None: + seg_name = "MEM" + offset = addr + else: + offset = addr - seg_addr + out.write(f"{n} = {seg_name}[{offset}:{Bytes(size)}]\n") + allocation_addr_to_name[addr] = (n, size, count) + count += size + elif e["action"] == "free_requested": + addr, size = e["addr"], e["size"] + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f"del {name} # {Bytes(size)}\n") + elif e["action"] == "free_completed": + addr, size = e["addr"], e["size"] + count -= size + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f"# free completed for {name} {Bytes(size)}\n") + if name in allocation_addr_to_name: + free_names.append(name) + del allocation_addr_to_name[name] + elif e["action"] == "segment_alloc": + addr, size = e["addr"], e["size"] + name = _name() + out.write(f"{name} = cudaMalloc({addr}, {Bytes(size)})\n") + segment_intervals.append((name, addr, size)) + segment_addr_to_name[addr] = name + elif e["action"] == "segment_free": + addr, size = e["addr"], e["size"] + name = segment_addr_to_name.get(addr, addr) + out.write(f"cudaFree({name}) # {Bytes(size)}\n") + if name in segment_addr_to_name: + free_names.append(name) + del segment_addr_to_name[name] + elif e["action"] == "oom": + size = e["size"] + free = e["device_free"] + out.write( + f"raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n" + ) + else: + out.write(f"{e}\n") + out.write(f"TOTAL MEM: {Bytes(count)}") + + for i, d in enumerate(data["device_traces"]): + if d: + out.write(f"Device {i} ----------------\n") + format(d) + return out.getvalue() + + +_memory_viz_template = r""" + + + + + + + +""" + + +def _format_viz(data, viz_kind, device): + if device is not None: + warnings.warn( + "device argument is deprecated, plots now contain all device", + FutureWarning, + stacklevel=3, + ) + buffer = pickle.dumps(data) + buffer += b"\x00" * (3 - len(buffer) % 3) + # Encode the buffer with base64 + encoded_buffer = base64.b64encode(buffer).decode("utf-8") + + json_format = json.dumps([{"name": "snapshot.pickle", "base64": encoded_buffer}]) + return _memory_viz_template.replace("$VIZ_KIND", repr(viz_kind)).replace( + "$SNAPSHOT", json_format + ) + + +def trace_plot(data, device=None, plot_segments=False): + """Generate a visualization over time of the memory usage recorded by the trace as an html file. + + Args: + data: Memory snapshot as generated from torch.cuda.memory._snapshot() + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations. + Defaults to False. + + Returns: + str: HTML of visualization + """ + return _format_viz( + data, + ( + "Active Memory Timeline" + if not plot_segments + else "Active Cached Memory Timeline" + ), + device, + ) + + +def _profile_to_snapshot(profile): + import torch + from torch.profiler._memory_profiler import Action, TensorKey + from torch._C._profiler import _EventType + + memory_profile = profile._memory_profile() + + allocation_stacks = {} + for event in memory_profile._op_tree.sorted_nodes: + if event.tag == _EventType.Allocation: + parent = event.parent + python_parents = [] + while parent: + if parent.tag in (_EventType.PyCall, _EventType.PyCCall): + python_parents.append(parent) + parent = parent.parent + key = TensorKey.from_allocation(event.extra_fields) + + # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor) + # key will be None. I should add some way to identify these, I just haven't yet. + if key and event.extra_fields.alloc_size > 0: + allocation_stacks[key] = python_parents + + device_count = torch.cuda.device_count() + snapshot = { + "device_traces": [[] for _ in range(device_count + 1)], + "segments": [ + { + "device": device, + "address": None, + "total_size": 0, + "stream": 0, + "blocks": [], + } + for device in range(device_count + 1) + ], + } + + def to_device(device): + if device.type == "cuda": + return device.index + else: + return device_count + + def allocate(size, tensor_key, version, during_trace=True): + device = to_device(tensor_key.device) + addr = tensor_key.storage.ptr + + seg = snapshot["segments"][device] # type: ignore[index] + if seg["address"] is None or seg["address"] > addr: + seg["address"] = addr + seg["total_size"] = max( + seg["total_size"], addr + size + ) # record max addr for now, we will make it the size later + category = memory_profile._categories.get(tensor_key, version) + category = category.name.lower() if category is not None else "unknown" + stack = allocation_stacks.get(tensor_key, ()) + stack = [{"filename": "none", "line": 0, "name": p.name} for p in stack] + r = { + "action": "alloc", + "addr": addr, + "size": size, + "stream": 0, + "frames": stack, + "category": category, + } + if during_trace: + snapshot["device_traces"][device].append(r) # type: ignore[index] + return r + + def free(alloc, device): + for e in ("free_requested", "free_completed"): + snapshot["device_traces"][device].append( + { + "action": e, # type: ignore[index] + "addr": alloc["addr"], + "size": alloc["size"], + "stream": 0, + "frames": alloc["frames"], + } + ) + + kv_to_elem = {} + + # create the device trace + for time, action, (tensor_key, version), size in memory_profile.timeline: + if not isinstance(tensor_key, TensorKey): + continue + if action == Action.CREATE: + kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version) + elif action == Action.DESTROY: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + elif action == Action.INCREMENT_VERSION: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + kv_to_elem[(tensor_key, version + 1)] = allocate( + size, tensor_key, version + 1 + ) + elif action == Action.PREEXISTING: + kv_to_elem[(tensor_key, version)] = allocate( + size, tensor_key, version, during_trace=False + ) + + # create the final snapshot state + blocks_at_end = [ + (to_device(tensor_key.device), event["addr"], event["size"], event["frames"]) + for (tensor_key, version), event in kv_to_elem.items() + ] + for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)): + seg = snapshot["segments"][device] # type: ignore[index] + last_addr = seg["address"] + for _, addr, size, frames in blocks: + if last_addr < addr: + seg["blocks"].append({"size": addr - last_addr, "state": "inactive"}) + seg["blocks"].append( + { + "size": size, + "state": "active_allocated", + "requested_size": size, + "frames": frames, + } + ) + last_addr = addr + size + if last_addr < seg["total_size"]: + seg["blocks"].append( + {"size": seg["total_size"] - last_addr, "state": "inactive"} + ) + + snapshot["segments"] = [seg for seg in snapshot["segments"] if seg["blocks"]] # type: ignore[attr-defined] + for seg in snapshot["segments"]: # type: ignore[attr-defined, name-defined, no-redef] + seg["total_size"] -= seg["address"] + if not seg["blocks"]: + seg["blocks"].append({"size": seg["total_size"], "state": "inactive"}) + + return snapshot + + +def profile_plot(profile, device=None): + """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file. + + Args: + profile: profile as generated by `torch.profiler.profile(profile_memory=True)` + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + + Returns: + str: HTML of visualization + """ + snapshot = _profile_to_snapshot(profile) + return _format_viz(snapshot, "Active Memory Timeline", device) + + +def segment_plot(data: Any, device=None): + return _format_viz(data, "Allocator State History", device) + + +if __name__ == "__main__": + import os.path + + thedir = os.path.realpath(os.path.dirname(__file__)) + if thedir in sys.path: + # otherwise we find cuda/random.py as random... + sys.path.remove(thedir) + import argparse + + fn_name = "torch.cuda.memory._snapshot()" + pickled = f"pickled memory statistics from {fn_name}" + parser = argparse.ArgumentParser( + description=f"Visualize memory dumps produced by {fn_name}" + ) + + subparsers = parser.add_subparsers(dest="action") + + def _output(p): + p.add_argument( + "-o", + "--output", + default="output.svg", + help="flamegraph svg (default: output.svg)", + ) + + description = "Prints overall allocation statistics and a visualization of how the allocators segments are currently filled." + stats_a = subparsers.add_parser("stats", description=description) + stats_a.add_argument("input", help=pickled) + + description = "Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style." + trace_a = subparsers.add_parser("trace", description=description) + trace_a.add_argument("input", help=pickled) + + description = "Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)" + segments_a = subparsers.add_parser("segments", description=description) + segments_a.add_argument("input", help=pickled) + _output(segments_a) + + description = ( + "Generate a flamegraph the program locations contributing to CUDA memory usage." + ) + memory_a = subparsers.add_parser("memory", description=description) + memory_a.add_argument("input", help=pickled) + _output(memory_a) + + description = ( + "Generate a flamegraph that shows segments (aka blocks) that have been added " + "or removed between two different memorys snapshots." + ) + compare_a = subparsers.add_parser("compare", description=description) + compare_a.add_argument("before", help=pickled) + compare_a.add_argument("after", help=pickled) + _output(compare_a) + + plots = ( + ( + "trace_plot", + "Generate a visualization over time of the memory usage recorded by the trace as an html file.", + ), + ( + "segment_plot", + "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.", + ), + ) + for cmd, description in plots: + trace_plot_a = subparsers.add_parser(cmd, description=description) + trace_plot_a.add_argument("input", help=pickled) + help = "visualize trace from this device (default: chooses the only device with trace info or errors)" + trace_plot_a.add_argument("-d", "--device", type=int, default=None, help=help) + help = "path to save the visualization(default: output.html)" + trace_plot_a.add_argument("-o", "--output", default="output.html", help=help) + if cmd == "trace_plot": + help = "visualize change to segments rather than individual allocations" + trace_plot_a.add_argument( + "-s", "--segments", action="store_true", help=help + ) + + args = parser.parse_args() + + def _read(name): + if name == "-": + f = sys.stdin.buffer + else: + f = open(name, "rb") + data = pickle.load(f) + if isinstance(data, list): # segments only... + data = {"segments": data, "traces": []} + return data + + def _write(name, data): + with open(name, "w") as f: + f.write(data) + + if args.action == "segments": + data = _read(args.input) + _write(args.output, segments(data)) + elif args.action == "memory": + data = _read(args.input) + _write(args.output, memory(data)) + elif args.action == "stats": + data = _read(args.input) + print(segsum(data)) + elif args.action == "trace": + data = _read(args.input) + print(trace(data)) + elif args.action == "compare": + before = _read(args.before) + after = _read(args.after) + _write(args.output, compare(before, after)) + elif args.action == "trace_plot": + data = _read(args.input) + _write( + args.output, + trace_plot(data, device=args.device, plot_segments=args.segments), + ) + elif args.action == "segment_plot": + data = _read(args.input) + _write(args.output, segment_plot(data, device=args.device)) diff --git a/notebooks/profiling/plot_GPU_memory_of_each_network.ipynb b/notebooks/profiling/plot_GPU_memory_of_each_network.ipynb new file mode 100644 index 00000000..be8813fb --- /dev/null +++ b/notebooks/profiling/plot_GPU_memory_of_each_network.ipynb @@ -0,0 +1,62 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from modelforge.utils.profiling import measure_performance_for_edge_sizes, plot_computation_time, plot_gpu_memory_usage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example usage:\n", + "edge_sizes = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5] # Edge sizes in nanometers\n", + "potential_names = ['painn', 'schnet', 'physnet', 'ani2x', 'aimnet2', 'sake'] \n", + "\n", + "results = measure_performance_for_edge_sizes(\n", + " edge_sizes=edge_sizes,\n", + " potential_names=potential_names,\n", + ")\n", + "\n", + "# Print the results\n", + "for result in results:\n", + " print(f\"Potential: {result['potential_name']}, \"\n", + " f\"Edge Size: {result['edge_size_nm']} nm, \"\n", + " f\"Number of Waters: {result['num_waters']}, \"\n", + " f\"Memory Usage: {result['memory_usage_bytes']/1e6:.2f} MB, \"\n", + " f\"Computation Time: {result['computation_time_s']*1000:.2f} ms\")\n", + "\n", + "# Plot the computation time\n", + "plot_computation_time(results)\n", + "plot_gpu_memory_usage(results)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "modelforge", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/profiling/profile_GPU_memory_of_a_single_network.ipynb b/notebooks/profiling/profile_GPU_memory_of_a_single_network.ipynb new file mode 100644 index 00000000..b4024c70 --- /dev/null +++ b/notebooks/profiling/profile_GPU_memory_of_a_single_network.ipynb @@ -0,0 +1,101 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from modelforge.tests.helper_functions import setup_potential_for_test\n", + "import torch\n", + "from modelforge.utils.profiling import start_record_memory_history, export_memory_snapshot, stop_record_memory_history, setup_waterbox_testsystem" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# --------------------------------------------------- #\n", + "# This script demonstrates how to record memory usage #\n", + "# --------------------------------------------------- #\n", + "# define the potential, device and precision\n", + "potential_name = 'SchNet'\n", + "precision = torch.float32\n", + "device = 'cuda'\n", + "\n", + "# setup the input and model\n", + "nnp_input = setup_waterbox_testsystem(2.5, device=device, precision=precision)\n", + "model = setup_potential_for_test(\n", + " potential_name,\n", + " \"inference\",\n", + " potential_seed=42,\n", + " use_training_mode_neighborlist=True,\n", + " simulation_environment='PyTorch',\n", + ").to(device, precision)\n", + "\n", + "# this is the function that will be profiled\n", + "def loop_to_record():\n", + " for _ in range(5):\n", + " # perform the forward pass through each of the models\n", + " r = model(nnp_input)[\"per_molecule_energy\"]\n", + " # Compute the gradient (forces) from the predicted energies\n", + " grad = torch.autograd.grad(\n", + " r,\n", + " nnp_input.positions,\n", + " grad_outputs=torch.ones_like(r),\n", + " create_graph=False,\n", + " retain_graph=False,\n", + " )[0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Start recording memory snapshot history\n", + "start_record_memory_history()\n", + "loop_to_record()\n", + "# Create the memory snapshot file\n", + "export_memory_snapshot()\n", + "# Stop recording memory snapshot history\n", + "stop_record_memory_history()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# the memory snapshot is a pickle file, to visualize this \n", + "# create a snapshot.html file using the following command\n", + "!python _memory_viz.py trace_plot a7srv5.pch.univie.ac.at_Oct_10_16_45_19.pickle -o snapshot.html" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "modelforge", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}