Skip to content
Open
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 139 additions & 9 deletions modelopt/onnx/quantization/autotune/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@
import importlib.util
import os
import re
import shlex
import shutil
import subprocess # nosec B404
import tempfile
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
from urllib.parse import parse_qs, urlparse

import numpy as np
import torch
Expand Down Expand Up @@ -145,6 +147,13 @@ def _write_log_file(self, file: Path | str | None, content: str) -> None:
self.logger.warning(f"Failed to save logs to {file}: {e}")


safe_pattern = (
r"\[\d{2}/\d{2}/\d{4}-\d{2}:\d{2}:\d{2}\]\s+\[I\]\s+"
r"Average over \d+ runs - GPU latency:\s*([\d.]+)\s*ms"
)
std_pattern = r"\[I\]\s+GPU Compute Time:.*?median\s*=\s*([\d.]+)\s*ms"


class TrtExecBenchmark(Benchmark):
"""TensorRT benchmark using trtexec command-line tool.

Expand Down Expand Up @@ -183,7 +192,6 @@ def __init__(
self.temp_model_path = os.path.join(self.temp_dir, "temp_model.onnx")
self.logger.debug(f"Created temporary engine directory: {self.temp_dir}")
self.logger.debug(f"Temporary model path: {self.temp_model_path}")
self.latency_pattern = r"\[I\]\s+Latency:.*?median\s*=\s*([\d.]+)\s*ms"

self._base_cmd = [
self.trtexec_path,
Expand All @@ -204,9 +212,66 @@ def __init__(
self.logger.debug(f"Added plugin library: {plugin_path}")

trtexec_args = self.trtexec_args or []
has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args)

if has_remote_config:
self.has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args)
self.remote_ip: str | None = None
self.remote_port: int = 22
self.remote_user: str = "root"
self.remote_password: str = ""
self.remote_engine_path: str = "trtexec_benchmark_model.trt"
self.remote_bin_path: str = "trtexec"
self.remote_timeout_sec = 300

if self.has_remote_config:
remote_config = [arg for arg in trtexec_args if "--remoteAutoTuningConfig" in arg]
if len(remote_config) != 1:
raise ValueError("Exactly one --remoteAutoTuningConfig argument is required")
# Parse --remoteAutoTuningConfig argument, which may be given as:
# ('--remoteAutoTuningConfig=ssh://user:pass@host:port?...') or
# ('--remoteAutoTuningConfig', 'ssh://user:pass@host:port?...')
#
# The logic: find the arg starting with '--remoteAutoTuningConfig'
# If formatted as '--remoteAutoTuningConfig=...', split off the '='
# Otherwise, grab the next argument.
config_arg_value: str | None = None
for i, arg in enumerate(trtexec_args):
if arg.startswith("--remoteAutoTuningConfig"):
if arg == "--remoteAutoTuningConfig":
# Value should be the next argument
if i + 1 < len(trtexec_args):
config_arg_value = trtexec_args[i + 1]
else:
raise ValueError("Missing value for --remoteAutoTuningConfig")
elif arg.startswith("--remoteAutoTuningConfig="):
config_arg_value = arg.split("=", 1)[1]
else:
raise ValueError(f"Malformed --remoteAutoTuningConfig argument: {arg}")
break
if not config_arg_value:
raise ValueError("Could not parse --remoteAutoTuningConfig argument")
remote_config_str: str = config_arg_value

if not remote_config_str.startswith("ssh://"):
raise ValueError("Only 'ssh://' remote autotuning config URLs are supported")
parsed = urlparse(remote_config_str)
# parsed.username, parsed.password, parsed.hostname, parsed.port, parsed.query
self.remote_user = parsed.username
self.remote_password = parsed.password
self.remote_ip = parsed.hostname
self.remote_port = parsed.port
if self.remote_port is None:
self.remote_port = 22
Comment thread
coderabbitai[bot] marked this conversation as resolved.
# Parse query options into a dict
self.remote_options = {
k: v[0] if len(v) == 1 else v for k, v in parse_qs(parsed.query).items()
}
required_params = ["remote_exec_path", "remote_lib_path"]
missing = [p for p in required_params if p not in self.remote_options]
if missing:
raise ValueError(
f"Missing required query parameters in --remoteAutoTuningConfig: {missing}"
)
self.remote_bin_path = os.path.dirname(str(self.remote_options["remote_exec_path"]))
Comment thread
coderabbitai[bot] marked this conversation as resolved.
self.remote_lib_path = str(self.remote_options["remote_lib_path"])
try:
_check_for_trtexec(min_version="10.15")
self.logger.debug("TensorRT Python API version >= 10.15 detected")
Expand All @@ -215,6 +280,7 @@ def __init__(
"Remote autotuning requires '--safe' to be set. Adding it to trtexec arguments."
)
self.trtexec_args.append("--safe")
self.is_safe = True
if "--skipInference" not in trtexec_args:
self.logger.warning(
"Remote autotuning requires '--skipInference' to be set. Adding it to trtexec arguments."
Expand All @@ -228,6 +294,7 @@ def __init__(
trtexec_args = [
arg for arg in trtexec_args if "--remoteAutoTuningConfig" not in arg
]
self.is_safe = "--safe" in trtexec_args
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
self._base_cmd.extend(trtexec_args)

self.logger.debug(f"Base command template: {' '.join(self._base_cmd)}")
Expand Down Expand Up @@ -269,7 +336,9 @@ def run(

cmd = [*self._base_cmd, f"--onnx={model_path}"]
self.logger.debug(f"Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True) # nosec B603
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=self.remote_timeout_sec
) # nosec B603
self._write_log_file(
log_file,
"\n".join(
Expand All @@ -292,10 +361,71 @@ def run(
self.logger.error(f"trtexec failed with return code {result.returncode}")
self.logger.error(f"stderr: {result.stderr}")
return float("inf")

if not (match := re.search(self.latency_pattern, result.stdout, re.IGNORECASE)):
self.logger.warning("Could not parse median latency from trtexec output")
self.logger.debug(f"trtexec stdout:\n{result.stdout}")
latency_pattern = std_pattern
if self.has_remote_config and self.is_safe:
ssh_pass = []
if self.remote_password:
ssh_pass.append("sshpass")
ssh_pass.append("-p")
ssh_pass.append(self.remote_password)
# need to push the model to the device and use trtexec_safe to run
scp_cmd = [
"scp",
f"-P{self.remote_port}",
self.engine_path,
f"{self.remote_user}@{self.remote_ip}:{shlex.quote(self.remote_engine_path)}",
]
scp_cmd = ssh_pass + scp_cmd
result = subprocess.run(
scp_cmd, capture_output=True, text=True, timeout=self.remote_timeout_sec
) # nosec B603
if result.returncode != 0:
self.logger.error(f"Failed to push engine to remote device: {result.stderr}")
return float("inf")
ld_path = f"LD_LIBRARY_PATH={shlex.quote(self.remote_lib_path)}:$LD_LIBRARY_PATH"
trt_path = f"{os.path.join(self.remote_bin_path, 'trtexec_safe')}"
trtexec_safe_cmd = [
"ssh",
"-p",
f"{self.remote_port}",
f"{self.remote_user}@{self.remote_ip}",
f"{ld_path} {shlex.quote(trt_path)} --loadEngine={shlex.quote(self.remote_engine_path)}",
]
trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd
result = subprocess.run(
trtexec_safe_cmd,
capture_output=True,
text=True,
timeout=self.remote_timeout_sec,
) # nosec B603
latency_pattern = safe_pattern
if result.returncode != 0:
# fallback and try trtexec with "--safe"
trt_path = f"{os.path.join(self.remote_bin_path, 'trtexec')}"
trtexec_safe_cmd = [
"ssh",
"-p",
f"{self.remote_port}",
f"{self.remote_user}@{self.remote_ip}",
f"{ld_path} {shlex.quote(trt_path)} --safe --loadEngine={shlex.quote(self.remote_engine_path)}",
]
trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd

result = subprocess.run(
trtexec_safe_cmd,
capture_output=True,
text=True,
timeout=self.remote_timeout_sec,
) # nosec B603
latency_pattern = std_pattern
if result.returncode != 0:
self.logger.error(
f"Failed to run trtexec_safe or trtexec with '--safe'\n {result.stdout}"
)
return float("inf")
if not (match := re.search(latency_pattern, result.stdout, re.IGNORECASE)):
self.logger.warning(f"trtexec stdout:\n{result.stdout}")
self.logger.error("Could not parse median latency from trtexec output")
return float("inf")
latency = float(match.group(1))
self.logger.info(f"TrtExec benchmark (median): {latency:.2f} ms")
Expand Down