diff --git a/Compiling.md b/Compiling.md index 74932a1f4..df1a9a4ae 100644 --- a/Compiling.md +++ b/Compiling.md @@ -151,3 +151,51 @@ As also mentioned in the instructions below but repeated here for visibility, if * Pre-trained neural nets are available at [the main training website](https://katagotraining.org/). * You will probably want to edit `configs/gtp_example.cfg` (see "Tuning for Performance" above). * If using OpenCL, you will want to verify that KataGo is picking up the correct device when you run it (e.g. some systems may have both an Intel CPU OpenCL and GPU OpenCL, if KataGo appears to pick the wrong one, you can correct this by specifying `openclGpuToUse` in `configs/gtp_example.cfg`). + +## ONNX Runtime Backend +The ONNX backend uses [ONNX Runtime](https://onnxruntime.ai/) for neural net inference. It supports both standard `.bin.gz` model files (building the ONNX graph internally from the model weights) and raw `.onnx` model files. On macOS, it can use the CoreML execution provider for hardware-accelerated inference on Apple Silicon. On Windows/Linux, it can use the CUDA or TensorRT execution providers for NVIDIA GPU acceleration. + + * Requirements + * ONNX Runtime built from source. See the [ONNX Runtime build instructions](https://onnxruntime.ai/docs/build/). + * On macOS with CoreML support, build ONNX Runtime with `--use_coreml`: + ``` + python3 tools/ci_build/build.py --build_dir build/MacOS \ + --config RelWithDebInfo --build_shared_lib --parallel \ + --compile_no_warning_as_error --skip_submodule_sync \ + --cmake_generator Ninja --use_coreml + ``` + * On Windows/Linux with CUDA support, build ONNX Runtime with `--use_cuda`: + ``` + python3 tools/ci_build/build.py --build_dir build/Linux \ + --config RelWithDebInfo --build_shared_lib --parallel \ + --compile_no_warning_as_error --skip_submodule_sync \ + --use_cuda --cudnn_home /usr/local/cuda --cuda_home /usr/local/cuda + ``` + * For TensorRT support, build with `--use_tensorrt` (also enables CUDA): + ``` + python3 tools/ci_build/build.py --build_dir build/Linux \ + --config RelWithDebInfo --build_shared_lib --parallel \ + --compile_no_warning_as_error --skip_submodule_sync \ + --use_tensorrt --tensorrt_home /usr/local/TensorRT \ + --use_cuda --cudnn_home /usr/local/cuda --cuda_home /usr/local/cuda + ``` + * zlib, libzip (same as other backends). + * Compile using CMake in the cpp directory: + * `cd KataGo/cpp` + * ``` + cmake . -DUSE_BACKEND=ONNX \ + -DONNXRUNTIME_ROOT=/path/to/onnxruntime \ + -DONNXRUNTIME_BUILD_DIR=/path/to/onnxruntime/build/MacOS/RelWithDebInfo + ``` + * `ONNXRUNTIME_ROOT` - path to the ONNX Runtime source/install root directory. + * `ONNXRUNTIME_BUILD_DIR` - path to the ONNX Runtime build output directory (e.g. `build/MacOS/RelWithDebInfo`). + * CoreML support is automatically enabled when building on Apple platforms. + * `make` + * Done! You should now have a compiled `katago` executable in your working directory. + * Pre-trained neural nets are available at [the main training website](https://katagotraining.org/). + * You will probably want to edit `configs/gtp_example.cfg` (see "Tuning for Performance" above). + * Using raw `.onnx` model files: + * You can pass an `.onnx` file directly as the `-model` argument instead of a `.bin.gz` file. + * Input/output node names and model version are auto-detected. If auto-detection fails, override them via config keys documented in the ONNX section of `configs/gtp_example.cfg`. + * Selecting the execution provider: + * Set `onnxProvider = cpu` (default), `onnxProvider = coreml` (macOS only), `onnxProvider = cuda`, or `onnxProvider = tensorrt` in your config file. diff --git a/LICENSE b/LICENSE index 9528bccbe..3e9a0f11f 100644 --- a/LICENSE +++ b/LICENSE @@ -5,6 +5,12 @@ and/or files, see the individual readmes and/or license files for each one withi subdirectories within cpp/external. Additionally, cpp/core/sha2.cpp derives from another piece of external code and embeds its own license within that file. +When built with the ONNX backend, this software links dynamically against ONNX Runtime (MIT License, +https://github.com/microsoft/onnxruntime) and its transitive dependencies including ONNX (MIT License, +https://github.com/onnx/onnx) and Protocol Buffers (BSD 3-Clause License, +https://github.com/protocolbuffers/protobuf). These libraries are not distributed with this +repository; see their respective repositories for license details. + Aside from the above, the license for all OTHER content in this repo is as follows: ---------------------------------------- diff --git a/README.md b/README.md index ce7e87b97..18ab693e3 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ * [GUIs](#guis) * [Windows and Linux](#windows-and-linux) * [MacOS](#macos) - * [OpenCL vs CUDA vs TensorRT vs Eigen](#opencl-vs-cuda-vs-tensorrt-vs-eigen) + * [OpenCL vs CUDA vs TensorRT vs Eigen vs ONNX](#opencl-vs-cuda-vs-tensorrt-vs-eigen-vs-onnx) * [How To Use](#how-to-use) * [Tuning for Performance](#tuning-for-performance) * [Common Questions and Issues](#common-questions-and-issues) @@ -84,8 +84,8 @@ The community also provides KataGo packages for [Homebrew](https://brew.sh) on M Use `brew install katago`. The latest config files and networks are installed in KataGo's `share` directory. Find them via `brew list --verbose katago`. A basic way to run katago will be `katago gtp -config $(brew list --verbose katago | grep 'gtp.*\.cfg') -model $(brew list --verbose katago | grep .gz | head -1)`. You should choose the Network according to the release notes here and customize the provided example config as with every other way of installing KataGo. -### OpenCL vs CUDA vs TensorRT vs Eigen -KataGo has four backends, OpenCL (GPU), CUDA (GPU), TensorRT (GPU), and Eigen (CPU). +### OpenCL vs CUDA vs TensorRT vs Eigen vs ONNX +KataGo has five backends, OpenCL (GPU), CUDA (GPU), TensorRT (GPU), Eigen (CPU), and ONNX (CPU/GPU). The quick summary is: * **To easily get something working, try OpenCL if you have any good or decent GPU.** @@ -93,12 +93,14 @@ The quick summary is: * Use Eigen with AVX2 if you don't have a GPU or if your GPU is too old/weak to work with OpenCL, and you just want a plain CPU KataGo. * Use Eigen without AVX2 if your CPU is old or on a low-end device that doesn't support AVX2. * The CUDA backend can work for NVIDIA GPUs with CUDA+CUDNN installed but is likely worse than TensorRT. + * Use ONNX if you want to load raw `.onnx` model files, use CoreML on Apple Silicon, or use CUDA/TensorRT via ONNX Runtime on Windows/Linux. More in detail: * OpenCL is a general GPU backend should be able to run with any GPUs or accelerators that support [OpenCL](https://en.wikipedia.org/wiki/OpenCL), including NVIDIA GPUs, AMD GPUs, as well CPU-based OpenCL implementations or things like Intel Integrated Graphics. This is the most general GPU version of KataGo and doesn't require a complicated install like CUDA does, so is most likely to work out of the box as long as you have a fairly modern GPU. **However, it also need to take some time when run for the very first time to tune itself.** For many systems, this will take 5-30 seconds, but on a few older/slower systems, may take many minutes or longer. Also, the quality of OpenCL implementations is sometimes inconsistent, particularly for Intel Integrated Graphics and for AMD GPUs that are older than several years, so it might not work for very old machines, as well as specific buggy newer AMD GPUs, see also [Issues with specific GPUs or GPU drivers](#issues-with-specific-gpus-or-gpu-drivers). * CUDA is a GPU backend specific to NVIDIA GPUs (it will not work with AMD or Intel or any other GPUs) and requires installing [CUDA](https://developer.nvidia.com/cuda-zone) and [CUDNN](https://developer.nvidia.com/cudnn) and a modern NVIDIA GPU. On most GPUs, the OpenCL implementation will actually beat NVIDIA's own CUDA/CUDNN at performance. The exception is for top-end NVIDIA GPUs that support FP16 and tensor cores, in which case sometimes one is better and sometimes the other is better. * TensorRT is similar to CUDA, but only uses NVIDIA's TensorRT framework to run the neural network with more optimized kernels. For modern NVIDIA GPUs, it should work whenever CUDA does and will usually be faster than CUDA or any other backend. * Eigen is a *CPU* backend that should work widely *without* needing a GPU or fancy drivers. Use this if you don't have a good GPU or really any GPU at all. It will be quite significantly slower than OpenCL or CUDA, but on a good CPU can still often get 10 to 20 playouts per second if using the smaller (15 or 20) block neural nets. Eigen can also be compiled with AVX2 and FMA support, which can provide a big performance boost for Intel and AMD CPUs from the last few years. However, it will not run at all on older CPUs (and possibly even some recent but low-power modern CPUs) that don't support these fancy vector instructions. + * ONNX is a backend that uses [ONNX Runtime](https://onnxruntime.ai/) for inference. It can load both standard `.bin.gz` model files and raw `.onnx` model files directly. It supports CPU inference out of the box, CoreML on macOS for Apple Silicon hardware acceleration, and CUDA/TensorRT execution providers for NVIDIA GPUs on Windows/Linux. Requires building ONNX Runtime from source as a prerequisite. See [Compiling KataGo](Compiling.md) for details. For **any** implementation, it's recommended that you also tune the number of threads used if you care about optimal performance, as it can make a factor of 2-3 difference in the speed. See "Tuning for Performance" below. However, if you mostly just want to get it working, then the default untuned settings should also be still reasonable. diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 8db79ca73..04b992bf4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -32,7 +32,7 @@ endif() set(BUILD_DISTRIBUTED 0 CACHE BOOL "Build with http support for contributing to distributed training") set(USE_BACKEND CACHE STRING "Neural net backend") string(TOUPPER "${USE_BACKEND}" USE_BACKEND) -set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA TENSORRT OPENCL EIGEN) +set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA TENSORRT OPENCL EIGEN ONNX) set(USE_TCMALLOC 0 CACHE BOOL "Use TCMalloc") set(NO_GIT_REVISION 0 CACHE BOOL "Disable embedding the git revision into the compiled exe") @@ -145,8 +145,14 @@ elseif(USE_BACKEND STREQUAL "EIGEN") set(NEURALNET_BACKEND_SOURCES neuralnet/eigenbackend.cpp ) +elseif(USE_BACKEND STREQUAL "ONNX") + message(STATUS "-DUSE_BACKEND=ONNX, using ONNX Runtime backend (loads .bin.gz natively).") + set(NEURALNET_BACKEND_SOURCES + neuralnet/onnxbackend.cpp + neuralnet/onnxmodelbuilder.cpp + ) elseif(USE_BACKEND STREQUAL "") - message(WARNING "${ColorBoldRed}WARNING: Using dummy neural net backend, intended for non-neural-net testing only, will fail on any code path requiring a neural net. To use neural net, specify -DUSE_BACKEND=CUDA or -DUSE_BACKEND=TENSORRT or -DUSE_BACKEND=OPENCL or -DUSE_BACKEND=EIGEN to compile with the respective backend.${ColorReset}") + message(WARNING "${ColorBoldRed}WARNING: Using dummy neural net backend, intended for non-neural-net testing only, will fail on any code path requiring a neural net. To use neural net, specify -DUSE_BACKEND=CUDA or -DUSE_BACKEND=TENSORRT or -DUSE_BACKEND=OPENCL or -DUSE_BACKEND=EIGEN or -DUSE_BACKEND=ONNX to compile with the respective backend.${ColorReset}") set(NEURALNET_BACKEND_SOURCES neuralnet/dummybackend.cpp) else() message(FATAL_ERROR "Unrecognized backend: " ${USE_BACKEND}) @@ -449,6 +455,41 @@ elseif(USE_BACKEND STREQUAL "EIGEN") endif() endif() endif() +elseif(USE_BACKEND STREQUAL "ONNX") + target_compile_definitions(katago PRIVATE USE_ONNX_BACKEND) + find_path(ONNXRUNTIME_INCLUDE_DIR onnxruntime_cxx_api.h + HINTS /opt/homebrew/opt/onnxruntime/include/onnxruntime + /opt/homebrew/include/onnxruntime /usr/local/include/onnxruntime + ) + if(NOT ONNXRUNTIME_INCLUDE_DIR) + message(FATAL_ERROR "Could not find onnxruntime headers. Install via: brew install onnxruntime") + endif() + target_include_directories(katago SYSTEM PRIVATE "${ONNXRUNTIME_INCLUDE_DIR}") + find_library(ONNXRUNTIME_LIB onnxruntime + HINTS /opt/homebrew/opt/onnxruntime/lib /opt/homebrew/lib /usr/local/lib + ) + if(NOT ONNXRUNTIME_LIB) + message(FATAL_ERROR "Could not find libonnxruntime. Install via: brew install onnxruntime") + endif() + find_path(ONNX_INCLUDE_DIR onnx/onnx-ml.pb.h + HINTS /opt/homebrew/opt/onnx/include /opt/homebrew/include /usr/local/include + ) + if(NOT ONNX_INCLUDE_DIR) + message(FATAL_ERROR "Could not find onnx headers. Install via: brew install onnx") + endif() + target_include_directories(katago PRIVATE "${ONNX_INCLUDE_DIR}") + target_compile_definitions(katago PRIVATE ONNX_ML) + find_library(ONNX_PROTO_LIB onnx_proto + HINTS /opt/homebrew/opt/onnx/lib /opt/homebrew/lib /usr/local/lib + ) + if(NOT ONNX_PROTO_LIB) + message(FATAL_ERROR "Could not find libonnx_proto. Install via: brew install onnx") + endif() + find_package(PkgConfig REQUIRED) + pkg_check_modules(PROTOBUF REQUIRED protobuf) + target_include_directories(katago PRIVATE ${PROTOBUF_INCLUDE_DIRS}) + target_link_directories(katago PRIVATE ${PROTOBUF_LIBRARY_DIRS}) + target_link_libraries(katago ${ONNXRUNTIME_LIB} ${ONNX_PROTO_LIB} ${PROTOBUF_LIBRARIES}) endif() if(USE_BIGGER_BOARDS_EXPENSIVE) diff --git a/cpp/README.md b/cpp/README.md index 1f5d8d21f..73b135420 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -9,13 +9,14 @@ Summary of source folders, in approximate dependency order, from lowest level to * `board.{cpp,h}` - Raw board implementation, without move history. Helper functions for Benson's algorithm and ladder search. * `boardhistory.{cpp,h}` - Datastructure that does include move history - handles superko, passing, game end, final scoring, komi, handicap detection, etc. * `graphhash.{cpp,h}` - History-sensitive hash used for [monte-carlo graph search](https://github.com/lightvector/KataGo/blob/master/docs/GraphSearch.md). -* `neuralnet` - Neural net GPU implementation and interface. Contains OpenCL, CUDA, Eigen, TensorRT backends along with common interfaces and model data structures. +* `neuralnet` - Neural net GPU implementation and interface. Contains OpenCL, CUDA, Eigen, TensorRT, and ONNX backends along with common interfaces and model data structures. * `desc.{cpp,h}` - Data structure holding neural net structure and weights. * `modelversion.{cpp,h}` - Enumerates the various versions of neural net features and models. * `nninputs.{cpp,h}` - Implements the input features for the neural net. * `sgfmetadata.{cpp,h}` - Implements the input features for the [HumanSL neural net](https://github.com/lightvector/KataGo/blob/master/docs/Analysis_Engine.md#human-sl-analysis-guide), for conditioning on various SGF metadata about human players from training data. * `nninterface.h` - Common interface that is implemented by every low-level neural net backend. - * `{cuda,opencl,eigen,trt,dummy}backend.cpp` - Various backends. + * `{cuda,opencl,eigen,trt,onnx,dummy}backend.cpp` - Various backends. + * `onnxmodelbuilder.{cpp,h}` - Builds ONNX graphs from KataGo model weights for the ONNX backend. * `nneval.{cpp,h}` - Top-level handle to the neural net used by the rest of the engine, implements thread-safe batching of queries. * `search` - The main search engine. * `timecontrols.cpp` - Basic handling of a few possible time controls. diff --git a/cpp/configs/gtp_example.cfg b/cpp/configs/gtp_example.cfg index cfa720bf3..8b1fa1a87 100644 --- a/cpp/configs/gtp_example.cfg +++ b/cpp/configs/gtp_example.cfg @@ -517,6 +517,35 @@ searchFactorWhenWinningThreshold = 0.95 # Default: numSearchThreads # numEigenThreadsPerModel = X +# ------------------------------ +# ONNX backend settings +# ------------------------------ +# These only apply when using the ONNX version of KataGo. + +# Execution provider to use: "cpu" (default), "coreml" (macOS only), +# "cuda" (NVIDIA GPU), or "tensorrt" (NVIDIA GPU, optimized). +# CoreML uses Apple's Neural Engine and GPU for hardware-accelerated inference. +# CUDA and TensorRT require ONNX Runtime built with --use_cuda or --use_tensorrt. +# onnxProvider = cpu + +# Override input/output node names for raw .onnx model files. +# When loading a raw .onnx file, KataGo auto-detects node names by searching +# for "spatial", "global", "meta", "policy", "value", "miscvalue", "ownership" +# in the model's node names. Use these settings to override if auto-detection +# picks the wrong nodes. +# onnxInputSpatial = input_spatial +# onnxInputGlobal = input_global +# onnxInputMeta = input_meta +# onnxOutputPolicy = out_policy +# onnxOutputValue = out_value +# onnxOutputMiscvalue = out_miscvalue +# onnxOutputOwnership = out_ownership + +# Override the auto-detected model version for raw .onnx model files. +# Model version is normally auto-detected from channel counts. Set this +# to a specific version number (>= 0) if auto-detection picks the wrong one. +# onnxModelVersion = 15 + # =========================================================================== # Root move selection and biases # =========================================================================== diff --git a/cpp/dataio/loadmodel.cpp b/cpp/dataio/loadmodel.cpp index 81483b170..ffe7cc86f 100644 --- a/cpp/dataio/loadmodel.cpp +++ b/cpp/dataio/loadmodel.cpp @@ -20,6 +20,7 @@ std::time_t to_time_t(TP tp) static const vector ACCEPTABLE_MODEL_SUFFIXES { ".bin.gz", ".bin", + ".onnx", "model.txt.gz", "model.txt" }; @@ -27,23 +28,23 @@ static const vector GENERIC_MODEL_NAMES { "model.bin.gz", "model.bin", "model.txt.gz", - "model.txt" + "model.txt", "Model.bin.gz", "Model.bin", "Model.txt.gz", - "Model.txt" + "Model.txt", "MODEL.bin.gz", "MODEL.bin", "MODEL.txt.gz", - "MODEL.txt" + "MODEL.txt", "model.ckpt", - "Model.ckpt" + "Model.ckpt", "MODEL.ckpt", "model.checkpoint", - "Model.checkpoint" + "Model.checkpoint", "MODEL.checkpoint", "model", - "Model" + "Model", "MODEL", }; diff --git a/cpp/main.cpp b/cpp/main.cpp index 0fcc36dea..631d256b3 100644 --- a/cpp/main.cpp +++ b/cpp/main.cpp @@ -248,6 +248,8 @@ string Version::getKataGoVersionFullInfo() { out << "Using OpenCL backend" << endl; #elif defined(USE_EIGEN_BACKEND) out << "Using Eigen(CPU) backend" << endl; +#elif defined(USE_ONNX_BACKEND) + out << "Using ONNX backend" << endl; #else out << "Using dummy backend" << endl; #endif diff --git a/cpp/neuralnet/onnxbackend.cpp b/cpp/neuralnet/onnxbackend.cpp new file mode 100644 index 000000000..fc9b31e07 --- /dev/null +++ b/cpp/neuralnet/onnxbackend.cpp @@ -0,0 +1,794 @@ +// ONNX Runtime backend for KataGo. +// Loads standard .bin.gz model files (builds ONNX graph from ModelDesc) or +// raw .onnx model files directly, and runs inference via ONNX Runtime with a +// configurable execution provider (CPU, CoreML, CUDA, TensorRT) selected at +// runtime via the onnxProvider config key. + +#include "../neuralnet/nninterface.h" +#include "../neuralnet/nneval.h" +#include "../neuralnet/nninputs.h" +#include "../neuralnet/modelversion.h" +#include "../neuralnet/onnxmodelbuilder.h" + +#include +#ifdef __APPLE__ +#include +#endif + +#include + +using namespace std; + +//-------------------------------------------------------------- + +// Auto-detect modelVersion from introspected channel counts. +// +// Detection is based on channel-count heuristics for raw .onnx files where the +// model version is not encoded in the file. The mapping assumes V7 inputs +// (22 spatial + 19 global channels) and distinguishes versions by the number of +// score-value and policy output channels: +// - 4 score-value channels → version 8 +// - 6 score-value channels, 1 policy channel → version 10 +// - 6 score-value channels, 2 policy channels → version 15 +// +// If the heuristic picks the wrong version, set the `onnxModelVersion` config +// key to the correct value (>= 0) to override auto-detection. +static int detectModelVersion( + int numInputChannels, int numInputGlobalChannels, + int numPolicyChannels, int numScoreValueChannels, + int configModelVersion +) { + if(configModelVersion >= 0) + return configModelVersion; + + // inputsVersion 7 → models 8-16: 22 spatial + 19 global + if(numInputChannels == NNInputs::NUM_FEATURES_SPATIAL_V7 && + numInputGlobalChannels == NNInputs::NUM_FEATURES_GLOBAL_V7) { + if(numScoreValueChannels == 6 && numPolicyChannels == 2) + return 15; + if(numScoreValueChannels == 6 && numPolicyChannels == 1) + return 10; + if(numScoreValueChannels == 4) + return 8; + // Default for V7 inputs + return 15; + } + // Older input versions — fall back to a reasonable default + return NNModelVersion::defaultModelVersion; +} + +struct LoadedModel { + ModelDesc modelDesc; + bool isRawOnnx; + string rawOnnxBytes; + + // Constructor for .bin.gz files + LoadedModel(const string& fileName, const string& expectedSha256, bool rawOnnx) + : isRawOnnx(rawOnnx) + { + if(!rawOnnx) { + ModelDesc::loadFromFileMaybeGZipped(fileName, modelDesc, expectedSha256); + return; + } + + // Read raw .onnx file bytes + { + std::ifstream in(fileName, std::ios::binary | std::ios::ate); + if(!in.good()) + throw StringError("ONNX backend: could not open raw ONNX file: " + fileName); + std::streamsize size = in.tellg(); + if(size < 0) + throw StringError("ONNX backend: could not determine size of ONNX file: " + fileName); + in.seekg(0, std::ios::beg); + rawOnnxBytes.resize(size); + if(!in.read(rawOnnxBytes.data(), size)) + throw StringError("ONNX backend: failed to read raw ONNX file: " + fileName); + } + + // Create a temporary CPU session to introspect shapes + Ort::Env tmpEnv(ORT_LOGGING_LEVEL_WARNING, "KataGoOnnxIntrospect"); + Ort::SessionOptions tmpOpts; + tmpOpts.SetIntraOpNumThreads(1); + Ort::Session tmpSession(tmpEnv, rawOnnxBytes.data(), rawOnnxBytes.size(), tmpOpts); + + Ort::AllocatorWithDefaultOptions allocator; + + // Introspect inputs by name first, falling back to shape-based heuristic + int numInputChannels = 0; + int numInputGlobalChannels = 0; + int numInputMetaChannels = 0; + size_t numInputs = tmpSession.GetInputCount(); + for(size_t i = 0; i < numInputs; i++) { + Ort::AllocatedStringPtr namePtr = tmpSession.GetInputNameAllocated(i, allocator); + string name = namePtr.get(); + auto typeInfo = tmpSession.GetInputTypeInfo(i); + auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo(); + auto shape = tensorInfo.GetShape(); + if(name.find("spatial") != string::npos) { + if(shape.size() >= 2) + numInputChannels = (int)shape[1]; + } else if(name.find("global") != string::npos) { + if(shape.size() >= 2) + numInputGlobalChannels = (int)shape[1]; + } else if(name.find("meta") != string::npos) { + if(shape.size() >= 2) + numInputMetaChannels = (int)shape[1]; + } else if(shape.size() == 4) { + // Shape-based fallback: [N, C, H, W] — spatial input + numInputChannels = (int)shape[1]; + } else if(shape.size() == 2) { + // Shape-based fallback: [N, C] — first 2D is global, second is meta + if(numInputGlobalChannels == 0) + numInputGlobalChannels = (int)shape[1]; + else + numInputMetaChannels = (int)shape[1]; + } else { + cerr << "ONNX backend warning: unrecognized input tensor '" << name + << "' with " << shape.size() << "D shape, ignoring" << "\n"; + } + } + + // Introspect outputs + int numPolicyChannels = 0; + int numValueChannels = 0; + int numScoreValueChannels = 0; + int numOwnershipChannels = 0; + size_t numOutputs = tmpSession.GetOutputCount(); + for(size_t i = 0; i < numOutputs; i++) { + Ort::AllocatedStringPtr namePtr = tmpSession.GetOutputNameAllocated(i, allocator); + string name = namePtr.get(); + auto typeInfo = tmpSession.GetOutputTypeInfo(i); + auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo(); + auto shape = tensorInfo.GetShape(); + + if(name.find("policy") != string::npos) { + // Policy: [N, C, H*W+1] → dim 1 is policy channels + if(shape.size() >= 2) + numPolicyChannels = (int)shape[1]; + } else if(name.find("miscvalue") != string::npos) { + // MiscValue: [N, numScoreValueChannels] — check before "value" since "miscvalue" contains "value" + if(shape.size() >= 2) + numScoreValueChannels = (int)shape[1]; + } else if(name.find("value") != string::npos) { + // Value: [N, 3] + if(shape.size() >= 2) + numValueChannels = (int)shape[1]; + } else if(name.find("ownership") != string::npos) { + // Ownership: [N, 1, H, W] + if(shape.size() >= 2) + numOwnershipChannels = (int)shape[1]; + } + } + + // Populate ModelDesc metadata (weights are in the ONNX graph, not in modelDesc) + modelDesc.numInputChannels = numInputChannels; + modelDesc.numInputGlobalChannels = numInputGlobalChannels; + modelDesc.numInputMetaChannels = numInputMetaChannels; + modelDesc.numPolicyChannels = numPolicyChannels; + modelDesc.numValueChannels = numValueChannels; + modelDesc.numScoreValueChannels = numScoreValueChannels; + modelDesc.numOwnershipChannels = numOwnershipChannels; + + // Extract filename stem as model name + { + size_t lastSlash = fileName.find_last_of("/\\"); + string basename = (lastSlash != string::npos) ? fileName.substr(lastSlash + 1) : fileName; + size_t dotPos = basename.find('.'); + modelDesc.name = (dotPos != string::npos) ? basename.substr(0, dotPos) : basename; + } + + // Model version: auto-detect with possible config override (applied later) + modelDesc.modelVersion = detectModelVersion( + numInputChannels, numInputGlobalChannels, + numPolicyChannels, numScoreValueChannels, + -1 // No config override at load time; applied in createComputeHandle if needed + ); + + // postProcessParams gets default values from its constructor (already set) + } + + LoadedModel() = delete; + LoadedModel(const LoadedModel&) = delete; + LoadedModel& operator=(const LoadedModel&) = delete; +}; + +LoadedModel* NeuralNet::loadModelFile(const string& file, const string& expectedSha256) { + bool isRawOnnx = Global::isSuffix(file, ".onnx"); + return new LoadedModel(file, expectedSha256, isRawOnnx); +} + +void NeuralNet::freeLoadedModel(LoadedModel* loadedModel) { + delete loadedModel; +} + +const ModelDesc& NeuralNet::getModelDesc(const LoadedModel* loadedModel) { + return loadedModel->modelDesc; +} + +//-------------------------------------------------------------- + +struct ComputeContext { + Ort::Env env; + int nnXLen; + int nnYLen; + string providerName; + + // Configurable input/output node names + string inputSpatialName; + string inputGlobalName; + string inputMetaName; + string outputPolicyName; + string outputValueName; + string outputMiscvalueName; + string outputOwnershipName; + + // Config override for model version (-1 means auto-detect) + int configModelVersion; + + ComputeContext(int xLen, int yLen, const string& provider) + : env(ORT_LOGGING_LEVEL_WARNING, "KataGoOnnx"), + nnXLen(xLen), + nnYLen(yLen), + providerName(provider), + inputSpatialName("input_spatial"), + inputGlobalName("input_global"), + inputMetaName("input_meta"), + outputPolicyName("out_policy"), + outputValueName("out_value"), + outputMiscvalueName("out_miscvalue"), + outputOwnershipName("out_ownership"), + configModelVersion(-1) + {} +}; + +//-------------------------------------------------------------- + +struct ComputeHandle { + ComputeContext* context; + std::unique_ptr session; + int modelVersion; + int numInputChannels; + int numInputGlobalChannels; + int numPolicyChannels; + int numValueChannels; + int numScoreValueChannels; + int numOwnershipChannels; + int numInputMetaChannels; + int policyResultLen; // H*W+1 + + // Input/output names (stored for session->Run) + vector inputNames; + vector outputNames; + vector inputNamePtrs; + vector outputNamePtrs; + + ComputeHandle(ComputeContext* ctx, const LoadedModel& loadedModel, Logger* logger) + : context(ctx), + modelVersion(loadedModel.modelDesc.modelVersion), + numInputChannels(loadedModel.modelDesc.numInputChannels), + numInputGlobalChannels(loadedModel.modelDesc.numInputGlobalChannels), + numPolicyChannels(loadedModel.modelDesc.numPolicyChannels), + numValueChannels(loadedModel.modelDesc.numValueChannels), + numScoreValueChannels(loadedModel.modelDesc.numScoreValueChannels), + numOwnershipChannels(loadedModel.modelDesc.numOwnershipChannels), + numInputMetaChannels(loadedModel.modelDesc.numInputMetaChannels), + policyResultLen(ctx->nnXLen * ctx->nnYLen + 1) + { + // Apply config model version override if set + if(ctx->configModelVersion >= 0) + modelVersion = ctx->configModelVersion; + + const char* onnxData; + size_t onnxSize; + string builtOnnxBytes; + if(loadedModel.isRawOnnx) { + if(logger != NULL) + logger->write("ONNX backend: using raw ONNX model (" + + Global::uint64ToString(loadedModel.rawOnnxBytes.size()) + " bytes)"); + onnxData = loadedModel.rawOnnxBytes.data(); + onnxSize = loadedModel.rawOnnxBytes.size(); + } else { + if(logger != NULL) + logger->write("ONNX backend: building ONNX graph from model weights..."); + builtOnnxBytes = OnnxModelBuilder::buildOnnxModel(loadedModel.modelDesc, ctx->nnXLen, ctx->nnYLen); + if(logger != NULL) + logger->write("ONNX backend: ONNX graph built (" + Global::uint64ToString(builtOnnxBytes.size()) + " bytes)"); + onnxData = builtOnnxBytes.data(); + onnxSize = builtOnnxBytes.size(); + } + + if(logger != NULL) + logger->write("ONNX backend: creating session..."); + + Ort::SessionOptions sessionOpts; + sessionOpts.SetIntraOpNumThreads(1); + + // Select execution provider based on providerName + const string& provider = ctx->providerName; + if(provider == "coreml") { +#ifdef __APPLE__ + uint32_t coremlFlags = COREML_FLAG_CREATE_MLPROGRAM; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOpts, coremlFlags)); + if(logger != NULL) + logger->write("ONNX backend: CoreML execution provider enabled (MLProgram mode)"); +#else + throw StringError("ONNX backend: CoreML is only available on Apple platforms"); +#endif + } else if(provider == "cuda") { + OrtCUDAProviderOptions cudaOpts; + cudaOpts.device_id = 0; + sessionOpts.AppendExecutionProvider_CUDA(cudaOpts); + if(logger != NULL) + logger->write("ONNX backend: CUDA execution provider enabled"); + } else if(provider == "tensorrt") { + OrtTensorRTProviderOptions trtOpts; + trtOpts.device_id = 0; + sessionOpts.AppendExecutionProvider_TensorRT(trtOpts); + if(logger != NULL) + logger->write("ONNX backend: TensorRT execution provider enabled"); + } else if(provider == "cpu" || provider.empty()) { + if(logger != NULL) + logger->write("ONNX backend: using CPU execution provider"); + } else { + throw StringError("ONNX backend: unknown onnxProvider '" + provider + "', expected 'cpu', 'coreml', 'cuda', or 'tensorrt'"); + } + + // Create session from in-memory bytes + session = std::make_unique(ctx->env, onnxData, onnxSize, sessionOpts); + + // Query and store input names + Ort::AllocatorWithDefaultOptions allocator; + size_t numInputs = session->GetInputCount(); + for(size_t i = 0; i < numInputs; i++) { + Ort::AllocatedStringPtr name = session->GetInputNameAllocated(i, allocator); + inputNames.push_back(name.get()); + } + for(auto& n : inputNames) + inputNamePtrs.push_back(n.c_str()); + + // Query and store output names + size_t numOutputs = session->GetOutputCount(); + for(size_t i = 0; i < numOutputs; i++) { + Ort::AllocatedStringPtr name = session->GetOutputNameAllocated(i, allocator); + outputNames.push_back(name.get()); + } + for(auto& n : outputNames) + outputNamePtrs.push_back(n.c_str()); + + if(logger != NULL) + logger->write("ONNX backend: session created, inputs=" + Global::uint64ToString(numInputs) + + " outputs=" + Global::uint64ToString(numOutputs)); + } + + ComputeHandle() = delete; + ComputeHandle(const ComputeHandle&) = delete; + ComputeHandle& operator=(const ComputeHandle&) = delete; +}; + +//-------------------------------------------------------------- + +struct InputBuffers { + int maxBatchSize; + + size_t singleInputElts; + size_t singleInputGlobalElts; + size_t singleInputMetaElts; + + vector spatialInput; + vector globalInput; + vector metaInput; + + InputBuffers(const LoadedModel* loadedModel, int maxBatchSz, int nnXLen, int nnYLen) { + const ModelDesc& m = loadedModel->modelDesc; + maxBatchSize = maxBatchSz; + singleInputElts = (size_t)m.numInputChannels * nnXLen * nnYLen; + singleInputGlobalElts = (size_t)m.numInputGlobalChannels; + singleInputMetaElts = (size_t)m.numInputMetaChannels; + spatialInput.resize(singleInputElts * maxBatchSize, 0.0f); + globalInput.resize(singleInputGlobalElts * maxBatchSize, 0.0f); + if(m.numInputMetaChannels > 0) + metaInput.resize(singleInputMetaElts * maxBatchSize, 0.0f); + } + + ~InputBuffers() {} + + InputBuffers() = delete; + InputBuffers(const InputBuffers&) = delete; + InputBuffers& operator=(const InputBuffers&) = delete; +}; + +InputBuffers* NeuralNet::createInputBuffers(const LoadedModel* loadedModel, int maxBatchSize, int nnXLen, int nnYLen) { + return new InputBuffers(loadedModel, maxBatchSize, nnXLen, nnYLen); +} +void NeuralNet::freeInputBuffers(InputBuffers* inputBuffers) { + delete inputBuffers; +} + +//-------------------------------------------------------------- + +void NeuralNet::globalInitialize() { +} + +void NeuralNet::globalCleanup() { +} + +//-------------------------------------------------------------- + +ComputeContext* NeuralNet::createComputeContext( + const std::vector& gpuIdxs, + Logger* logger, + int nnXLen, + int nnYLen, + const string& backendExtraParam, + const string& homeDataDirOverride, + bool openCLReTunePerBoardSize, + enabled_t useFP16Mode, + enabled_t useNHWCMode, + const LoadedModel* loadedModel +) { + (void)gpuIdxs; + (void)homeDataDirOverride; + (void)openCLReTunePerBoardSize; + (void)useFP16Mode; + (void)useNHWCMode; + (void)loadedModel; + + // Parse backendExtraParam as "key=value;key=value;..." + string providerName = "cpu"; + map params; + if(!backendExtraParam.empty()) { + vector parts = Global::split(backendExtraParam, ';'); + for(const string& part : parts) { + size_t eq = part.find('='); + if(eq != string::npos) { + string key = Global::trim(part.substr(0, eq)); + string val = Global::trim(part.substr(eq + 1)); + params[key] = val; + } else { + // Legacy: bare string is provider name + string trimmed = Global::trim(part); + if(!trimmed.empty()) + providerName = trimmed; + } + } + if(params.count("provider")) + providerName = params["provider"]; + } + + if(logger != NULL) + logger->write("ONNX backend: creating compute context for " + + Global::intToString(nnXLen) + "x" + Global::intToString(nnYLen) + + " with provider '" + providerName + "'"); + + ComputeContext* ctx = new ComputeContext(nnXLen, nnYLen, providerName); + + // Apply configured node names + if(params.count("inputSpatial")) ctx->inputSpatialName = params["inputSpatial"]; + if(params.count("inputGlobal")) ctx->inputGlobalName = params["inputGlobal"]; + if(params.count("inputMeta")) ctx->inputMetaName = params["inputMeta"]; + if(params.count("outputPolicy")) ctx->outputPolicyName = params["outputPolicy"]; + if(params.count("outputValue")) ctx->outputValueName = params["outputValue"]; + if(params.count("outputMiscvalue")) ctx->outputMiscvalueName = params["outputMiscvalue"]; + if(params.count("outputOwnership")) ctx->outputOwnershipName = params["outputOwnership"]; + if(params.count("modelVersion")) { + int v = Global::stringToInt(params["modelVersion"]); + if(v >= 0) + ctx->configModelVersion = v; + } + + return ctx; +} + +void NeuralNet::freeComputeContext(ComputeContext* computeContext) { + delete computeContext; +} + +//-------------------------------------------------------------- + +ComputeHandle* NeuralNet::createComputeHandle( + ComputeContext* context, + const LoadedModel* loadedModel, + Logger* logger, + int maxBatchSize, + bool requireExactNNLen, + bool inputsUseNHWC, + int gpuIdxForThisThread, + int serverThreadIdx +) { + (void)maxBatchSize; + (void)requireExactNNLen; + (void)gpuIdxForThisThread; + + if(inputsUseNHWC) + throw StringError("ONNX backend: inputsUseNHWC = true not supported, must use NCHW"); + + if(logger != NULL) { + logger->write("ONNX backend thread " + Global::intToString(serverThreadIdx) + + ": Model version " + Global::intToString(loadedModel->modelDesc.modelVersion)); + logger->write("ONNX backend thread " + Global::intToString(serverThreadIdx) + + ": Model name: " + loadedModel->modelDesc.name); + } + + return new ComputeHandle(context, *loadedModel, logger); +} + +void NeuralNet::freeComputeHandle(ComputeHandle* computeHandle) { + delete computeHandle; +} + +bool NeuralNet::isUsingFP16(const ComputeHandle* handle) { + (void)handle; + return false; +} + +//-------------------------------------------------------------- + +// Helper to find the index of a name in a vector, checking multiple alternatives. +static int findNameIndex(const vector& names, const vector& targets) { + for(size_t i = 0; i < names.size(); i++) { + for(const auto& t : targets) { + if(names[i] == t) + return (int)i; + } + } + return -1; +} + +void NeuralNet::getOutput( + ComputeHandle* computeHandle, + InputBuffers* inputBuffers, + int numBatchEltsFilled, + NNResultBuf** inputBufs, + vector& outputs +) { + assert(numBatchEltsFilled <= inputBuffers->maxBatchSize); + assert(numBatchEltsFilled > 0); + const int batchSize = numBatchEltsFilled; + const int nnXLen = computeHandle->context->nnXLen; + const int nnYLen = computeHandle->context->nnYLen; + const int numSpatialFeatures = computeHandle->numInputChannels; + const int numGlobalFeatures = computeHandle->numInputGlobalChannels; + const int numPolicyChannels = computeHandle->numPolicyChannels; + + // Fill input buffers + for(int nIdx = 0; nIdx < batchSize; nIdx++) { + float* rowSpatialInput = inputBuffers->spatialInput.data() + (inputBuffers->singleInputElts * nIdx); + float* rowGlobalInput = inputBuffers->globalInput.data() + (inputBuffers->singleInputGlobalElts * nIdx); + + const float* rowGlobal = inputBufs[nIdx]->rowGlobalBuf.data(); + const float* rowSpatial = inputBufs[nIdx]->rowSpatialBuf.data(); + std::copy(rowGlobal, rowGlobal + numGlobalFeatures, rowGlobalInput); + SymmetryHelpers::copyInputsWithSymmetry(rowSpatial, rowSpatialInput, 1, nnYLen, nnXLen, numSpatialFeatures, false, inputBufs[nIdx]->symmetry); + + if(computeHandle->numInputMetaChannels > 0) { + float* rowMetaInput = inputBuffers->metaInput.data() + (inputBuffers->singleInputMetaElts * nIdx); + const float* rowMeta = inputBufs[nIdx]->rowMetaBuf.data(); + std::copy(rowMeta, rowMeta + computeHandle->numInputMetaChannels, rowMetaInput); + } + } + + // Create ONNX tensors + Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + + std::array spatialShape = {batchSize, numSpatialFeatures, nnYLen, nnXLen}; + Ort::Value spatialTensor = Ort::Value::CreateTensor( + memInfo, inputBuffers->spatialInput.data(), inputBuffers->singleInputElts * batchSize, + spatialShape.data(), spatialShape.size() + ); + + std::array globalShape = {batchSize, numGlobalFeatures}; + Ort::Value globalTensor = Ort::Value::CreateTensor( + memInfo, inputBuffers->globalInput.data(), inputBuffers->singleInputGlobalElts * batchSize, + globalShape.data(), globalShape.size() + ); + + // Match input ordering using configured node names + const ComputeContext* ctx = computeHandle->context; + int spatialIdx = findNameIndex(computeHandle->inputNames, {ctx->inputSpatialName}); + int globalIdx = findNameIndex(computeHandle->inputNames, {ctx->inputGlobalName}); + if(spatialIdx < 0 || globalIdx < 0) + throw StringError("ONNX backend: could not find expected input names"); + + int metaIdx = -1; + Ort::Value metaTensor(nullptr); + if(computeHandle->numInputMetaChannels > 0) { + metaIdx = findNameIndex(computeHandle->inputNames, {ctx->inputMetaName}); + if(metaIdx < 0) + throw StringError("ONNX backend: model has metadata channels but could not find input_meta"); + std::array metaShape = {batchSize, computeHandle->numInputMetaChannels}; + metaTensor = Ort::Value::CreateTensor( + memInfo, inputBuffers->metaInput.data(), inputBuffers->singleInputMetaElts * batchSize, + metaShape.data(), metaShape.size() + ); + } + + vector inputTensors; + inputTensors.reserve(computeHandle->inputNames.size()); + for(size_t i = 0; i < computeHandle->inputNames.size(); i++) { + if((int)i == spatialIdx) + inputTensors.push_back(std::move(spatialTensor)); + else if((int)i == globalIdx) + inputTensors.push_back(std::move(globalTensor)); + else if((int)i == metaIdx) + inputTensors.push_back(std::move(metaTensor)); + else { + throw StringError("ONNX backend: unexpected input node '" + computeHandle->inputNames[i] + + "' — only spatial, global, and meta inputs are supported"); + } + } + + // Run inference + auto outputTensors = computeHandle->session->Run( + Ort::RunOptions{nullptr}, + computeHandle->inputNamePtrs.data(), + inputTensors.data(), + inputTensors.size(), + computeHandle->outputNamePtrs.data(), + computeHandle->outputNamePtrs.size() + ); + + // Find output indices using configured node names + int policyOutputIdx = findNameIndex(computeHandle->outputNames, {ctx->outputPolicyName}); + int valueOutputIdx = findNameIndex(computeHandle->outputNames, {ctx->outputValueName}); + int miscvalueOutputIdx = findNameIndex(computeHandle->outputNames, {ctx->outputMiscvalueName}); + int ownershipOutputIdx = findNameIndex(computeHandle->outputNames, {ctx->outputOwnershipName}); + + if(policyOutputIdx < 0) + throw StringError("ONNX backend: could not find policy output node '" + ctx->outputPolicyName + "'"); + if(valueOutputIdx < 0) + throw StringError("ONNX backend: could not find value output node '" + ctx->outputValueName + "'"); + if(miscvalueOutputIdx < 0) + throw StringError("ONNX backend: could not find miscvalue output node '" + ctx->outputMiscvalueName + "'"); + if(ownershipOutputIdx < 0) + throw StringError("ONNX backend: could not find ownership output node '" + ctx->outputOwnershipName + "'"); + + const float* policyData = outputTensors[policyOutputIdx].GetTensorData(); + const float* valueData = outputTensors[valueOutputIdx].GetTensorData(); + const float* miscvalueData = outputTensors[miscvalueOutputIdx].GetTensorData(); + const float* ownershipData = outputTensors[ownershipOutputIdx].GetTensorData(); + + assert(policyData != nullptr); + assert(valueData != nullptr); + assert(miscvalueData != nullptr); + assert(ownershipData != nullptr); + assert((int)outputs.size() == batchSize); + + const int policyResultLen = computeHandle->policyResultLen; + const int spatialPolicyLen = nnXLen * nnYLen; + float policyProbsTmp[NNPos::MAX_NN_POLICY_SIZE]; + + for(int row = 0; row < batchSize; row++) { + NNOutput* output = outputs[row]; + assert(output->nnXLen == nnXLen); + assert(output->nnYLen == nnYLen); + float policyOptimism = (float)inputBufs[row]->policyOptimism; + + // Policy: [N, C, H*W+1] + { + const float* policyRowBase = policyData + row * numPolicyChannels * policyResultLen; + float* policyProbs = output->policyProbs; + + if(numPolicyChannels >= 2) { + const float* ch0 = policyRowBase; + const float* ch1 = policyRowBase + policyResultLen; + for(int i = 0; i < spatialPolicyLen; i++) { + float p = ch0[i]; + float pOpt = ch1[i]; + policyProbsTmp[i] = p + (pOpt - p) * policyOptimism; + } + SymmetryHelpers::copyOutputsWithSymmetry(policyProbsTmp, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + policyProbs[spatialPolicyLen] = ch0[spatialPolicyLen] + (ch1[spatialPolicyLen] - ch0[spatialPolicyLen]) * policyOptimism; + } else { + assert(numPolicyChannels == 1); + const float* ch0 = policyRowBase; + SymmetryHelpers::copyOutputsWithSymmetry(ch0, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + policyProbs[spatialPolicyLen] = ch0[spatialPolicyLen]; + } + } + + // Value: [N, 3] + { + int numVC = computeHandle->numValueChannels; + assert(numVC == 3); + output->whiteWinProb = valueData[row * numVC]; + output->whiteLossProb = valueData[row * numVC + 1]; + output->whiteNoResultProb = valueData[row * numVC + 2]; + } + + // MiscValue: [N, numScoreValueChannels] — version-dependent interpretation + { + int numScoreValueChannels = computeHandle->numScoreValueChannels; + if(computeHandle->modelVersion >= 9) { + assert(numScoreValueChannels >= 6); + output->whiteScoreMean = miscvalueData[row * numScoreValueChannels]; + output->whiteScoreMeanSq = miscvalueData[row * numScoreValueChannels + 1]; + output->whiteLead = miscvalueData[row * numScoreValueChannels + 2]; + output->varTimeLeft = miscvalueData[row * numScoreValueChannels + 3]; + output->shorttermWinlossError = miscvalueData[row * numScoreValueChannels + 4]; + output->shorttermScoreError = miscvalueData[row * numScoreValueChannels + 5]; + } + else if(computeHandle->modelVersion >= 8) { + assert(numScoreValueChannels >= 4); + output->whiteScoreMean = miscvalueData[row * numScoreValueChannels]; + output->whiteScoreMeanSq = miscvalueData[row * numScoreValueChannels + 1]; + output->whiteLead = miscvalueData[row * numScoreValueChannels + 2]; + output->varTimeLeft = miscvalueData[row * numScoreValueChannels + 3]; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } + else if(computeHandle->modelVersion >= 4) { + assert(numScoreValueChannels >= 2); + output->whiteScoreMean = miscvalueData[row * numScoreValueChannels]; + output->whiteScoreMeanSq = miscvalueData[row * numScoreValueChannels + 1]; + output->whiteLead = output->whiteScoreMean; + output->varTimeLeft = 0; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } + else if(computeHandle->modelVersion >= 3) { + assert(numScoreValueChannels >= 1); + output->whiteScoreMean = miscvalueData[row * numScoreValueChannels]; + output->whiteScoreMeanSq = output->whiteScoreMean * output->whiteScoreMean; + output->whiteLead = output->whiteScoreMean; + output->varTimeLeft = 0; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } + else { + ASSERT_UNREACHABLE; + } + } + + // Ownership: [N, 1, H, W] + if(output->whiteOwnerMap != NULL) { + assert(computeHandle->numOwnershipChannels == 1); + const float* ownershipRowBuf = ownershipData + row * nnXLen * nnYLen; + SymmetryHelpers::copyOutputsWithSymmetry(ownershipRowBuf, output->whiteOwnerMap, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + } + } +} + +void NeuralNet::printDevices() { +} + +//-------------------------------------------------------------- +// FOR TESTING — all return false (not implemented for this backend) + +bool NeuralNet::testEvaluateConv( + const ConvLayerDesc* desc, int batchSize, int nnXLen, int nnYLen, + bool useFP16, bool useNHWC, const std::vector& inputBuffer, std::vector& outputBuffer +) { + (void)desc; (void)batchSize; (void)nnXLen; (void)nnYLen; + (void)useFP16; (void)useNHWC; (void)inputBuffer; (void)outputBuffer; + return false; +} + +bool NeuralNet::testEvaluateBatchNorm( + const BatchNormLayerDesc* desc, int batchSize, int nnXLen, int nnYLen, + bool useFP16, bool useNHWC, const std::vector& inputBuffer, + const std::vector& maskBuffer, std::vector& outputBuffer +) { + (void)desc; (void)batchSize; (void)nnXLen; (void)nnYLen; + (void)useFP16; (void)useNHWC; (void)inputBuffer; (void)maskBuffer; (void)outputBuffer; + return false; +} + +bool NeuralNet::testEvaluateResidualBlock( + const ResidualBlockDesc* desc, int batchSize, int nnXLen, int nnYLen, + bool useFP16, bool useNHWC, const std::vector& inputBuffer, + const std::vector& maskBuffer, std::vector& outputBuffer +) { + (void)desc; (void)batchSize; (void)nnXLen; (void)nnYLen; + (void)useFP16; (void)useNHWC; (void)inputBuffer; (void)maskBuffer; (void)outputBuffer; + return false; +} + +bool NeuralNet::testEvaluateGlobalPoolingResidualBlock( + const GlobalPoolingResidualBlockDesc* desc, int batchSize, int nnXLen, int nnYLen, + bool useFP16, bool useNHWC, const std::vector& inputBuffer, + const std::vector& maskBuffer, std::vector& outputBuffer +) { + (void)desc; (void)batchSize; (void)nnXLen; (void)nnYLen; + (void)useFP16; (void)useNHWC; (void)inputBuffer; (void)maskBuffer; (void)outputBuffer; + return false; +} diff --git a/cpp/neuralnet/onnxmodelbuilder.cpp b/cpp/neuralnet/onnxmodelbuilder.cpp new file mode 100644 index 000000000..006d62479 --- /dev/null +++ b/cpp/neuralnet/onnxmodelbuilder.cpp @@ -0,0 +1,774 @@ +// Builds an ONNX computational graph from a KataGo ModelDesc. +// Uses the ONNX protobuf API (onnx-ml.pb.h) to construct a ModelProto +// that can be loaded directly by ONNX Runtime. + +#include "../neuralnet/onnxmodelbuilder.h" +#include "../neuralnet/activations.h" +#include "../core/global.h" + +#include + +#include +#include + +using namespace std; + +static string uniqueName(int& nameCounter, const string& prefix) { + return prefix + "_" + to_string(nameCounter++); +} + +// ===================================================================== +// Helper: Add a float tensor initializer to the graph +// ===================================================================== +static string addInitializer( + onnx::GraphProto* graph, + const string& name, + const vector& shape, + const float* data, + size_t numElements +) { + onnx::TensorProto* tensor = graph->add_initializer(); + tensor->set_name(name); + tensor->set_data_type(onnx::TensorProto_DataType_FLOAT); + for(int64_t d : shape) + tensor->add_dims(d); + tensor->set_raw_data(data, numElements * sizeof(float)); + return name; +} + +static string addInitializer( + onnx::GraphProto* graph, + const string& name, + const vector& shape, + const vector& data +) { + return addInitializer(graph, name, shape, data.data(), data.size()); +} + +// Add a scalar float constant +static string addScalarInitializer(onnx::GraphProto* graph, const string& name, float value) { + return addInitializer(graph, name, {}, &value, 1); +} + +// Add a 1D int64 constant tensor +static string addInt64Initializer( + onnx::GraphProto* graph, + const string& name, + const vector& data +) { + onnx::TensorProto* tensor = graph->add_initializer(); + tensor->set_name(name); + tensor->set_data_type(onnx::TensorProto_DataType_INT64); + tensor->add_dims((int64_t)data.size()); + tensor->set_raw_data(data.data(), data.size() * sizeof(int64_t)); + return name; +} + +// ===================================================================== +// Helper: Add ONNX graph node +// ===================================================================== + +// Generic node with n inputs, 1 output +static onnx::NodeProto* addNode( + onnx::GraphProto* graph, + const string& opType, + const vector& inputs, + const string& outputName +) { + onnx::NodeProto* node = graph->add_node(); + node->set_op_type(opType); + for(const auto& inp : inputs) + node->add_input(inp); + node->add_output(outputName); + return node; +} + +// Add an attribute (int) to a node +static void setAttrInt(onnx::NodeProto* node, const string& attrName, int64_t value) { + onnx::AttributeProto* attr = node->add_attribute(); + attr->set_name(attrName); + attr->set_type(onnx::AttributeProto_AttributeType_INT); + attr->set_i(value); +} + +// Add an attribute (ints) to a node +static void setAttrInts(onnx::NodeProto* node, const string& attrName, const vector& values) { + onnx::AttributeProto* attr = node->add_attribute(); + attr->set_name(attrName); + attr->set_type(onnx::AttributeProto_AttributeType_INTS); + for(int64_t v : values) + attr->add_ints(v); +} + +// ===================================================================== +// Convolution: Conv with zero-padding +// ===================================================================== +static string addConvNode( + onnx::GraphProto* graph, + int& nameCounter, + const string& input, + const ConvLayerDesc& desc, + const string& prefix +) { + string weightsName = addInitializer( + graph, prefix + "/w", + {desc.outChannels, desc.inChannels, desc.convYSize, desc.convXSize}, + desc.weights + ); + + int padY = desc.convYSize / 2; + int padX = desc.convXSize / 2; + string output = uniqueName(nameCounter, prefix + "/out"); + + onnx::NodeProto* convNode = addNode(graph, "Conv", {input, weightsName}, output); + setAttrInts(convNode, "kernel_shape", {desc.convYSize, desc.convXSize}); + setAttrInts(convNode, "pads", {padY, padX, padY, padX}); + setAttrInts(convNode, "dilations", {desc.dilationY, desc.dilationX}); + setAttrInts(convNode, "strides", {1, 1}); + + return output; +} + +// ===================================================================== +// Merged Batch Norm: output = input * mergedScale + mergedBias +// Applied channel-wise, broadcasting over [N, C, H, W] +// ===================================================================== +static string addMergedBNNode( + onnx::GraphProto* graph, + int& nameCounter, + const string& input, + const BatchNormLayerDesc& desc, + const string& prefix +) { + int C = desc.numChannels; + string scaleName = addInitializer(graph, prefix + "/scale", {C, 1, 1}, desc.mergedScale); + string biasName = addInitializer(graph, prefix + "/bias", {C, 1, 1}, desc.mergedBias); + + string scaled = uniqueName(nameCounter, prefix + "/scaled"); + addNode(graph, "Mul", {input, scaleName}, scaled); + + string output = uniqueName(nameCounter, prefix + "/bn_out"); + addNode(graph, "Add", {scaled, biasName}, output); + + return output; +} + +// ===================================================================== +// Activation: ReLU, Mish (softplus->tanh->mul), or Identity +// ===================================================================== +static string addActivationNode( + onnx::GraphProto* graph, + int& nameCounter, + const string& input, + int activationType, + const string& prefix +) { + if(activationType == ACTIVATION_RELU) { + string output = uniqueName(nameCounter, prefix + "/relu"); + addNode(graph, "Relu", {input}, output); + return output; + } else if(activationType == ACTIVATION_MISH) { + // Mish = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) + string sp = uniqueName(nameCounter, prefix + "/softplus"); + addNode(graph, "Softplus", {input}, sp); + + string th = uniqueName(nameCounter, prefix + "/tanh"); + addNode(graph, "Tanh", {sp}, th); + + string output = uniqueName(nameCounter, prefix + "/mish"); + addNode(graph, "Mul", {input, th}, output); + return output; + } else { + // ACTIVATION_IDENTITY — pass through + return input; + } +} + +// ===================================================================== +// BN + Activation + Mask multiply +// output = activation(input * scale + bias) * mask +// ===================================================================== +static string addBNActivationMask( + onnx::GraphProto* graph, + int& nameCounter, + const string& input, + const BatchNormLayerDesc& bnDesc, + const ActivationLayerDesc& actDesc, + const string& mask, + const string& prefix +) { + string bn = addMergedBNNode(graph, nameCounter, input, bnDesc, prefix + "/bn"); + string act = addActivationNode(graph, nameCounter, bn, actDesc.activation, prefix + "/act"); + string output = uniqueName(nameCounter, prefix + "/masked"); + addNode(graph, "Mul", {act, mask}, output); + return output; +} + +// ===================================================================== +// MatMul: output = input @ W +// W is [inC, outC] +// ===================================================================== +static string addMatMulNode( + onnx::GraphProto* graph, + int& nameCounter, + const string& input, + const MatMulLayerDesc& desc, + const string& prefix +) { + string weightsName = addInitializer(graph, prefix + "/w", {desc.inChannels, desc.outChannels}, desc.weights); + string output = uniqueName(nameCounter, prefix + "/matmul"); + addNode(graph, "MatMul", {input, weightsName}, output); + return output; +} + +// ===================================================================== +// Bias addition: output = input + bias +// bias is [C], broadcast over [N, C] or [N, C, H, W] +// ===================================================================== +static string addBiasNode( + onnx::GraphProto* graph, + int& nameCounter, + const string& input, + const MatBiasLayerDesc& desc, + const string& prefix +) { + string biasName = addInitializer(graph, prefix + "/b", {desc.numChannels}, desc.weights); + string output = uniqueName(nameCounter, prefix + "/biased"); + addNode(graph, "Add", {input, biasName}, output); + return output; +} + +// ===================================================================== +// KataGPool: Global pooling producing 3 values per channel +// Pool 1: mean = ReduceSum(x * mask, [2,3]) / maskSum +// Pool 2: mean * (sqrt(maskSum) - 14.0) * 0.1 +// Pool 3: ReduceMax(x + (mask - 1.0), [2,3]) +// Output: [N, 3*C] +// ===================================================================== +static string addGlobalPool( + onnx::GraphProto* graph, + int& nameCounter, + const string& input, + const string& mask, + const string& maskSumHW, + const string& prefix +) { + // x_masked = input * mask (already masked, but let's be safe) + string xMasked = uniqueName(nameCounter, prefix + "/gpool_xm"); + addNode(graph, "Mul", {input, mask}, xMasked); + + // sum = ReduceSum(xMasked, axes=[2,3]) + string axesName = addInt64Initializer(graph, uniqueName(nameCounter, prefix + "/axes23"), {2, 3}); + string sumOut = uniqueName(nameCounter, prefix + "/gpool_sum"); + onnx::NodeProto* sumNode = addNode(graph, "ReduceSum", {xMasked, axesName}, sumOut); + setAttrInt(sumNode, "keepdims", 0); + + // mean = sum / maskSumFlat + // maskSumHW is [N,1,1,1], we need [N,1] for division + string maskSumFlat = uniqueName(nameCounter, prefix + "/gpool_msf"); + string reshapeShape = addInt64Initializer(graph, uniqueName(nameCounter, prefix + "/shape_n1"), {0, 1}); + addNode(graph, "Reshape", {maskSumHW, reshapeShape}, maskSumFlat); + + string mean = uniqueName(nameCounter, prefix + "/gpool_mean"); + addNode(graph, "Div", {sumOut, maskSumFlat}, mean); + + // sqrtMaskSum = sqrt(maskSumFlat) + string sqrtMs = uniqueName(nameCounter, prefix + "/gpool_sqrt"); + addNode(graph, "Sqrt", {maskSumFlat}, sqrtMs); + + // sqrtMs - 14.0 + string const14 = addScalarInitializer(graph, uniqueName(nameCounter, prefix + "/c14"), 14.0f); + string sqrtMsSub = uniqueName(nameCounter, prefix + "/gpool_sqrtsub"); + addNode(graph, "Sub", {sqrtMs, const14}, sqrtMsSub); + + // * 0.1 + string const01 = addScalarInitializer(graph, uniqueName(nameCounter, prefix + "/c01"), 0.1f); + string scaledSqrt = uniqueName(nameCounter, prefix + "/gpool_ssm"); + addNode(graph, "Mul", {sqrtMsSub, const01}, scaledSqrt); + + // pool2 = mean * scaledSqrt + string pool2 = uniqueName(nameCounter, prefix + "/gpool_p2"); + addNode(graph, "Mul", {mean, scaledSqrt}, pool2); + + // Pool3: max over (x + mask - 1) + string constNeg1 = addScalarInitializer(graph, uniqueName(nameCounter, prefix + "/cn1"), -1.0f); + string maskBias = uniqueName(nameCounter, prefix + "/gpool_mb"); + addNode(graph, "Add", {mask, constNeg1}, maskBias); + + string xShifted = uniqueName(nameCounter, prefix + "/gpool_xs"); + addNode(graph, "Add", {input, maskBias}, xShifted); + + // ReduceMax over [2,3] + string axesName2 = addInt64Initializer(graph, uniqueName(nameCounter, prefix + "/axes23b"), {2, 3}); + string pool3 = uniqueName(nameCounter, prefix + "/gpool_max"); + onnx::NodeProto* maxNode = addNode(graph, "ReduceMax", {xShifted, axesName2}, pool3); + setAttrInt(maxNode, "keepdims", 0); + + // Concat [mean, pool2, pool3] along axis=1 + string output = uniqueName(nameCounter, prefix + "/gpool_out"); + onnx::NodeProto* concatNode = addNode(graph, "Concat", {mean, pool2, pool3}, output); + setAttrInt(concatNode, "axis", 1); + + return output; +} + +// ===================================================================== +// KataValueHeadGPool: Different third pool from KataGPool +// Pool 3: mean * ((sqrt(maskSum) - 14.0)^2 * 0.01 - 0.1) +// ===================================================================== +static string addValueHeadGPool( + onnx::GraphProto* graph, + int& nameCounter, + const string& input, + const string& mask, + const string& maskSumHW, + const string& prefix +) { + // x for value head already has activation applied + // sum = ReduceSum(input * mask, [2,3]) + string xMasked = uniqueName(nameCounter, prefix + "/vgpool_xm"); + addNode(graph, "Mul", {input, mask}, xMasked); + + string axesName = addInt64Initializer(graph, uniqueName(nameCounter, prefix + "/axes23"), {2, 3}); + string sumOut = uniqueName(nameCounter, prefix + "/vgpool_sum"); + onnx::NodeProto* sumNode = addNode(graph, "ReduceSum", {xMasked, axesName}, sumOut); + setAttrInt(sumNode, "keepdims", 0); + + // mean + string maskSumFlat = uniqueName(nameCounter, prefix + "/vgpool_msf"); + string reshapeShape = addInt64Initializer(graph, uniqueName(nameCounter, prefix + "/shape_n1"), {0, 1}); + addNode(graph, "Reshape", {maskSumHW, reshapeShape}, maskSumFlat); + + string mean = uniqueName(nameCounter, prefix + "/vgpool_mean"); + addNode(graph, "Div", {sumOut, maskSumFlat}, mean); + + // sqrt(maskSum) + string sqrtMs = uniqueName(nameCounter, prefix + "/vgpool_sqrt"); + addNode(graph, "Sqrt", {maskSumFlat}, sqrtMs); + + // (sqrt(maskSum) - 14.0) + string const14 = addScalarInitializer(graph, uniqueName(nameCounter, prefix + "/c14"), 14.0f); + string sqrtMsSub = uniqueName(nameCounter, prefix + "/vgpool_ss"); + addNode(graph, "Sub", {sqrtMs, const14}, sqrtMsSub); + + // pool2 = mean * (sqrtMsSub) * 0.1 + string const01 = addScalarInitializer(graph, uniqueName(nameCounter, prefix + "/c01"), 0.1f); + string scaledSqrt = uniqueName(nameCounter, prefix + "/vgpool_ssm"); + addNode(graph, "Mul", {sqrtMsSub, const01}, scaledSqrt); + string pool2 = uniqueName(nameCounter, prefix + "/vgpool_p2"); + addNode(graph, "Mul", {mean, scaledSqrt}, pool2); + + // pool3 = mean * ((sqrtMsSub)^2 * 0.01 - 0.1) + string sqrtMsSubSq = uniqueName(nameCounter, prefix + "/vgpool_sq"); + addNode(graph, "Mul", {sqrtMsSub, sqrtMsSub}, sqrtMsSubSq); + + string constP01 = addScalarInitializer(graph, uniqueName(nameCounter, prefix + "/cp01"), 0.01f); + string sqScaled = uniqueName(nameCounter, prefix + "/vgpool_sqs"); + addNode(graph, "Mul", {sqrtMsSubSq, constP01}, sqScaled); + + string constN01 = addScalarInitializer(graph, uniqueName(nameCounter, prefix + "/cn01"), -0.1f); + string sqShifted = uniqueName(nameCounter, prefix + "/vgpool_sqsh"); + addNode(graph, "Add", {sqScaled, constN01}, sqShifted); + + string pool3 = uniqueName(nameCounter, prefix + "/vgpool_p3"); + addNode(graph, "Mul", {mean, sqShifted}, pool3); + + // Concat [mean, pool2, pool3] along axis=1 + string output = uniqueName(nameCounter, prefix + "/vgpool_out"); + onnx::NodeProto* concatNode = addNode(graph, "Concat", {mean, pool2, pool3}, output); + setAttrInt(concatNode, "axis", 1); + + return output; +} + +// ===================================================================== +// Residual Block: BN→Act→Conv→BN→Act→Conv + skip +// ===================================================================== +static string addResidualBlock( + onnx::GraphProto* graph, + int& nameCounter, + const string& input, + const string& mask, + const ResidualBlockDesc& desc, + const string& prefix +) { + string pre = addBNActivationMask(graph, nameCounter, input, desc.preBN, desc.preActivation, mask, prefix + "/pre"); + string mid = addConvNode(graph, nameCounter, pre, desc.regularConv, prefix + "/conv1"); + string midAct = addBNActivationMask(graph, nameCounter, mid, desc.midBN, desc.midActivation, mask, prefix + "/mid"); + string final_ = addConvNode(graph, nameCounter, midAct, desc.finalConv, prefix + "/conv2"); + + // Residual add + string output = uniqueName(nameCounter, prefix + "/resadd"); + addNode(graph, "Add", {input, final_}, output); + return output; +} + +// ===================================================================== +// Global Pooling Residual Block +// ===================================================================== +static string addGPoolResidualBlock( + onnx::GraphProto* graph, + int& nameCounter, + const string& input, + const string& mask, + const string& maskSumHW, + const GlobalPoolingResidualBlockDesc& desc, + const string& prefix +) { + string pre = addBNActivationMask(graph, nameCounter, input, desc.preBN, desc.preActivation, mask, prefix + "/pre"); + + // Regular path + string regOut = addConvNode(graph, nameCounter, pre, desc.regularConv, prefix + "/reg"); + + // Global pooling path + string gpoolConvOut = addConvNode(graph, nameCounter, pre, desc.gpoolConv, prefix + "/gconv"); + string gpoolBNAct = addBNActivationMask(graph, nameCounter, gpoolConvOut, desc.gpoolBN, desc.gpoolActivation, mask, prefix + "/gbn"); + string gpoolResult = addGlobalPool(graph, nameCounter, gpoolBNAct, mask, maskSumHW, prefix + "/gpool"); + + // gpoolToBiasMul: [N, 3*gpoolC] → [N, regC] + string gpoolBias = addMatMulNode(graph, nameCounter, gpoolResult, desc.gpoolToBiasMul, prefix + "/g2b"); + + // Reshape bias to [N, C, 1, 1] for broadcasting + string biasShape = addInt64Initializer(graph, uniqueName(nameCounter, prefix + "/shape_nc11"), {0, -1, 1, 1}); + string gpoolBiasReshaped = uniqueName(nameCounter, prefix + "/gbr"); + addNode(graph, "Reshape", {gpoolBias, biasShape}, gpoolBiasReshaped); + + // Add bias to regular conv output + string regPlusBias = uniqueName(nameCounter, prefix + "/rpb"); + addNode(graph, "Add", {regOut, gpoolBiasReshaped}, regPlusBias); + + // Second half: BN→Act→Conv + string midAct = addBNActivationMask(graph, nameCounter, regPlusBias, desc.midBN, desc.midActivation, mask, prefix + "/mid"); + string final_ = addConvNode(graph, nameCounter, midAct, desc.finalConv, prefix + "/conv2"); + + // Residual add + string output = uniqueName(nameCounter, prefix + "/resadd"); + addNode(graph, "Add", {input, final_}, output); + return output; +} + +// ===================================================================== +// Nested Bottleneck Residual Block +// Pre: BN→Act→Mask→1x1Conv (c_main→c_mid) +// Inner: sequence of ordinary/gpool/nested_bottleneck sub-blocks at c_mid +// Post: BN→Act→Mask→1x1Conv (c_mid→c_main) + residual add +// ===================================================================== +static string addNestedBottleneckResidualBlock( + onnx::GraphProto* graph, + int& nameCounter, + const string& input, + const string& mask, + const string& maskSumHW, + const NestedBottleneckResidualBlockDesc& desc, + const string& prefix +) { + // Pre: BN → Act → Mask → 1x1 Conv (c_main → c_mid) + string pre = addBNActivationMask(graph, nameCounter, input, desc.preBN, desc.preActivation, mask, prefix + "/pre"); + string midOut = addConvNode(graph, nameCounter, pre, desc.preConv, prefix + "/preconv"); + + // Inner sub-blocks at c_mid channels + for(int i = 0; i < desc.numBlocks; i++) { + int kind = desc.blocks[i].first; + string sub = prefix + "/sub" + to_string(i); + if(kind == ORDINARY_BLOCK_KIND) { + midOut = addResidualBlock(graph, nameCounter, midOut, mask, + *((const ResidualBlockDesc*)desc.blocks[i].second.get()), sub); + } else if(kind == GLOBAL_POOLING_BLOCK_KIND) { + midOut = addGPoolResidualBlock(graph, nameCounter, midOut, mask, maskSumHW, + *((const GlobalPoolingResidualBlockDesc*)desc.blocks[i].second.get()), sub); + } else if(kind == NESTED_BOTTLENECK_BLOCK_KIND) { + midOut = addNestedBottleneckResidualBlock(graph, nameCounter, midOut, mask, maskSumHW, + *((const NestedBottleneckResidualBlockDesc*)desc.blocks[i].second.get()), sub); + } else { + throw StringError("ONNX backend: unknown sub-block kind " + to_string(kind)); + } + } + + // Post: BN → Act → Mask → 1x1 Conv (c_mid → c_main) + string post = addBNActivationMask(graph, nameCounter, midOut, desc.postBN, desc.postActivation, mask, prefix + "/post"); + string postOut = addConvNode(graph, nameCounter, post, desc.postConv, prefix + "/postconv"); + + // Residual add: input + postOut + string output = uniqueName(nameCounter, prefix + "/resadd"); + addNode(graph, "Add", {input, postOut}, output); + return output; +} + +// ===================================================================== +// Add ValueInfo for graph input/output +// ===================================================================== +static void addGraphInput( + onnx::GraphProto* graph, + const string& name, + const vector& shape +) { + onnx::ValueInfoProto* input = graph->add_input(); + input->set_name(name); + onnx::TypeProto* type = input->mutable_type(); + onnx::TypeProto_Tensor* tensorType = type->mutable_tensor_type(); + tensorType->set_elem_type(onnx::TensorProto_DataType_FLOAT); + onnx::TensorShapeProto* shapeProto = tensorType->mutable_shape(); + for(int64_t d : shape) { + auto* dim = shapeProto->add_dim(); + if(d < 0) + dim->set_dim_param("N"); + else + dim->set_dim_value(d); + } +} + +static void addGraphOutput( + onnx::GraphProto* graph, + const string& name, + const vector& shape +) { + onnx::ValueInfoProto* output = graph->add_output(); + output->set_name(name); + onnx::TypeProto* type = output->mutable_type(); + onnx::TypeProto_Tensor* tensorType = type->mutable_tensor_type(); + tensorType->set_elem_type(onnx::TensorProto_DataType_FLOAT); + onnx::TensorShapeProto* shapeProto = tensorType->mutable_shape(); + for(int64_t d : shape) { + auto* dim = shapeProto->add_dim(); + if(d < 0) + dim->set_dim_param("N"); + else + dim->set_dim_value(d); + } +} + +// ===================================================================== +// Main: Build the full ONNX model from ModelDesc +// ===================================================================== +string OnnxModelBuilder::buildOnnxModel(const ModelDesc& modelDesc, int nnXLen, int nnYLen) { + int nameCounter = 0; + + const int modelVersion = modelDesc.modelVersion; + const int numInputChannels = modelDesc.numInputChannels; + const int numInputGlobalChannels = modelDesc.numInputGlobalChannels; + const int numPolicyChannels = modelDesc.numPolicyChannels; + const int numValueChannels = modelDesc.numValueChannels; + const int numScoreValueChannels = modelDesc.numScoreValueChannels; + const int numOwnershipChannels = modelDesc.numOwnershipChannels; + + const TrunkDesc& trunk = modelDesc.trunk; + const PolicyHeadDesc& policyHead = modelDesc.policyHead; + const ValueHeadDesc& valueHead = modelDesc.valueHead; + + onnx::ModelProto model; + model.set_ir_version(8); + model.set_producer_name("KataGo"); + model.set_domain("ai.katago"); + + auto* opset = model.add_opset_import(); + opset->set_domain(""); + opset->set_version(18); + + onnx::GraphProto* graph = model.mutable_graph(); + graph->set_name("katago"); + + // ------------------------------------------------------------------ + // Graph Inputs + // ------------------------------------------------------------------ + addGraphInput(graph, "input_spatial", {-1, numInputChannels, nnYLen, nnXLen}); + addGraphInput(graph, "input_global", {-1, numInputGlobalChannels}); + if(modelDesc.numInputMetaChannels > 0) { + addGraphInput(graph, "input_meta", {-1, modelDesc.numInputMetaChannels}); + } + + // ------------------------------------------------------------------ + // Derive mask and maskSumHW from input_spatial. + // Channel 0 of the spatial input is the "on board" indicator: 1.0 for + // positions on the board, 0.0 for off-board padding. This is Feature 0 + // set by fillRowV3/V4/V5/V6/V7 in nninputs.cpp and holds across all + // supported input versions (V3-V7). + // + // mask = input_spatial[:, 0:1, :, :] → [N, 1, H, W] + // maskSumHW = ReduceSum(mask, [2, 3], keepdims=true) → [N, 1, 1, 1] + // ------------------------------------------------------------------ + + // Slice channel 0 to get mask + string sliceStarts = addInt64Initializer(graph, "mask_starts", {0}); + string sliceEnds = addInt64Initializer(graph, "mask_ends", {1}); + string sliceAxes = addInt64Initializer(graph, "mask_axes", {1}); + string mask = uniqueName(nameCounter, "mask"); + addNode(graph, "Slice", {"input_spatial", sliceStarts, sliceEnds, sliceAxes}, mask); + + // maskSumHW + string sumAxes = addInt64Initializer(graph, "mask_sum_axes", {2, 3}); + string maskSumHW = uniqueName(nameCounter, "maskSumHW"); + onnx::NodeProto* maskSumNode = addNode(graph, "ReduceSum", {mask, sumAxes}, maskSumHW); + setAttrInt(maskSumNode, "keepdims", 1); + + // ------------------------------------------------------------------ + // Trunk: Initial conv + matmul bias + // ------------------------------------------------------------------ + string trunkOut = addConvNode(graph, nameCounter, "input_spatial", trunk.initialConv, "trunk/init_conv"); + + // initialMatMul: global features → [N, trunkNumChannels] + string globalBias = addMatMulNode(graph, nameCounter, "input_global", trunk.initialMatMul, "trunk/init_matmul"); + + // Reshape to [N, C, 1, 1] for broadcasting + string biasShape = addInt64Initializer(graph, "trunk_bias_shape", {0, -1, 1, 1}); + string globalBiasReshaped = uniqueName(nameCounter, "trunk/gbr"); + addNode(graph, "Reshape", {globalBias, biasShape}, globalBiasReshaped); + + // Add global bias to conv output + string trunkCombined = uniqueName(nameCounter, "trunk/combined"); + addNode(graph, "Add", {trunkOut, globalBiasReshaped}, trunkCombined); + trunkOut = trunkCombined; + + // ------------------------------------------------------------------ + // Trunk: Metadata encoder (SGF metadata → trunk bias) + // ------------------------------------------------------------------ + if(trunk.metaEncoderVersion > 0) { + const SGFMetadataEncoderDesc& enc = trunk.sgfMetadataEncoder; + string metaOut = addMatMulNode(graph, nameCounter, "input_meta", enc.mul1, "trunk/meta_mul1"); + metaOut = addBiasNode(graph, nameCounter, metaOut, enc.bias1, "trunk/meta_b1"); + metaOut = addActivationNode(graph, nameCounter, metaOut, enc.act1.activation, "trunk/meta_a1"); + metaOut = addMatMulNode(graph, nameCounter, metaOut, enc.mul2, "trunk/meta_mul2"); + metaOut = addBiasNode(graph, nameCounter, metaOut, enc.bias2, "trunk/meta_b2"); + metaOut = addActivationNode(graph, nameCounter, metaOut, enc.act2.activation, "trunk/meta_a2"); + metaOut = addMatMulNode(graph, nameCounter, metaOut, enc.mul3, "trunk/meta_mul3"); + + // Reshape to [N, C, 1, 1] for spatial broadcasting + string metaBiasShape = addInt64Initializer(graph, "trunk_meta_bias_shape", {0, -1, 1, 1}); + string metaBiasReshaped = uniqueName(nameCounter, "trunk/mbr"); + addNode(graph, "Reshape", {metaOut, metaBiasShape}, metaBiasReshaped); + + // Add to trunk + string trunkWithMeta = uniqueName(nameCounter, "trunk/with_meta"); + addNode(graph, "Add", {trunkOut, metaBiasReshaped}, trunkWithMeta); + trunkOut = trunkWithMeta; + } + + // ------------------------------------------------------------------ + // Trunk: Residual blocks + // ------------------------------------------------------------------ + for(int i = 0; i < trunk.numBlocks; i++) { + int blockKind = trunk.blocks[i].first; + string blockPrefix = "trunk/block" + to_string(i); + + if(blockKind == ORDINARY_BLOCK_KIND) { + const ResidualBlockDesc& blockDesc = *((const ResidualBlockDesc*)trunk.blocks[i].second.get()); + trunkOut = addResidualBlock(graph, nameCounter, trunkOut, mask, blockDesc, blockPrefix); + } else if(blockKind == GLOBAL_POOLING_BLOCK_KIND) { + const GlobalPoolingResidualBlockDesc& blockDesc = *((const GlobalPoolingResidualBlockDesc*)trunk.blocks[i].second.get()); + trunkOut = addGPoolResidualBlock(graph, nameCounter, trunkOut, mask, maskSumHW, blockDesc, blockPrefix); + } else if(blockKind == NESTED_BOTTLENECK_BLOCK_KIND) { + const NestedBottleneckResidualBlockDesc& blockDesc = *((const NestedBottleneckResidualBlockDesc*)trunk.blocks[i].second.get()); + trunkOut = addNestedBottleneckResidualBlock(graph, nameCounter, trunkOut, mask, maskSumHW, blockDesc, blockPrefix); + } else { + throw StringError("ONNX backend: unknown block kind " + to_string(blockKind)); + } + } + + // Trunk tip: BN + activation + mask + trunkOut = addBNActivationMask(graph, nameCounter, trunkOut, trunk.trunkTipBN, trunk.trunkTipActivation, mask, "trunk/tip"); + + // ------------------------------------------------------------------ + // Policy Head + // ------------------------------------------------------------------ + + // p1Conv: spatial path + string p1Out = addConvNode(graph, nameCounter, trunkOut, policyHead.p1Conv, "policy/p1conv"); + + // g1Conv: global pooling path + string g1Out = addConvNode(graph, nameCounter, trunkOut, policyHead.g1Conv, "policy/g1conv"); + string g1BNAct = addBNActivationMask(graph, nameCounter, g1Out, policyHead.g1BN, policyHead.g1Activation, mask, "policy/g1bn"); + string g1Pool = addGlobalPool(graph, nameCounter, g1BNAct, mask, maskSumHW, "policy/g1pool"); + + // gpoolToBiasMul: [N, 3*g1C] → [N, p1C] + string policyBias = addMatMulNode(graph, nameCounter, g1Pool, policyHead.gpoolToBiasMul, "policy/g2b"); + + // Reshape to [N, C, 1, 1] + string pBiasShape = addInt64Initializer(graph, uniqueName(nameCounter, "policy/bias_shape"), {0, -1, 1, 1}); + string policyBiasReshaped = uniqueName(nameCounter, "policy/pbr"); + addNode(graph, "Reshape", {policyBias, pBiasShape}, policyBiasReshaped); + + // Add bias to p1 + string p1PlusBias = uniqueName(nameCounter, "policy/p1pb"); + addNode(graph, "Add", {p1Out, policyBiasReshaped}, p1PlusBias); + + // p1BN + activation + mask + string p1BNAct = addBNActivationMask(graph, nameCounter, p1PlusBias, policyHead.p1BN, policyHead.p1Activation, mask, "policy/p1bn"); + + // p2Conv: [N, p1C, H, W] → [N, policyChannels, H, W] + string p2Out = addConvNode(graph, nameCounter, p1BNAct, policyHead.p2Conv, "policy/p2conv"); + + // Reshape to [N, policyChannels, H*W] + string pSpatialShape = addInt64Initializer(graph, uniqueName(nameCounter, "policy/spat_shape"), {0, numPolicyChannels, -1}); + string policySpatial = uniqueName(nameCounter, "policy/spatial"); + addNode(graph, "Reshape", {p2Out, pSpatialShape}, policySpatial); + + // Pass move: gpoolToPassMul + string passOut; + if(modelVersion >= 15) { + // gpoolToPassMul → bias → activation → gpoolToPassMul2 + string passMul1 = addMatMulNode(graph, nameCounter, g1Pool, policyHead.gpoolToPassMul, "policy/pass_mul1"); + string passBiased = addBiasNode(graph, nameCounter, passMul1, policyHead.gpoolToPassBias, "policy/pass_bias"); + string passAct = addActivationNode(graph, nameCounter, passBiased, policyHead.passActivation.activation, "policy/pass_act"); + passOut = addMatMulNode(graph, nameCounter, passAct, policyHead.gpoolToPassMul2, "policy/pass_mul2"); + } else { + passOut = addMatMulNode(graph, nameCounter, g1Pool, policyHead.gpoolToPassMul, "policy/pass_mul"); + } + + // Reshape pass to [N, policyChannels, 1] + string passShape = addInt64Initializer(graph, uniqueName(nameCounter, "policy/pass_shape"), {0, numPolicyChannels, 1}); + string passReshaped = uniqueName(nameCounter, "policy/pass_r"); + addNode(graph, "Reshape", {passOut, passShape}, passReshaped); + + // Concat spatial + pass → out_policy [N, policyChannels, H*W+1] + onnx::NodeProto* policyConcatNode = addNode(graph, "Concat", {policySpatial, passReshaped}, "out_policy"); + setAttrInt(policyConcatNode, "axis", 2); + + // ------------------------------------------------------------------ + // Value Head + // ------------------------------------------------------------------ + + // v1Conv + string v1Out = addConvNode(graph, nameCounter, trunkOut, valueHead.v1Conv, "value/v1conv"); + + // v1BN + activation + mask + string v1BNAct = addBNActivationMask(graph, nameCounter, v1Out, valueHead.v1BN, valueHead.v1Activation, mask, "value/v1bn"); + + // Value head global pooling + string v1Pool = addValueHeadGPool(graph, nameCounter, v1BNAct, mask, maskSumHW, "value/vpool"); + + // v2Mul + v2Bias + v2Activation + string v2Out = addMatMulNode(graph, nameCounter, v1Pool, valueHead.v2Mul, "value/v2mul"); + string v2Biased = addBiasNode(graph, nameCounter, v2Out, valueHead.v2Bias, "value/v2bias"); + string v2Act = addActivationNode(graph, nameCounter, v2Biased, valueHead.v2Activation.activation, "value/v2act"); + + // v3Mul + v3Bias → out_value [N, 3] + string v3Out = addMatMulNode(graph, nameCounter, v2Act, valueHead.v3Mul, "value/v3mul"); + string v3Biased = addBiasNode(graph, nameCounter, v3Out, valueHead.v3Bias, "value/v3bias"); + addNode(graph, "Identity", {v3Biased}, "out_value"); + + // sv3Mul + sv3Bias → out_miscvalue [N, numScoreValueChannels] + string sv3Out = addMatMulNode(graph, nameCounter, v2Act, valueHead.sv3Mul, "value/sv3mul"); + string sv3Biased = addBiasNode(graph, nameCounter, sv3Out, valueHead.sv3Bias, "value/sv3bias"); + addNode(graph, "Identity", {sv3Biased}, "out_miscvalue"); + + // vOwnershipConv → out_ownership [N, 1, H, W] + string ownOut = addConvNode(graph, nameCounter, v1BNAct, valueHead.vOwnershipConv, "value/own_conv"); + addNode(graph, "Identity", {ownOut}, "out_ownership"); + + // ------------------------------------------------------------------ + // Graph Outputs + // ------------------------------------------------------------------ + int policyResultLen = nnXLen * nnYLen + 1; + addGraphOutput(graph, "out_policy", {-1, numPolicyChannels, policyResultLen}); + addGraphOutput(graph, "out_value", {-1, numValueChannels}); + addGraphOutput(graph, "out_miscvalue", {-1, numScoreValueChannels}); + addGraphOutput(graph, "out_ownership", {-1, numOwnershipChannels, nnYLen, nnXLen}); + + // ------------------------------------------------------------------ + // Serialize to string + // ------------------------------------------------------------------ + string serialized; + if(!model.SerializeToString(&serialized)) + throw StringError("ONNX backend: failed to serialize ONNX model to protobuf"); + + return serialized; +} diff --git a/cpp/neuralnet/onnxmodelbuilder.h b/cpp/neuralnet/onnxmodelbuilder.h new file mode 100644 index 000000000..96bc8e07a --- /dev/null +++ b/cpp/neuralnet/onnxmodelbuilder.h @@ -0,0 +1,14 @@ +#ifndef NEURALNET_ONNXMODELBUILDER_H_ +#define NEURALNET_ONNXMODELBUILDER_H_ + +#include +#include "../neuralnet/desc.h" + +namespace OnnxModelBuilder { + // Builds a serialized ONNX ModelProto from a KataGo ModelDesc. + // The model is constructed for a fixed spatial size of nnXLen x nnYLen. + // Returns the protobuf-serialized bytes, ready for Ort::Session creation. + std::string buildOnnxModel(const ModelDesc& modelDesc, int nnXLen, int nnYLen); +} + +#endif // NEURALNET_ONNXMODELBUILDER_H_ diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index 60baac228..2f4ea5571 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -20,6 +20,7 @@ std::vector Setup::getBackendPrefixes() { prefixes.push_back("metal"); prefixes.push_back("opencl"); prefixes.push_back("eigen"); + prefixes.push_back("onnx"); prefixes.push_back("dummybackend"); return prefixes; } @@ -88,6 +89,8 @@ vector Setup::initializeNNEvaluators( string backendPrefix = "opencl"; #elif defined(USE_EIGEN_BACKEND) string backendPrefix = "eigen"; + #elif defined(USE_ONNX_BACKEND) + string backendPrefix = "onnx"; #else string backendPrefix = "dummybackend"; #endif @@ -141,7 +144,7 @@ vector Setup::initializeNNEvaluators( requireExactNNLen = cfg.getBool("requireMaxBoardSize"); } - bool inputsUseNHWC = backendPrefix == "opencl" || backendPrefix == "trt" || backendPrefix == "metal" ? false : true; + bool inputsUseNHWC = backendPrefix == "opencl" || backendPrefix == "trt" || backendPrefix == "metal" || backendPrefix == "onnx" ? false : true; if(cfg.contains(backendPrefix+"InputsUseNHWC"+idxStr)) inputsUseNHWC = cfg.getBool(backendPrefix+"InputsUseNHWC"+idxStr); else if(cfg.contains("inputsUseNHWC"+idxStr)) @@ -220,9 +223,32 @@ vector Setup::initializeNNEvaluators( string homeDataDirOverride = loadHomeDataDirOverride(cfg); - string openCLTunerFile; + string backendExtraParam; + #if defined(USE_ONNX_BACKEND) + string onnxProvider = cfg.contains("onnxProvider") ? cfg.getString("onnxProvider") : "cpu"; + { + backendExtraParam = "provider=" + onnxProvider; + if(cfg.contains("onnxInputSpatial")) + backendExtraParam += ";inputSpatial=" + cfg.getString("onnxInputSpatial"); + if(cfg.contains("onnxInputGlobal")) + backendExtraParam += ";inputGlobal=" + cfg.getString("onnxInputGlobal"); + if(cfg.contains("onnxInputMeta")) + backendExtraParam += ";inputMeta=" + cfg.getString("onnxInputMeta"); + if(cfg.contains("onnxOutputPolicy")) + backendExtraParam += ";outputPolicy=" + cfg.getString("onnxOutputPolicy"); + if(cfg.contains("onnxOutputValue")) + backendExtraParam += ";outputValue=" + cfg.getString("onnxOutputValue"); + if(cfg.contains("onnxOutputMiscvalue")) + backendExtraParam += ";outputMiscvalue=" + cfg.getString("onnxOutputMiscvalue"); + if(cfg.contains("onnxOutputOwnership")) + backendExtraParam += ";outputOwnership=" + cfg.getString("onnxOutputOwnership"); + if(cfg.contains("onnxModelVersion")) + backendExtraParam += ";modelVersion=" + cfg.getString("onnxModelVersion"); + } + #else if(cfg.contains("openclTunerFile")) - openCLTunerFile = cfg.getString("openclTunerFile"); + backendExtraParam = cfg.getString("openclTunerFile"); + #endif bool openCLReTunePerBoardSize = false; if(cfg.contains("openclReTunePerBoardSize")) openCLReTunePerBoardSize = cfg.getBool("openclReTunePerBoardSize"); @@ -275,7 +301,29 @@ vector Setup::initializeNNEvaluators( setupFor == SETUP_FOR_ANALYSIS ? 17 : cfg.getInt("nnMutexPoolSizePowerOfTwo", -1, 24); -#ifndef USE_EIGEN_BACKEND +#if defined(USE_ONNX_BACKEND) + // ONNX backend: use small batch for CPU provider (like Eigen), normal for accelerators + int nnMaxBatchSize; + { + if(onnxProvider == "cpu" || onnxProvider.empty()) { + nnMaxBatchSize = 2; + cfg.markAllKeysUsedWithPrefix("nnMaxBatchSize"); + (void)defaultMaxBatchSize; + } else { + if(setupFor == SETUP_FOR_BENCHMARK || setupFor == SETUP_FOR_DISTRIBUTED) { + nnMaxBatchSize = defaultMaxBatchSize; + } + else if(defaultMaxBatchSize > 0) { + nnMaxBatchSize = + cfg.contains("nnMaxBatchSize") ? cfg.getInt("nnMaxBatchSize", 1, 65536) : + defaultMaxBatchSize; + } + else { + nnMaxBatchSize = cfg.getInt("nnMaxBatchSize", 1, 65536); + } + } + } +#elif !defined(USE_EIGEN_BACKEND) int nnMaxBatchSize; if(setupFor == SETUP_FOR_BENCHMARK || setupFor == SETUP_FOR_DISTRIBUTED) { nnMaxBatchSize = defaultMaxBatchSize; @@ -315,7 +363,7 @@ vector Setup::initializeNNEvaluators( nnCacheSizePowerOfTwo, nnMutexPoolSizePowerOfTwo, debugSkipNeuralNet, - openCLTunerFile, + backendExtraParam, homeDataDirOverride, openCLReTunePerBoardSize, useFP16Mode, diff --git a/cpp/runonnxtests.sh b/cpp/runonnxtests.sh new file mode 100755 index 000000000..2aff64733 --- /dev/null +++ b/cpp/runonnxtests.sh @@ -0,0 +1,43 @@ +#!/bin/bash -eux +set -o pipefail +{ +# --------------------------------------------------------------- +# ONNX backend integration tests +# +# Exercises three levels of the inference pipeline: +# 1. runtinynntests — tiny model, full pipeline (no external model) +# 2. testgpuerror -quick — FP32 unbatched vs batched comparison +# 3. runnnevalcanarytests — sanity checks on real game positions +# --------------------------------------------------------------- + +mkdir -p tests/scratch + +# 1. Tiny NN tests — self-contained, no external model needed +echo "=== runtinynntests ===" +./katago runtinynntests tests/scratch 1.0 \ + | grep -v ': nnRandSeed0 = ' \ + | grep -v 'finishing, processed' + +# 2. GPU error test (quick) — compares unbatched vs batched inference +# For CPU ONNX provider both paths are FP32, so errors should be near zero. +# Any ownership indexing bug would surface as large ownership error. +echo "=== testgpuerror -quick ===" +./katago testgpuerror \ + -config configs/gtp_example.cfg \ + -model tests/models/g170-b6c96-s175395328-d26788732.bin.gz \ + -quick \ + -override-config "nnRandSeed=forTesting,forDeterministicTesting=true" + +# 3. NN eval canary tests — sanity checks on 5 real game positions +# Uses symmetries 0, 3, 6 (same as runsearchtests.sh) +echo "=== runnnevalcanarytests ===" +./katago runnnevalcanarytests configs/gtp_example.cfg tests/models/g170e-b10c128-s1141046784-d204142634.bin.gz 0 \ + | grep -v ': nnRandSeed0 = ' +./katago runnnevalcanarytests configs/gtp_example.cfg tests/models/g170e-b10c128-s1141046784-d204142634.bin.gz 3 \ + | grep -v ': nnRandSeed0 = ' +./katago runnnevalcanarytests configs/gtp_example.cfg tests/models/g170e-b10c128-s1141046784-d204142634.bin.gz 6 \ + | grep -v ': nnRandSeed0 = ' + +echo "=== All ONNX tests passed ===" +exit 0 +} diff --git a/python/export_model_pytorch.py b/python/export_model_pytorch.py index 5c409d26d..4ae3768c8 100644 --- a/python/export_model_pytorch.py +++ b/python/export_model_pytorch.py @@ -35,9 +35,45 @@ parser.add_argument('-filename-prefix', help='filename prefix to save to within dir', required=True) parser.add_argument('-use-swa', help='Use SWA model', action="store_true", required=False) parser.add_argument('-export-14-as-15', help='Export model version 14 as 15', action="store_true", required=False) +parser.add_argument('-export-onnx', help='Also export an ONNX model', action="store_true", required=False) args = vars(parser.parse_args()) +class OnnxExportWrapper(torch.nn.Module): + """Wrapper that selects the outputs needed by the C++ ONNX backend.""" + + def __init__(self, model, version): + super().__init__() + self.model = model + self.version = version + + def forward(self, input_spatial, input_global, input_meta=None): + outputs = self.model(input_spatial, input_global, input_meta, extra_outputs=None) + out_policy = outputs[0][0] + out_value = outputs[0][1] + out_miscvalue = outputs[0][2] + out_moremiscvalue = outputs[0][3] + out_ownership = outputs[0][4] + + # Select policy channels based on export version. + # Channel indices into the raw policy head output: + # 0 — main policy (move selection probabilities) + # 5 — short-term-optimistic policy + # 6 — Q-value winloss policy (v16+) + # 7 — Q-value score policy (v16+) + if self.version <= 11: + out_policy = out_policy[:, 0:1, :] + elif self.version <= 15: + out_policy = out_policy[:, [0, 5], :] + else: + out_policy = out_policy[:, [0, 5, 6, 7], :] + + # Combine miscvalue (first 4) and moremiscvalue (first 2) into 6 channels + miscvalue = torch.cat([out_miscvalue[:, :4], out_moremiscvalue[:, :2]], dim=1) + + return out_policy, out_value, miscvalue, out_ownership + + def main(args): checkpoint_file = args["checkpoint"] export_dir = args["export_dir"] @@ -45,6 +81,7 @@ def main(args): filename_prefix = args["filename_prefix"] use_swa = args["use_swa"] export_14_as_15 = args["export_14_as_15"] + export_onnx = args["export_onnx"] os.makedirs(export_dir,exist_ok=True) @@ -444,6 +481,50 @@ def write_model(model): write_model(model) f.close() + # ONNX EXPORT ------------------------------------------------------------------- + if export_onnx: + logging.info("Exporting ONNX model...") + onnx_model = swa_model if swa_model is not None else model + onnx_model.eval() + wrapper = OnnxExportWrapper(onnx_model, version) + wrapper.eval() + + # Build dummy inputs (fixed 19x19 board) + num_bin_features = modelconfigs.get_num_bin_input_features(model_config) + num_global_features = modelconfigs.get_num_global_input_features(model_config) + dummy_spatial = torch.zeros(1, num_bin_features, 19, 19) + dummy_global = torch.zeros(1, num_global_features) + input_names = ["input_spatial", "input_global"] + dynamic_axes = { + "input_spatial": {0: "batch", 2: "height", 3: "width"}, + "input_global": {0: "batch"}, + "out_policy": {0: "batch", 2: "spatial"}, + "out_value": {0: "batch"}, + "out_miscvalue": {0: "batch"}, + "out_ownership": {0: "batch", 2: "height", 3: "width"}, + } + + if onnx_model.metadata_encoder is not None: + dummy_meta = torch.zeros(1, onnx_model.metadata_encoder.c_input) + dummy_input = (dummy_spatial, dummy_global, dummy_meta) + input_names.append("input_meta") + dynamic_axes["input_meta"] = {0: "batch"} + else: + dummy_input = (dummy_spatial, dummy_global) + + onnx_path = os.path.join(export_dir, filename_prefix + ".onnx") + torch.onnx.export( + wrapper, + dummy_input, + onnx_path, + input_names=input_names, + output_names=["out_policy", "out_value", "out_miscvalue", "out_ownership"], + dynamic_axes=dynamic_axes, + opset_version=17, + do_constant_folding=True, + ) + logging.info(f"ONNX model exported to: {onnx_path}") + with open(os.path.join(export_dir,"metadata.json"),"w") as f: train_state = other_state_dict["train_state"] data = {}