From ada07c9cb2e389d099209d952b59adf9ee621535 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 30 Dec 2025 17:57:32 +0530 Subject: [PATCH] Add Windows/clang-cl support for AMD HIP backend - Use LoadLibrary/GetProcAddress on Windows instead of dlopen/dlsym - Use rocm_sdk.find_libraries() to locate amdhip64 - Add platform-specific macros for dynamic library loading - Escape Windows paths for C string embedding - Treat clang-cl as MSVC-compatible compiler in build.py - Fix NamedTemporaryFile handling on Windows in compiler.py --- python/triton/runtime/build.py | 16 ++- third_party/amd/backend/compiler.py | 39 ++++- third_party/amd/backend/driver.c | 33 ++++- third_party/amd/backend/driver.py | 211 ++++++++++++++++++++-------- 4 files changed, 232 insertions(+), 67 deletions(-) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 2a851c773469..66a27e3ec780 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -51,6 +51,11 @@ def is_msvc(cc): return cc == "cl" or cc == "cl.exe" +def is_clang_cl(cc): + cc = os.path.basename(cc).lower() + return cc == "clang-cl" or cc == "clang-cl.exe" + + def is_clang(cc): cc = os.path.basename(cc).lower() return cc == "clang" or cc == "clang.exe" @@ -58,9 +63,14 @@ def is_clang(cc): def _cc_cmd(cc: str, src: str, out: str, include_dirs: list[str], library_dirs: list[str], libraries: list[str], ccflags: list[str]) -> list[str]: - if is_msvc(cc): + if is_msvc(cc) or is_clang_cl(cc): out_base = os.path.splitext(out)[0] - cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/std:c11", "/wd4819"] + cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/wd4819"] + # clang-cl doesn't support /std:c11, use -std=c11 instead + if is_clang_cl(cc): + cc_cmd += ["-std=c11"] + else: + cc_cmd += ["/std:c11"] cc_cmd += [f"/I{dir}" for dir in include_dirs if dir is not None] cc_cmd += [f"/Fo{out_base + '.obj'}"] cc_cmd += ["/link"] @@ -110,7 +120,7 @@ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_di if sysconfig.get_config_var("Py_GIL_DISABLED"): version += "t" libraries = libraries + [f"python{version}"] - if is_msvc(cc): + if is_msvc(cc) or is_clang_cl(cc): _, msvc_winsdk_inc_dirs, msvc_winsdk_lib_dirs = find_msvc_winsdk() include_dirs = include_dirs + msvc_winsdk_inc_dirs library_dirs = library_dirs + msvc_winsdk_lib_dirs diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 550bd561bcb2..90521bc1021a 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -5,12 +5,17 @@ from typing import Any, Dict, Tuple from types import ModuleType import hashlib +import os +import platform import tempfile import re import functools import warnings from pathlib import Path +def _is_windows(): + return platform.system() == 'Windows' + def get_min_dot_size(target: GPUTarget): # We fallback to use FMA and cast arguments if certain configurations is @@ -437,13 +442,35 @@ def make_hsaco(src, metadata, options): if knobs.compilation.enable_asan: target_features = '+xnack' hsaco = amd.assemble_amdgcn(src, options.arch, target_features) - with tempfile.NamedTemporaryFile() as tmp_out: - with tempfile.NamedTemporaryFile() as tmp_in: - with open(tmp_in.name, "wb") as fd_in: - fd_in.write(hsaco) + # On Windows, NamedTemporaryFile cannot be reopened while open, so we + # use delete=False and manually clean up. + if _is_windows(): + tmp_in = tempfile.NamedTemporaryFile(delete=False, suffix='.o') + tmp_out = tempfile.NamedTemporaryFile(delete=False, suffix='.hsaco') + try: + tmp_in.write(hsaco) + tmp_in.close() + tmp_out.close() amd.link_hsaco(tmp_in.name, tmp_out.name) - with open(tmp_out.name, "rb") as fd_out: - ret = fd_out.read() + with open(tmp_out.name, "rb") as fd_out: + ret = fd_out.read() + finally: + try: + os.unlink(tmp_in.name) + except OSError: + pass + try: + os.unlink(tmp_out.name) + except OSError: + pass + else: + with tempfile.NamedTemporaryFile() as tmp_out: + with tempfile.NamedTemporaryFile() as tmp_in: + with open(tmp_in.name, "wb") as fd_in: + fd_in.write(hsaco) + amd.link_hsaco(tmp_in.name, tmp_out.name) + with open(tmp_out.name, "rb") as fd_out: + ret = fd_out.read() return ret def add_stages(self, stages, options, language): diff --git a/third_party/amd/backend/driver.c b/third_party/amd/backend/driver.c index c84727d6da3c..b83c5be28418 100644 --- a/third_party/amd/backend/driver.c +++ b/third_party/amd/backend/driver.c @@ -3,11 +3,42 @@ #include #define PY_SSIZE_T_CLEAN #include -#include #include #include #include +#ifdef _WIN32 +#include +// Windows compatibility layer for dlopen/dlsym/dlclose/dlerror +#define RTLD_NOW 0 +#define RTLD_LAZY 0 +#define RTLD_LOCAL 0 +static char dlerror_buf[512]; +static inline void *dlopen(const char *filename, int flags) { + (void)flags; + HMODULE h = LoadLibraryA(filename); + if (!h) { + snprintf(dlerror_buf, sizeof(dlerror_buf), "LoadLibrary failed with error %lu", GetLastError()); + } + return (void *)h; +} +static inline void *dlsym(void *handle, const char *symbol) { + void *p = (void *)GetProcAddress((HMODULE)handle, symbol); + if (!p) { + snprintf(dlerror_buf, sizeof(dlerror_buf), "GetProcAddress failed for %s with error %lu", symbol, GetLastError()); + } + return p; +} +static inline int dlclose(void *handle) { + return FreeLibrary((HMODULE)handle) ? 0 : -1; +} +static inline const char *dlerror(void) { + return dlerror_buf[0] ? dlerror_buf : NULL; +} +#else +#include +#endif + // The list of paths to search for the HIP runtime library. The caller Python // code should substitute the search path placeholder. static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"}; diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 3418ce26f10a..095b4042a594 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -13,9 +13,48 @@ dirname = os.path.dirname(os.path.realpath(__file__)) include_dirs = [os.path.join(dirname, "include")] +import platform + +def _is_windows(): + return platform.system() == 'Windows' + +def _get_rocm_sdk_root(): + """Get ROCm SDK root path using rocm-sdk command or environment variables.""" + # Try rocm-sdk path --root first (for Windows ROCm SDK) + try: + result = subprocess.check_output(["rocm-sdk", "path", "--root"], stderr=subprocess.DEVNULL) + root = result.decode().strip() + if root and os.path.isdir(root): + return root + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + # Fall back to environment variables + for env_var in ["ROCM_HOME", "HIP_PATH", "ROCM_PATH"]: + path = os.environ.get(env_var, "") + if path and os.path.isdir(path): + return path + return None + + +def _get_hip_library_from_rocm_sdk(): + """Get the amdhip64 library path using rocm_sdk.find_libraries.""" + try: + import rocm_sdk + paths = rocm_sdk.find_libraries("amdhip64") + if paths: + return str(paths[0]) + except (ImportError, ModuleNotFoundError, FileNotFoundError): + pass + return None + +# Add HIP runtime headers from ROCm SDK if available +_rocm_root = _get_rocm_sdk_root() +if _rocm_root and os.path.isdir(os.path.join(_rocm_root, "include")): + include_dirs.append(os.path.join(_rocm_root, "include")) + def _find_already_mmapped_dylib_on_linux(lib_name): - import platform if platform.system() != 'Linux': return None @@ -64,7 +103,7 @@ def callback(info, size, data): @functools.lru_cache() def _get_path_to_hip_runtime_dylib(): - lib_name = "libamdhip64.so" + lib_name = "amdhip64.dll" if _is_windows() else "libamdhip64.so" # If we are told explicitly what HIP runtime dynamic library to use, obey that. if env_libhip_path := knobs.amd.libhip_path: @@ -72,12 +111,18 @@ def _get_path_to_hip_runtime_dylib(): return env_libhip_path raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}") - # If the shared object is already mmapped to address space, use it. - mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name) - if mmapped_path: - if os.path.exists(mmapped_path): - return mmapped_path - raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}") + # Try rocm_sdk.find_libraries first - this is the preferred method + rocm_sdk_path = _get_hip_library_from_rocm_sdk() + if rocm_sdk_path: + return rocm_sdk_path + + # If the shared object is already mmapped to address space, use it (Linux only). + if not _is_windows(): + mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name) + if mmapped_path: + if os.path.exists(mmapped_path): + return mmapped_path + raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}") paths = [] @@ -101,10 +146,15 @@ def _get_path_to_hip_runtime_dylib(): return path paths.append(path) - # Then try to see if developer provides a HIP runtime dynamic library using LD_LIBARAY_PATH. - env_ld_library_path = os.getenv("LD_LIBRARY_PATH") - if env_ld_library_path: - for d in env_ld_library_path.split(":"): + # Then try to see if developer provides a HIP runtime dynamic library using LD_LIBRARY_PATH (Linux) or PATH (Windows). + if _is_windows(): + env_path = os.getenv("PATH", "") + path_sep = ";" + else: + env_path = os.getenv("LD_LIBRARY_PATH", "") + path_sep = ":" + if env_path: + for d in env_path.split(path_sep): f = os.path.join(d, lib_name) if os.path.exists(f): return f @@ -113,47 +163,57 @@ def _get_path_to_hip_runtime_dylib(): # HIP_PATH should point to HIP SDK root if set env_hip_path = os.getenv("HIP_PATH") if env_hip_path: - hip_lib_path = os.path.join(env_hip_path, "lib", lib_name) + # On Windows, DLLs are in bin; on Linux, .so files are in lib + lib_subdir = "bin" if _is_windows() else "lib" + hip_lib_path = os.path.join(env_hip_path, lib_subdir, lib_name) if os.path.exists(hip_lib_path): return hip_lib_path paths.append(hip_lib_path) - # if available, `hipconfig --path` prints the HIP SDK root + # Try rocm-sdk path --root (Windows ROCm SDK) or hipconfig --path (Linux) + lib_subdir = "bin" if _is_windows() else "lib" try: - hip_root = subprocess.check_output(["hipconfig", "--path"]).decode().strip() - if hip_root: - hip_lib_path = os.path.join(hip_root, "lib", lib_name) - if os.path.exists(hip_lib_path): - return hip_lib_path - paths.append(hip_lib_path) + if _is_windows(): + rocm_root = subprocess.check_output(["rocm-sdk", "path", "--root"], stderr=subprocess.DEVNULL).decode().strip() + else: + rocm_root = subprocess.check_output(["hipconfig", "--path"]).decode().strip() + if rocm_root: + rocm_lib_path = os.path.join(rocm_root, lib_subdir, lib_name) + if os.path.exists(rocm_lib_path): + return rocm_lib_path + paths.append(rocm_lib_path) except (subprocess.CalledProcessError, FileNotFoundError): - # hipconfig may not be available + # rocm-sdk or hipconfig may not be available pass # ROCm lib dir based on env var - env_rocm_path = os.getenv("ROCM_PATH") + env_rocm_path = os.getenv("ROCM_PATH") or os.getenv("ROCM_HOME") if env_rocm_path: - rocm_lib_path = os.path.join(env_rocm_path, "lib", lib_name) + rocm_lib_path = os.path.join(env_rocm_path, lib_subdir, lib_name) if os.path.exists(rocm_lib_path): return rocm_lib_path paths.append(rocm_lib_path) - # Afterwards try to search the loader dynamic library resolution paths. - libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore") - # each line looks like the following: - # libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6 - # libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so - locs = [line.split()[-1] for line in libs.splitlines() if line.strip().endswith(lib_name)] - for loc in locs: - if os.path.exists(loc): - return loc - paths.append(loc) - - # As a last resort, guess if we have it in some common installation path. - common_install_path = os.path.join('/opt/rocm/lib/', lib_name) - if os.path.exists(common_install_path): - return common_install_path - paths.append(common_install_path) + # Afterwards try to search the loader dynamic library resolution paths (Linux only). + if not _is_windows(): + try: + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore") + # each line looks like the following: + # libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6 + # libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so + locs = [line.split()[-1] for line in libs.splitlines() if line.strip().endswith(lib_name)] + for loc in locs: + if os.path.exists(loc): + return loc + paths.append(loc) + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + # As a last resort on Linux, guess if we have it in some common installation path. + common_install_path = os.path.join('/opt/rocm/lib/', lib_name) + if os.path.exists(common_install_path): + return common_install_path + paths.append(common_install_path) raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}") @@ -167,11 +227,13 @@ def __new__(cls): def __init__(self): libhip_path = _get_path_to_hip_runtime_dylib() + # Escape backslashes for C string embedding + libhip_path_escaped = libhip_path.replace("\\", "\\\\") src = Path(os.path.join(dirname, "driver.c")).read_text() # Just do a simple search and replace here instead of templates or format strings. # This way we don't need to escape-quote C code curly brackets and we can replace # exactly once. - src = src.replace('/*py_libhip_search_path*/', libhip_path, 1) + src = src.replace('/*py_libhip_search_path*/', libhip_path_escaped, 1) mod = compile_module_from_src(src=src, name="hip_utils", include_dirs=include_dirs) self.load_binary = mod.load_binary self.get_device_properties = mod.get_device_properties @@ -318,21 +380,59 @@ def format_of(ty): ] libhip_path = _get_path_to_hip_runtime_dylib() + # Escape backslashes for C string embedding + libhip_path = libhip_path.replace("\\", "\\\\") # generate glue code params = list(range(len(signature))) params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"] params.append("&global_scratch") params.append("&profile_scratch") - src = f""" + + # Platform-specific includes and dlopen/dlsym macros + if _is_windows(): + platform_includes = """ #define __HIP_PLATFORM_AMD__ #include #include #include -#include +#include #include -#include +// Windows compatibility layer for dlopen/dlsym/dlclose +static char _dlerror_buf[512]; +static inline void *dlopen(const char *filename, int flags) { + (void)flags; + HMODULE h = LoadLibraryA(filename); + if (!h) { + snprintf(_dlerror_buf, sizeof(_dlerror_buf), "LoadLibrary failed with error %lu", GetLastError()); + } + return (void *)h; +} +static inline void *dlsym(void *handle, const char *symbol) { + void *p = (void *)GetProcAddress((HMODULE)handle, symbol); + if (!p) { + snprintf(_dlerror_buf, sizeof(_dlerror_buf), "GetProcAddress failed for %s with error %lu", symbol, GetLastError()); + } + return p; +} +static inline int dlclose(void *handle) { return FreeLibrary((HMODULE)handle) ? 0 : -1; } +static inline const char *dlerror(void) { return _dlerror_buf[0] ? _dlerror_buf : NULL; } +#define RTLD_LAZY 0 +#define RTLD_LOCAL 0 +#define RTLD_NOLOAD 0 +""" + else: + platform_includes = """ +#define __HIP_PLATFORM_AMD__ +#include +#include +#include +#include +#include +""" + + src = f"""{platform_includes} // The list of paths to search for the HIP runtime library. The caller Python // code should substitute the search path placeholder. static const char *hipLibSearchPaths[] = {{"{libhip_path}"}}; @@ -370,22 +470,19 @@ def format_of(ty): static struct HIPSymbolTable hipSymbolTable; bool initSymbolTable() {{ - // Use the HIP runtime library loaded into the existing process if it exits. - void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD); - - // Otherwise, go through the list of search paths to dlopen the first HIP - // driver library. - if (!lib) {{ - int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]); - for (int i = 0; i < n; ++i) {{ - void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL); - if (handle) {{ - lib = handle; - }} + void *lib = NULL; + + // Go through the list of search paths to open the first HIP driver library. + int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]); + for (int i = 0; i < n; ++i) {{ + void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL); + if (handle) {{ + lib = handle; + break; }} }} if (!lib) {{ - PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so"); + PyErr_SetString(PyExc_RuntimeError, "cannot open HIP runtime library"); return false; }} @@ -399,7 +496,7 @@ def format_of(ty): error = dlerror(); if (error) {{ PyErr_SetString(PyExc_RuntimeError, - "cannot query 'hipGetProcAddress' from libamdhip64.so"); + "cannot query 'hipGetProcAddress' from HIP runtime library"); dlclose(lib); return false; }}