diff --git a/modelforge/tests/test_profiling.py b/modelforge/tests/test_profiling.py new file mode 100644 index 00000000..cba6bea7 --- /dev/null +++ b/modelforge/tests/test_profiling.py @@ -0,0 +1,56 @@ +import torch +import pytest + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_profiling_function(): + from modelforge.tests.helper_functions import setup_potential_for_test + import torch + from modelforge.utils.profiling import ( + start_record_memory_history, + export_memory_snapshot, + stop_record_memory_history, + setup_waterbox_testsystem, + ) + + # define the potential, device and precision + potential_name = "AimNet2" + precision = torch.float32 + device = "cuda" + + # setup the input and model + nnp_input = setup_waterbox_testsystem(2.5, device=device, precision=precision) + model = setup_potential_for_test( + potential_name, + "inference", + potential_seed=42, + use_training_mode_neighborlist=True, + simulation_environment="PyTorch", + ).to(device, precision) + # Disable gradients for model parameters + for param in model.parameters(): + param.requires_grad = False + # Set model to eval + model.eval() + + # this is the function that will be profiled + def loop_to_record(): + for _ in range(5): + # perform the forward pass through each of the models + r = model(nnp_input)["per_system_energy"] + # Compute the gradient (forces) from the predicted energies + grad = torch.autograd.grad( + r, + nnp_input.positions, + grad_outputs=torch.ones_like(r), + create_graph=False, + retain_graph=False, + )[0] + + # Start recording memory snapshot history + start_record_memory_history() + loop_to_record() + # Create the memory snapshot file + export_memory_snapshot() + # Stop recording memory snapshot history + stop_record_memory_history() diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 1d71eb7e..5013f8ea 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -1263,9 +1263,7 @@ def _log_time(self): epoch_time = time.time() - self.epoch_start_time if isinstance(self.logger, pL.loggers.WandbLogger): # Log epoch duration to W&B - self.logger.experiment.log( - {"epoch_time": epoch_time, "epoch": self.current_epoch} - ) + self.log("train/epoch_time", epoch_time) else: log.warning("Weights & Biases logger not found; epoch time not logged.") diff --git a/modelforge/utils/io.py b/modelforge/utils/io.py index 8f0eab42..489370ce 100644 --- a/modelforge/utils/io.py +++ b/modelforge/utils/io.py @@ -159,6 +159,14 @@ conda install conda-forge::wandb """ +MESSAGES[ + "openmmtools" +] = """ +A batteries-included toolkit for the GPU-accelerated OpenMM molecular simulation engine. +OpenMMTools can be installed via conda: + conda install conda-forge::openmmtools + +""" def import_(module: str): diff --git a/modelforge/utils/profiling.py b/modelforge/utils/profiling.py new file mode 100644 index 00000000..2a9916f5 --- /dev/null +++ b/modelforge/utils/profiling.py @@ -0,0 +1,279 @@ +import torch +from loguru import logger as log +import socket +from datetime import datetime +from modelforge.dataset.dataset import NNPInput + +TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S" +MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000 + + +def setup_waterbox_testsystem( + edge_size_in_nm: float, + device: torch.device, + precision: torch.dtype, +) -> NNPInput: + from modelforge.utils.io import import_ + + openmmtools = import_("openmmtools") + from simtk import unit + from modelforge.dataset.dataset import NNPInput + + test_system = openmmtools.testsystems.WaterBox( + box_edge=edge_size_in_nm * unit.nanometer + ) + positions = test_system.positions # Positions in nanometers + topology = test_system.topology + + # Extract atomic numbers and residue indices + atomic_numbers = [] + residue_indices = [] + for residue_index, residue in enumerate(topology.residues()): + for atom in residue.atoms(): + atomic_numbers.append(atom.element.atomic_number) + residue_indices.append(residue_index) + num_waters = len(list(topology.residues())) + positions_in_nanometers = positions.value_in_unit(unit.nanometer) + + # Convert to torch tensors and move to GPU + torch_atomic_numbers = torch.tensor(atomic_numbers, dtype=torch.long, device=device) + torch_positions = torch.tensor( + positions_in_nanometers, dtype=torch.float32, device=device, requires_grad=True + ) + torch_atomic_subsystem_indices = torch.zeros_like( + torch_atomic_numbers, dtype=torch.long, device=device + ) + torch_total_charge = torch.zeros((1, 1), dtype=torch.float32, device=device) + + log.info(f"Waterbox system setup with {num_waters} waters") + return NNPInput( + atomic_numbers=torch_atomic_numbers, + positions=torch_positions, + atomic_subsystem_indices=torch_atomic_subsystem_indices, + per_system_total_charge=torch_total_charge, + ).to_dtype(dtype=precision) + + +from typing import List +import time + + +def measure_performance_for_edge_sizes( + edge_sizes: List[float], + potential_names: List[str], +): + """ + Measures GPU memory utilization and computation time for force calculations + for water boxes of different edge sizes across multiple potentials. + Parameters + ---------- + edge_sizes : List[float] + A list of edge sizes (in nanometers) for the water boxes. + potential_names : List[str] + A list of potential names to use in the model setup. + Returns + ------- + List[dict] + A list of dictionaries containing edge size, number of water molecules, + potential name, memory usage in bytes, and computation time in seconds. + """ + results = [] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + precicion = torch.float32 + for potential_name in potential_names: + for edge_size in edge_sizes: + + nnp_input = setup_waterbox_testsystem( + edge_size, + device, + precicion, + ) + + # Import your model setup function + from modelforge.tests.helper_functions import setup_potential_for_test + + # Setup model + model = setup_potential_for_test( + potential_name, + "inference", + potential_seed=42, + use_training_mode_neighborlist=False, + simulation_environment="PyTorch", + ) + + model.to(device) + model.to(precicion) + total_params = sum(p.numel() for p in model.parameters()) + + # Measure GPU memory usage and computation time + torch.cuda.reset_peak_memory_stats(device=device) + torch.cuda.synchronize() + + # Run forward pass and time it + start_time = time.perf_counter() + try: + output = model(nnp_input.as_namedtuple())["per_molecule_energy"] + except: + print("Out of memory error during forward pass") + continue + + try: + F_training = -torch.autograd.grad( + output.sum(), + nnp_input.positions, + create_graph=False, + retain_graph=False, + )[0] + except: + print("Out of memory error during backward pass") + continue + torch.cuda.synchronize() + end_time = time.perf_counter() + + max_memory_allocated = torch.cuda.max_memory_allocated(device=device) + computation_time = end_time - start_time + + results.append( + { + "potential_name": f"{potential_name}: {total_params:.1e} params", + "edge_size_nm": edge_size, + "num_waters": num_waters, + "memory_usage_bytes": max_memory_allocated, + "computation_time_s": computation_time, + } + ) + + # Clean up + del ( + nnp_input, + output, + model, + ) + try: + del F_training + except: + pass + torch.cuda.empty_cache() + time.sleep(1) # Sleep for a second to allow GPU memory to be freed + + return results + + +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt + + +def plot_computation_time(results): + """ + Plots computation time against the number of water molecules for multiple potentials. + Parameters + ---------- + results : List[dict] + A list of dictionaries containing edge size, number of water molecules, + potential name, memory usage in bytes, and computation time in seconds. + """ + # Create a DataFrame for plotting + df = pd.DataFrame(results) + df["computation_time_ms"] = ( + df["computation_time_s"] * 1000 + ) # Convert seconds to milliseconds + + # Plot using seaborn + sns.set(style="whitegrid") + plt.figure(figsize=(10, 6)) + sns.lineplot( + data=df, + x="num_waters", + y="computation_time_ms", + hue="potential_name", + units="potential_name", + estimator=None, # Do not aggregate data + marker="o", + linewidth=2, + markersize=8, + ) + plt.title("Computation Time vs Number of Water Molecules for Different Potentials") + plt.xlabel("Number of Water Molecules") + plt.ylabel("Computation Time (ms)") + plt.xticks(sorted(df["num_waters"].unique())) + plt.legend(title="Potential Name") + plt.tight_layout() + plt.show() + + +def plot_gpu_memory_usage(results): + """ + Plots GPU memory usage against the number of water molecules for multiple potentials. + Parameters + ---------- + results : List[dict] + A list of dictionaries containing edge size, number of water molecules, + potential name, and memory usage in bytes. + """ + # Create a DataFrame for plotting + df = pd.DataFrame(results) + df["memory_usage_mb"] = df["memory_usage_bytes"] / 1e6 # Convert bytes to megabytes + + # Plot using seaborn + sns.set(style="whitegrid") + plt.figure(figsize=(10, 6)) + sns.lineplot( + data=df, + x="num_waters", + y="memory_usage_mb", + units="potential_name", + estimator=None, # Do not aggregate data + hue="potential_name", + marker="o", + linewidth=2, + markersize=8, + ) + plt.title( + "Backward pass: GPU Memory Usage vs Number of Water Molecules for Different Potentials" + ) + plt.xlabel("Number of Water Molecules") + plt.ylabel("GPU Memory Usage (MB)") + plt.xticks(sorted(df["num_waters"].unique())) + plt.legend(title="Potential Name") + plt.tight_layout() + plt.show() + + +def start_record_memory_history() -> None: + if not torch.cuda.is_available(): + log.info("CUDA unavailable. Not recording memory history") + return + + log.info("Starting snapshot record_memory_history") + torch.cuda.memory._record_memory_history( + max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT + ) + + +def stop_record_memory_history() -> None: + if not torch.cuda.is_available(): + log.info("CUDA unavailable. Not stopping memory history") + return + + log.info("Stopping snapshot record_memory_history") + torch.cuda.memory._record_memory_history(enabled=None) + + +def export_memory_snapshot() -> None: + if not torch.cuda.is_available(): + log.info("CUDA unavailable. Not exporting memory snapshot") + return + + # Prefix for file names. + host_name = socket.gethostname() + timestamp = datetime.now().strftime(TIME_FORMAT_STR) + file_prefix = f"{host_name}_{timestamp}" + + try: + log.info(f"Saving snapshot to local file: {file_prefix}.pickle") + torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle") + except Exception as e: + log.error(f"Failed to capture memory snapshot {e}") + return + return f"{file_prefix}.pickle" diff --git a/notebooks/profiling/_memory_viz.py b/notebooks/profiling/_memory_viz.py new file mode 100644 index 00000000..a961fed6 --- /dev/null +++ b/notebooks/profiling/_memory_viz.py @@ -0,0 +1,732 @@ +import pickle +import sys +import os +import io +import subprocess +import json +from functools import lru_cache +from typing import Any +from itertools import groupby +import base64 +import warnings +import operator + +cache = lru_cache(None) + +__all__ = ["format_flamegraph", "segments", "memory", "compare"] + + +def _frame_fmt(f, full_filename=False): + i = f["line"] + fname = f["filename"] + if not full_filename: + fname = fname.split("/")[-1] + func = f["name"] + return f"{fname}:{i}:{func}" + + +@cache +def _frame_filter(name, filename): + omit_functions = [ + "unwind::unwind", + "CapturedTraceback::gather", + "gather_with_cpp", + "_start", + "__libc_start_main", + "PyEval_", + "PyObject_", + "PyFunction_", + ] + omit_filenames = [ + "core/boxing", + "/Register", + "/Redispatch", + "pythonrun.c", + "Modules/main.c", + "Objects/call.c", + "Objects/methodobject.c", + "pycore_ceval.h", + "ceval.c", + "cpython/abstract.h", + ] + for of in omit_functions: + if of in name: + return False + for of in omit_filenames: + if of in filename: + return False + return True + + +def _frames_fmt(frames, full_filename=False, reverse=False): + if reverse: + frames = reversed(frames) + return [ + _frame_fmt(f, full_filename) + for f in frames + if _frame_filter(f["name"], f["filename"]) + ] + + +def _block_extra_legacy(b): + if "history" in b: + frames = b["history"][0].get("frames", []) + real_size = b["history"][0]["real_size"] + else: + real_size = b.get("requested_size", b["size"]) + frames = [] + return frames, real_size + + +def _block_extra(b): + if "frames" not in b: + # old snapshot format made it more complicated to get frames/allocated size + return _block_extra_legacy(b) + return b["frames"], b["requested_size"] + + +def format_flamegraph(flamegraph_lines, flamegraph_script=None): + if flamegraph_script is None: + flamegraph_script = f"/tmp/{os.getuid()}_flamegraph.pl" + if not os.path.exists(flamegraph_script): + import urllib.request + + print(f"Downloading flamegraph.pl to: {flamegraph_script}") + urllib.request.urlretrieve( + "https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl", + flamegraph_script, + ) + subprocess.check_call(["chmod", "+x", flamegraph_script]) + args = [flamegraph_script, "--countname", "bytes"] + p = subprocess.Popen( + args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding="utf-8" + ) + assert p.stdin is not None + assert p.stdout is not None + p.stdin.write(flamegraph_lines) + p.stdin.close() + result = p.stdout.read() + p.stdout.close() + p.wait() + assert p.wait() == 0 + return result + + +def _write_blocks(f, prefix, blocks): + def frames_fragment(frames): + if not frames: + return "" + return ";".join(_frames_fmt(frames, reverse=True)) + + for b in blocks: + if "history" not in b: + frames, accounted_for_size = _block_extra(b) + f.write( + f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n' + ) + else: + accounted_for_size = 0 + for h in b["history"]: + sz = h["real_size"] + accounted_for_size += sz + if "frames" in h: + frames = h["frames"] + f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n') + else: + f.write(f'{prefix};{b["state"]}; {sz}\n') + gaps = b["size"] - accounted_for_size + if gaps: + f.write(f'{prefix};{b["state"]}; {gaps}\n') + + +def segments(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot["segments"]: + prefix = f'stream_{seg["stream"]};seg_{seg["address"]}' + _write_blocks(f, prefix, seg["blocks"]) + return format_flamegraph(f.getvalue()) + + +def memory(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot["segments"]: + prefix = f'stream_{seg["stream"]}' + _write_blocks(f, prefix, seg["blocks"]) + return format_flamegraph(f.getvalue()) + + +def compare(before, after, format_flamegraph=format_flamegraph): + def _seg_key(seg): + return (seg["address"], seg["total_size"]) + + def _seg_info(seg): + return f'stream_{seg["stream"]};seg_{seg["address"]}' + + f = io.StringIO() + + before_segs = {_seg_key(seg) for seg in before} + after_segs = {_seg_key(seg) for seg in after} + + print(f"only_before = {[a for a, _ in (before_segs - after_segs)]}") + print(f"only_after = {[a for a, _ in (after_segs - before_segs)]}") + + for seg in before: + if _seg_key(seg) not in after_segs: + _write_blocks(f, f"only_before;{_seg_info(seg)}", seg["blocks"]) + + for seg in after: + if _seg_key(seg) not in before_segs: + _write_blocks(f, f"only_after;{_seg_info(seg)}", seg["blocks"]) + + return format_flamegraph(f.getvalue()) + + +def _format_size(num): + # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: + if abs(num) < 1024.0: + return f"{num:3.1f}{unit}B" + num /= 1024.0 + return f"{num:.1f}YiB" + + +class Bytes: + def __init__(self, value): + self.value = value + + def __add__(self, rhs): + return Bytes(self.value + rhs) + + def __repr__(self): + return _format_size(self.value) + + +def calc_active(seg): + return sum(b["size"] for b in seg["blocks"] if b["state"] == "active_allocated") + + +def _report_free(free_external, free_internal): + total = free_external + free_internal + suffix = "" + if total != 0: + pct = (free_internal / total) * 100 + suffix = f" ({pct:.1f}% internal)" + return f"{Bytes(total)}{suffix}" + + +PAGE_SIZE = 1024 * 1024 * 20 +legend = f"""\ +Legend: + [a ] - a segment in the allocator + ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment + a-z: pages filled with a single block's content + ' ': page is completely free + *: page if completely full with multiple blocks + 0-9: page is partially full with tensors of multiple blocks (9 == 90% full) + (X% internal) - of the free memory, X% is free because we rounded the size of the allocation. +""" + + +def segsum(data): + r"""Visually reports how the allocator has filled its segments. + This printout can help debug fragmentation issues since free fragments + will appear as gaps in this printout. The amount of free space is reported + for each segment. + We distinguish between internal free memory which occurs because the + allocator rounds the allocation size, and external free memory, which are + the gaps between allocations in a segment. + Args: + data: snapshot dictionary created from _snapshot() + """ + segments = [] + out = io.StringIO() + out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n") + total_reserved = 0 + total_allocated = 0 + free_external = 0 + free_internal = 0 + for seg in sorted( + data["segments"], key=lambda x: (x["total_size"], calc_active(x)) + ): + total_reserved += seg["total_size"] + + seg_free_external = 0 + seg_free_internal = 0 + seg_allocated = 0 + all_ranges = [] + boffset = 0 + for b in seg["blocks"]: + active = b["state"] == "active_allocated" + if active: + _, allocated_size = _block_extra(b) + all_ranges.append((boffset, allocated_size, True)) + seg_allocated += allocated_size + seg_free_internal += b["size"] - allocated_size + else: + seg_free_external += b["size"] + + boffset += b["size"] + + total_allocated += seg_allocated + free_external += seg_free_external + free_internal += seg_free_internal + + nseg = (seg["total_size"] - 1) // PAGE_SIZE + 1 + occupied = [" " for _ in range(nseg)] + frac = [0.0 for _ in range(nseg)] + active_size = 0 + for i, (start_, size, active) in enumerate(all_ranges): + active_size += size + finish_ = start_ + size + start = start_ // PAGE_SIZE + finish = (finish_ - 1) // PAGE_SIZE + 1 + m = chr(ord("a" if active else "A") + (i % 26)) + for j in range(start, finish): + s = max(start_, j * PAGE_SIZE) + e = min(finish_, (j + 1) * PAGE_SIZE) + frac[j] += (e - s) / PAGE_SIZE + if occupied[j] != " ": + occupied[j] = "0123456789*"[int(frac[j] * 10)] + else: + occupied[j] = m + stream = "" if seg["stream"] == 0 else f', stream_{seg["stream"]}' + body = "".join(occupied) + assert ( + seg_free_external + seg_free_internal + seg_allocated == seg["total_size"] + ) + stream = f' stream_{seg["stream"]}' if seg["stream"] != 0 else "" + if seg["total_size"] >= PAGE_SIZE: + out.write( + f'[{body}] {Bytes(seg["total_size"])} allocated, ' + f"{_report_free(seg_free_external, seg_free_internal)} free{stream}\n" + ) + out.write(f'segments: {len(data["segments"])}\n') + out.write(f"total_reserved: {Bytes(total_reserved)}\n") + out.write(f"total_allocated: {Bytes(total_allocated)}\n") + internal_external = ( + f" ({Bytes(free_internal)} internal + {Bytes(free_external)} external)" + if free_internal + else "" + ) + out.write(f"total_free: {_report_free(free_external, free_internal)}\n") + out.write(legend) + assert free_internal + free_external + total_allocated == total_reserved + return out.getvalue() + + +def trace(data): + out = io.StringIO() + + def format(entries): + segment_intervals: list = [] + segment_addr_to_name = {} + allocation_addr_to_name = {} + + free_names: list = [] + next_name = 0 + + def _name(): + nonlocal next_name + if free_names: + return free_names.pop() + r, m = next_name // 26, next_name % 26 + next_name += 1 + return f'{chr(ord("a") + m)}{"" if r == 0 else r}' + + def find_segment(addr): + for name, saddr, size in segment_intervals: + if addr >= saddr and addr < saddr + size: + return name, saddr + for i, seg in enumerate(data["segments"]): + saddr = seg["address"] + size = seg["allocated_size"] + if addr >= saddr and addr < saddr + size: + return f"seg_{i}", saddr + return None, None + + count = 0 + out.write(f"{len(entries)} entries\n") + + total_reserved = 0 + for seg in data["segments"]: + total_reserved += seg["total_size"] + + for count, e in enumerate(entries): + if e["action"] == "alloc": + addr, size = e["addr"], e["size"] + n = _name() + seg_name, seg_addr = find_segment(addr) + if seg_name is None: + seg_name = "MEM" + offset = addr + else: + offset = addr - seg_addr + out.write(f"{n} = {seg_name}[{offset}:{Bytes(size)}]\n") + allocation_addr_to_name[addr] = (n, size, count) + count += size + elif e["action"] == "free_requested": + addr, size = e["addr"], e["size"] + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f"del {name} # {Bytes(size)}\n") + elif e["action"] == "free_completed": + addr, size = e["addr"], e["size"] + count -= size + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f"# free completed for {name} {Bytes(size)}\n") + if name in allocation_addr_to_name: + free_names.append(name) + del allocation_addr_to_name[name] + elif e["action"] == "segment_alloc": + addr, size = e["addr"], e["size"] + name = _name() + out.write(f"{name} = cudaMalloc({addr}, {Bytes(size)})\n") + segment_intervals.append((name, addr, size)) + segment_addr_to_name[addr] = name + elif e["action"] == "segment_free": + addr, size = e["addr"], e["size"] + name = segment_addr_to_name.get(addr, addr) + out.write(f"cudaFree({name}) # {Bytes(size)}\n") + if name in segment_addr_to_name: + free_names.append(name) + del segment_addr_to_name[name] + elif e["action"] == "oom": + size = e["size"] + free = e["device_free"] + out.write( + f"raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n" + ) + else: + out.write(f"{e}\n") + out.write(f"TOTAL MEM: {Bytes(count)}") + + for i, d in enumerate(data["device_traces"]): + if d: + out.write(f"Device {i} ----------------\n") + format(d) + return out.getvalue() + + +_memory_viz_template = r""" + + + + + + + +""" + + +def _format_viz(data, viz_kind, device): + if device is not None: + warnings.warn( + "device argument is deprecated, plots now contain all device", + FutureWarning, + stacklevel=3, + ) + buffer = pickle.dumps(data) + buffer += b"\x00" * (3 - len(buffer) % 3) + # Encode the buffer with base64 + encoded_buffer = base64.b64encode(buffer).decode("utf-8") + + json_format = json.dumps([{"name": "snapshot.pickle", "base64": encoded_buffer}]) + return _memory_viz_template.replace("$VIZ_KIND", repr(viz_kind)).replace( + "$SNAPSHOT", json_format + ) + + +def trace_plot(data, device=None, plot_segments=False): + """Generate a visualization over time of the memory usage recorded by the trace as an html file. + Args: + data: Memory snapshot as generated from torch.cuda.memory._snapshot() + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations. + Defaults to False. + Returns: + str: HTML of visualization + """ + return _format_viz( + data, + ( + "Active Memory Timeline" + if not plot_segments + else "Active Cached Memory Timeline" + ), + device, + ) + + +def _profile_to_snapshot(profile): + import torch + from torch.profiler._memory_profiler import Action, TensorKey + from torch._C._profiler import _EventType + + memory_profile = profile._memory_profile() + + allocation_stacks = {} + for event in memory_profile._op_tree.sorted_nodes: + if event.tag == _EventType.Allocation: + parent = event.parent + python_parents = [] + while parent: + if parent.tag in (_EventType.PyCall, _EventType.PyCCall): + python_parents.append(parent) + parent = parent.parent + key = TensorKey.from_allocation(event.extra_fields) + + # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor) + # key will be None. I should add some way to identify these, I just haven't yet. + if key and event.extra_fields.alloc_size > 0: + allocation_stacks[key] = python_parents + + device_count = torch.cuda.device_count() + snapshot = { + "device_traces": [[] for _ in range(device_count + 1)], + "segments": [ + { + "device": device, + "address": None, + "total_size": 0, + "stream": 0, + "blocks": [], + } + for device in range(device_count + 1) + ], + } + + def to_device(device): + if device.type == "cuda": + return device.index + else: + return device_count + + def allocate(size, tensor_key, version, during_trace=True): + device = to_device(tensor_key.device) + addr = tensor_key.storage.ptr + + seg = snapshot["segments"][device] # type: ignore[index] + if seg["address"] is None or seg["address"] > addr: + seg["address"] = addr + seg["total_size"] = max( + seg["total_size"], addr + size + ) # record max addr for now, we will make it the size later + category = memory_profile._categories.get(tensor_key, version) + category = category.name.lower() if category is not None else "unknown" + stack = allocation_stacks.get(tensor_key, ()) + stack = [{"filename": "none", "line": 0, "name": p.name} for p in stack] + r = { + "action": "alloc", + "addr": addr, + "size": size, + "stream": 0, + "frames": stack, + "category": category, + } + if during_trace: + snapshot["device_traces"][device].append(r) # type: ignore[index] + return r + + def free(alloc, device): + for e in ("free_requested", "free_completed"): + snapshot["device_traces"][device].append( + { + "action": e, # type: ignore[index] + "addr": alloc["addr"], + "size": alloc["size"], + "stream": 0, + "frames": alloc["frames"], + } + ) + + kv_to_elem = {} + + # create the device trace + for time, action, (tensor_key, version), size in memory_profile.timeline: + if not isinstance(tensor_key, TensorKey): + continue + if action == Action.CREATE: + kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version) + elif action == Action.DESTROY: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + elif action == Action.INCREMENT_VERSION: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + kv_to_elem[(tensor_key, version + 1)] = allocate( + size, tensor_key, version + 1 + ) + elif action == Action.PREEXISTING: + kv_to_elem[(tensor_key, version)] = allocate( + size, tensor_key, version, during_trace=False + ) + + # create the final snapshot state + blocks_at_end = [ + (to_device(tensor_key.device), event["addr"], event["size"], event["frames"]) + for (tensor_key, version), event in kv_to_elem.items() + ] + for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)): + seg = snapshot["segments"][device] # type: ignore[index] + last_addr = seg["address"] + for _, addr, size, frames in blocks: + if last_addr < addr: + seg["blocks"].append({"size": addr - last_addr, "state": "inactive"}) + seg["blocks"].append( + { + "size": size, + "state": "active_allocated", + "requested_size": size, + "frames": frames, + } + ) + last_addr = addr + size + if last_addr < seg["total_size"]: + seg["blocks"].append( + {"size": seg["total_size"] - last_addr, "state": "inactive"} + ) + + snapshot["segments"] = [seg for seg in snapshot["segments"] if seg["blocks"]] # type: ignore[attr-defined] + for seg in snapshot["segments"]: # type: ignore[attr-defined, name-defined, no-redef] + seg["total_size"] -= seg["address"] + if not seg["blocks"]: + seg["blocks"].append({"size": seg["total_size"], "state": "inactive"}) + + return snapshot + + +def profile_plot(profile, device=None): + """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file. + Args: + profile: profile as generated by `torch.profiler.profile(profile_memory=True)` + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + Returns: + str: HTML of visualization + """ + snapshot = _profile_to_snapshot(profile) + return _format_viz(snapshot, "Active Memory Timeline", device) + + +def segment_plot(data: Any, device=None): + return _format_viz(data, "Allocator State History", device) + + +if __name__ == "__main__": + import os.path + + thedir = os.path.realpath(os.path.dirname(__file__)) + if thedir in sys.path: + # otherwise we find cuda/random.py as random... + sys.path.remove(thedir) + import argparse + + fn_name = "torch.cuda.memory._snapshot()" + pickled = f"pickled memory statistics from {fn_name}" + parser = argparse.ArgumentParser( + description=f"Visualize memory dumps produced by {fn_name}" + ) + + subparsers = parser.add_subparsers(dest="action") + + def _output(p): + p.add_argument( + "-o", + "--output", + default="output.svg", + help="flamegraph svg (default: output.svg)", + ) + + description = "Prints overall allocation statistics and a visualization of how the allocators segments are currently filled." + stats_a = subparsers.add_parser("stats", description=description) + stats_a.add_argument("input", help=pickled) + + description = "Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style." + trace_a = subparsers.add_parser("trace", description=description) + trace_a.add_argument("input", help=pickled) + + description = "Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)" + segments_a = subparsers.add_parser("segments", description=description) + segments_a.add_argument("input", help=pickled) + _output(segments_a) + + description = ( + "Generate a flamegraph the program locations contributing to CUDA memory usage." + ) + memory_a = subparsers.add_parser("memory", description=description) + memory_a.add_argument("input", help=pickled) + _output(memory_a) + + description = ( + "Generate a flamegraph that shows segments (aka blocks) that have been added " + "or removed between two different memorys snapshots." + ) + compare_a = subparsers.add_parser("compare", description=description) + compare_a.add_argument("before", help=pickled) + compare_a.add_argument("after", help=pickled) + _output(compare_a) + + plots = ( + ( + "trace_plot", + "Generate a visualization over time of the memory usage recorded by the trace as an html file.", + ), + ( + "segment_plot", + "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.", + ), + ) + for cmd, description in plots: + trace_plot_a = subparsers.add_parser(cmd, description=description) + trace_plot_a.add_argument("input", help=pickled) + help = "visualize trace from this device (default: chooses the only device with trace info or errors)" + trace_plot_a.add_argument("-d", "--device", type=int, default=None, help=help) + help = "path to save the visualization(default: output.html)" + trace_plot_a.add_argument("-o", "--output", default="output.html", help=help) + if cmd == "trace_plot": + help = "visualize change to segments rather than individual allocations" + trace_plot_a.add_argument( + "-s", "--segments", action="store_true", help=help + ) + + args = parser.parse_args() + + def _read(name): + if name == "-": + f = sys.stdin.buffer + else: + f = open(name, "rb") + data = pickle.load(f) + if isinstance(data, list): # segments only... + data = {"segments": data, "traces": []} + return data + + def _write(name, data): + with open(name, "w") as f: + f.write(data) + + if args.action == "segments": + data = _read(args.input) + _write(args.output, segments(data)) + elif args.action == "memory": + data = _read(args.input) + _write(args.output, memory(data)) + elif args.action == "stats": + data = _read(args.input) + print(segsum(data)) + elif args.action == "trace": + data = _read(args.input) + print(trace(data)) + elif args.action == "compare": + before = _read(args.before) + after = _read(args.after) + _write(args.output, compare(before, after)) + elif args.action == "trace_plot": + data = _read(args.input) + _write( + args.output, + trace_plot(data, device=args.device, plot_segments=args.segments), + ) + elif args.action == "segment_plot": + data = _read(args.input) + _write(args.output, segment_plot(data, device=args.device)) diff --git a/notebooks/profiling/profile_GPU_memory_of_a_single_network.ipynb b/notebooks/profiling/profile_GPU_memory_of_a_single_network.ipynb new file mode 100644 index 00000000..d4c4c8b4 --- /dev/null +++ b/notebooks/profiling/profile_GPU_memory_of_a_single_network.ipynb @@ -0,0 +1,112 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from modelforge.tests.helper_functions import setup_potential_for_test\n", + "import torch\n", + "from modelforge.utils.profiling import start_record_memory_history, export_memory_snapshot, stop_record_memory_history, setup_waterbox_testsystem\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# --------------------------------------------------- #\n", + "# This script demonstrates how to record memory usage #\n", + "# --------------------------------------------------- #\n", + "# define the potential, device and precision\n", + "potential_name = 'AimNet2'\n", + "precision = torch.float32\n", + "device = 'cuda'\n", + "\n", + "# setup the input and model\n", + "nnp_input = setup_waterbox_testsystem(2.5, device=device, precision=precision)\n", + "model = setup_potential_for_test(\n", + " potential_name,\n", + " \"training\",\n", + " potential_seed=42,\n", + " use_training_mode_neighborlist=True,\n", + " simulation_environment='PyTorch',\n", + ")['trainer']\n", + "\n", + "# this is the function that will be profiled\n", + "def loop_to_record():\n", + " for _ in range(5):\n", + " # perform the forward pass through each of the models\n", + " r = model(nnp_input)[\"per_system_energy\"]\n", + " # Compute the gradient (forces) from the predicted energies\n", + " grad = torch.autograd.grad(\n", + " r,\n", + " nnp_input.positions,\n", + " grad_outputs=torch.ones_like(r),\n", + " create_graph=False,\n", + " retain_graph=False,\n", + " )[0]\n", + "\n", + "def loop_to_record():\n", + " model.train_potential()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Start recording memory snapshot history\n", + "start_record_memory_history()\n", + "loop_to_record()\n", + "# Create the memory snapshot file\n", + "file_name = export_memory_snapshot()\n", + "print(file_name)\n", + "# Stop recording memory snapshot history\n", + "stop_record_memory_history()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# the memory snapshot is a pickle file, to visualize this \n", + "# create a snapshot.html file using the following command\n", + "!python _memory_viz.py trace_plot a7srv5.pch.univie.ac.at_Nov_09_21_18_29.pickle -o snapshot.html" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "modelforge", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}