-
Notifications
You must be signed in to change notification settings - Fork 382
Use trtexec_safe on safety platforms when using remoteAutoTuning #1378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
d816d70
19a4961
e45b08e
bb04422
8b779b3
d95b571
14dc293
144df91
be947cf
28e4e8e
3660092
2f75b5d
2512076
110f3d7
8a292fe
547c2ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,13 +30,15 @@ | |
| import importlib.util | ||
| import os | ||
| import re | ||
| import shlex | ||
| import shutil | ||
| import subprocess # nosec B404 | ||
| import tempfile | ||
| import time | ||
| from abc import ABC, abstractmethod | ||
| from pathlib import Path | ||
| from typing import Any | ||
| from urllib.parse import parse_qs, urlparse | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
@@ -145,6 +147,13 @@ def _write_log_file(self, file: Path | str | None, content: str) -> None: | |
| self.logger.warning(f"Failed to save logs to {file}: {e}") | ||
|
|
||
|
|
||
| safe_pattern = ( | ||
| r"\[\d{2}/\d{2}/\d{4}-\d{2}:\d{2}:\d{2}\]\s+\[I\]\s+" | ||
| r"Average over \d+ runs - GPU latency:\s*([\d.]+)\s*ms" | ||
| ) | ||
| std_pattern = r"\[I\]\s+GPU Compute Time:.*?median\s*=\s*([\d.]+)\s*ms" | ||
|
|
||
|
|
||
| class TrtExecBenchmark(Benchmark): | ||
| """TensorRT benchmark using trtexec command-line tool. | ||
|
|
||
|
|
@@ -183,7 +192,6 @@ def __init__( | |
| self.temp_model_path = os.path.join(self.temp_dir, "temp_model.onnx") | ||
| self.logger.debug(f"Created temporary engine directory: {self.temp_dir}") | ||
| self.logger.debug(f"Temporary model path: {self.temp_model_path}") | ||
| self.latency_pattern = r"\[I\]\s+Latency:.*?median\s*=\s*([\d.]+)\s*ms" | ||
|
|
||
| self._base_cmd = [ | ||
| self.trtexec_path, | ||
|
|
@@ -204,9 +212,65 @@ def __init__( | |
| self.logger.debug(f"Added plugin library: {plugin_path}") | ||
|
|
||
| trtexec_args = self.trtexec_args or [] | ||
| has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args) | ||
|
|
||
| if has_remote_config: | ||
| self.has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args) | ||
| self.remote_ip: str | None = None | ||
| self.remote_port: int = 22 | ||
| self.remote_user: str = "root" | ||
| self.remote_password: str = "" | ||
| self.remote_engine_path: str = "trtexec_benchmark_model.trt" | ||
| self.remote_bin_path: str = "trtexec" | ||
|
|
||
| if self.has_remote_config: | ||
| remote_config = [arg for arg in trtexec_args if "--remoteAutoTuningConfig" in arg] | ||
| if len(remote_config) != 1: | ||
| raise ValueError("Exactly one --remoteAutoTuningConfig argument is required") | ||
| # Parse --remoteAutoTuningConfig argument, which may be given as: | ||
| # ('--remoteAutoTuningConfig=ssh://user:pass@host:port?...') or | ||
| # ('--remoteAutoTuningConfig', 'ssh://user:pass@host:port?...') | ||
| # | ||
| # The logic: find the arg starting with '--remoteAutoTuningConfig' | ||
| # If formatted as '--remoteAutoTuningConfig=...', split off the '=' | ||
| # Otherwise, grab the next argument. | ||
| config_arg_value: str | None = None | ||
| for i, arg in enumerate(trtexec_args): | ||
| if arg.startswith("--remoteAutoTuningConfig"): | ||
| if arg == "--remoteAutoTuningConfig": | ||
| # Value should be the next argument | ||
| if i + 1 < len(trtexec_args): | ||
| config_arg_value = trtexec_args[i + 1] | ||
| else: | ||
| raise ValueError("Missing value for --remoteAutoTuningConfig") | ||
| elif arg.startswith("--remoteAutoTuningConfig="): | ||
| config_arg_value = arg.split("=", 1)[1] | ||
| else: | ||
| raise ValueError(f"Malformed --remoteAutoTuningConfig argument: {arg}") | ||
| break | ||
| if not config_arg_value: | ||
| raise ValueError("Could not parse --remoteAutoTuningConfig argument") | ||
| remote_config_str: str = config_arg_value | ||
|
|
||
| if not remote_config_str.startswith("ssh://"): | ||
| raise ValueError("Only 'ssh://' remote autotuning config URLs are supported") | ||
| parsed = urlparse(remote_config_str) | ||
| # parsed.username, parsed.password, parsed.hostname, parsed.port, parsed.query | ||
| self.remote_user = parsed.username | ||
| self.remote_password = parsed.password | ||
| self.remote_ip = parsed.hostname | ||
| self.remote_port = parsed.port | ||
| if self.remote_port is None: | ||
| self.remote_port = 22 | ||
| # Parse query options into a dict | ||
| self.remote_options = { | ||
| k: v[0] if len(v) == 1 else v for k, v in parse_qs(parsed.query).items() | ||
| } | ||
| required_params = ["remote_exec_path", "remote_lib_path"] | ||
| missing = [p for p in required_params if p not in self.remote_options] | ||
| if missing: | ||
| raise ValueError( | ||
| f"Missing required query parameters in --remoteAutoTuningConfig: {missing}" | ||
| ) | ||
| self.remote_bin_path = os.path.dirname(str(self.remote_options["remote_exec_path"])) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| self.remote_lib_path = str(self.remote_options["remote_lib_path"]) | ||
| try: | ||
| _check_for_trtexec(min_version="10.15") | ||
| self.logger.debug("TensorRT Python API version >= 10.15 detected") | ||
|
|
@@ -215,6 +279,7 @@ def __init__( | |
| "Remote autotuning requires '--safe' to be set. Adding it to trtexec arguments." | ||
| ) | ||
| self.trtexec_args.append("--safe") | ||
| self.is_safe = True | ||
| if "--skipInference" not in trtexec_args: | ||
| self.logger.warning( | ||
| "Remote autotuning requires '--skipInference' to be set. Adding it to trtexec arguments." | ||
|
|
@@ -228,6 +293,7 @@ def __init__( | |
| trtexec_args = [ | ||
| arg for arg in trtexec_args if "--remoteAutoTuningConfig" not in arg | ||
| ] | ||
| self.is_safe = "--safe" in trtexec_args | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
| self._base_cmd.extend(trtexec_args) | ||
|
|
||
| self.logger.debug(f"Base command template: {' '.join(self._base_cmd)}") | ||
|
|
@@ -292,10 +358,59 @@ def run( | |
| self.logger.error(f"trtexec failed with return code {result.returncode}") | ||
| self.logger.error(f"stderr: {result.stderr}") | ||
| return float("inf") | ||
| latency_pattern = std_pattern | ||
| if self.has_remote_config and self.is_safe: | ||
| ssh_pass = [] | ||
| if self.remote_password: | ||
| ssh_pass.append("sshpass") | ||
| ssh_pass.append("-p") | ||
| ssh_pass.append(self.remote_password) | ||
| # need to push the model to the device and use trtexec_safe to run | ||
| scp_cmd = [ | ||
| "scp", | ||
| f"-P{self.remote_port}", | ||
| self.engine_path, | ||
| f"{self.remote_user}@{self.remote_ip}:{shlex.quote(self.remote_engine_path)}", | ||
| ] | ||
| scp_cmd = ssh_pass + scp_cmd | ||
| result = subprocess.run(scp_cmd, capture_output=True, text=True) # nosec B603 | ||
| if result.returncode != 0: | ||
| self.logger.error(f"Failed to push engine to remote device: {result.stderr}") | ||
| return float("inf") | ||
| ld_path = f"LD_LIBRARY_PATH={shlex.quote(self.remote_lib_path)}:$LD_LIBRARY_PATH" | ||
| trt_path = f"{os.path.join(self.remote_bin_path, 'trtexec_safe')}" | ||
| trtexec_safe_cmd = [ | ||
| "ssh", | ||
| "-p", | ||
| f"{self.remote_port}", | ||
| f"{self.remote_user}@{self.remote_ip}", | ||
| f"{ld_path} {shlex.quote(trt_path)} --loadEngine={shlex.quote(self.remote_engine_path)}", | ||
| ] | ||
| trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd | ||
| result = subprocess.run(trtexec_safe_cmd, capture_output=True, text=True) # nosec B603 | ||
|
Comment on lines
+363
to
+392
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CRITICAL: Multiple security violations and incorrect SSH implementation. Several issues violate the security coding guidelines:
Proposed fix direction if self.has_remote_config and self.is_safe:
# need to push the model to the device and use trtexec_safe to run
+ if self.remote_port is None:
+ self.logger.error("Remote port is required for SCP/SSH")
+ return float("inf")
+ if self.remote_password:
+ self.logger.warning(
+ "Password-based SSH auth is not recommended. Consider using SSH keys."
+ )
scp_cmd = [
- "scp",
- f"-P{self.remote_port}",
+ "sshpass", "-p", self.remote_password,
+ "scp", "-P", str(self.remote_port),
self.engine_path,
f"{self.remote_user}@{self.remote_ip}:{self.remote_engine_path}",
]
- result = subprocess.run(scp_cmd) # nosec B603
+ result = subprocess.run(scp_cmd, capture_output=True, text=True)
# ... similar changes for SSH commandConsider using SSH key-based authentication and passing command arguments separately to avoid shell interpretation on the remote side. As per coding guidelines: "Any use of '# nosec' comments to bypass Bandit security checks is not allowed." 🤖 Prompt for AI Agents
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Partially addressed.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
✅ Fixed:
❌ Still remaining: 1. 2. Command injection risk (line ~376 in current code): f"{ld_path} {trt_path} --loadEngine={self.remote_engine_path}",If any of import shlex
remote_cmd = (
f"LD_LIBRARY_PATH={shlex.quote(str(self.remote_options['remote_lib_path']))}:$LD_LIBRARY_PATH"
f" {shlex.quote(trt_path)}"
f" --loadEngine={shlex.quote(str(self.remote_engine_path))}"
)3. Missing if self.remote_port is None:
self.remote_port = 22 # default SSH portor raise a 🧠 Learnings used
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Partially addressed.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
✅ Now fixed:
❌ Still remaining:
🧠 Learnings used |
||
| latency_pattern = safe_pattern | ||
| if result.returncode != 0: | ||
| # fallback and try trtexec with "--safe" | ||
| trt_path = f"{os.path.join(self.remote_bin_path, 'trtexec')}" | ||
| trtexec_safe_cmd = [ | ||
| "ssh", | ||
| "-p", | ||
| f"{self.remote_port}", | ||
| f"{self.remote_user}@{self.remote_ip}", | ||
| f"{ld_path} {shlex.quote(trt_path)} --safe --loadEngine={shlex.quote(self.remote_engine_path)}", | ||
| ] | ||
| trtexec_safe_cmd = ssh_pass + trtexec_safe_cmd | ||
|
|
||
| if not (match := re.search(self.latency_pattern, result.stdout, re.IGNORECASE)): | ||
| self.logger.warning("Could not parse median latency from trtexec output") | ||
| self.logger.debug(f"trtexec stdout:\n{result.stdout}") | ||
| result = subprocess.run(trtexec_safe_cmd, capture_output=True, text=True) # nosec B603 | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| latency_pattern = std_pattern | ||
| if result.returncode != 0: | ||
| self.logger.error( | ||
| f"Failed to run trtexec_safe or trtexec with '--safe'\n {result.stdout}" | ||
| ) | ||
| return float("inf") | ||
| if not (match := re.search(latency_pattern, result.stdout, re.IGNORECASE)): | ||
| self.logger.warning(f"trtexec stdout:\n{result.stdout}") | ||
| self.logger.error("Could not parse median latency from trtexec output") | ||
| return float("inf") | ||
| latency = float(match.group(1)) | ||
| self.logger.info(f"TrtExec benchmark (median): {latency:.2f} ms") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.