Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions tmva/sofie/inc/TMVA/ROperator_Softplus.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#ifndef TMVA_SOFIE_ROPERATOR_SOFTPLUS
#define TMVA_SOFIE_ROPERATOR_SOFTPLUS

#include "TMVA/SOFIE_common.hxx"
#include "TMVA/ROperator.hxx"
#include "TMVA/RModel.hxx"

#include <sstream>

namespace TMVA{
namespace Experimental{
namespace SOFIE{

template <typename T>
class ROperator_Softplus final : public ROperator
{

private:

std::string fNX;
std::string fNY;
std::vector<size_t> fShape;

public:
ROperator_Softplus(){}
ROperator_Softplus(std::string nameX, std::string nameY):
fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){
fInputTensorNames = { fNX };
fOutputTensorNames = { fNY };
}

std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
return input;
}

std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
return input;
}

void Initialize(RModel& model) override {
//input must be a graph input, or already initialized intermediate tensor.
if (model.CheckIfTensorAlreadyExist(fNX) == false){
throw std::runtime_error("TMVA SOFIE Softplus Op Input Tensor " + fNX + " is not found in model");
}
fShape = model.GetTensorShape(fNX);
model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
}

std::string Generate(std::string OpName) override {
OpName = "op_" + OpName;
if (fShape.empty()){
throw std::runtime_error("TMVA SOFIE Softplus operator called to Generate without being initialized first");
}
std::stringstream out;
size_t length = ConvertShapeToLength(fShape);

out << "\n//------ Softplus\n";
out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
out << SP << SP << "float x = tensor_" << fNX << "[id];\n";
out << SP << SP << "tensor_" << fNY << "[id] = (x >= 0x1.4000000000000p+4f) "
<< "? x : std::log1p(std::exp(x));\n";
out << SP << "}\n";
return out.str();
}

std::vector<std::string> GetStdLibs() override { return { std::string("cmath") };}
};

}//SOFIE
}//Experimental
}//TMVA


#endif //TMVA_SOFIE_ROPERATOR_SOFTPLUS
7 changes: 7 additions & 0 deletions tmva/sofie/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ if (BLAS_FOUND)
# Creating a Google Test for the automatic differentiation of Gemm_Call
ROOT_ADD_GTEST(TestGemmDerivative TestGemmDerivative.cxx LIBRARIES Core BLAS::BLAS)
endif()

# Softplus Operator Unit Test
# Tests threshold hexfloat, numerical stability, overflow protection
ROOT_ADD_GTEST(TestSofieSoftplus TestSofieSoftplus.cxx
LIBRARIES
ROOTTMVASofie
)
endif()

# Look for needed Python modules
Expand Down
208 changes: 208 additions & 0 deletions tmva/sofie/test/TestSofieSoftplus.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#include "TMVA/ROperator_Softplus.hxx"
#include "TMVA/RModel.hxx"

#include "gtest/gtest.h"

#include <cmath>
#include <string>
#include <vector>
#include <utility>
#include <algorithm>

using namespace TMVA::Experimental::SOFIE;

// Testing hexfloat threshold constant for overflow protection
TEST(SOFIE_Softplus, GenerateHexfloatConstants)
{
RModel model;
model.AddInputTensorInfo("input", ETensorType::FLOAT, std::vector<size_t>{1, 10});
model.AddOutputTensorNameList({"output"});

ROperator_Softplus<float> op("input", "output");
op.Initialize(model);

std::string code = op.Generate("softplus_test");

// Testing hexfloat threshold (20.0f) for overflow protection
EXPECT_TRUE(code.find("0x1.4000000000000p+4f") != std::string::npos)
<< "Generated code missing hexfloat threshold constant 0x1.4000000000000p+4f (20.0f)";

// Verify local variable x is declared for CSE-independent code
EXPECT_TRUE(code.find("float x =") != std::string::npos)
<< "Generated code should declare local variable 'float x' for clean emitted code";

// Verify ternary conditional structure
EXPECT_TRUE(code.find("?") != std::string::npos && code.find(":") != std::string::npos)
<< "Generated code should use ternary operator for threshold branch";
}

// Testing numerical stability functions
TEST(SOFIE_Softplus, GenerateStabilityFunctions)
{
RModel model;
model.AddInputTensorInfo("X", ETensorType::FLOAT, std::vector<size_t>{2, 5});
model.AddOutputTensorNameList({"Y"});

ROperator_Softplus<float> op("X", "Y");
op.Initialize(model);

std::string code = op.Generate("softplus_stability_test");

// Verify std::log1p is used (not std::log)
EXPECT_TRUE(code.find("std::log1p") != std::string::npos)
<< "Generated code must use std::log1p for numerical stability";

// Verify std::exp is used
EXPECT_TRUE(code.find("std::exp") != std::string::npos)
<< "Generated code must use std::exp";

// Verify std::log is NOT used (would indicate precision loss)
size_t log1p_pos = code.find("std::log1p");
size_t log_pos = code.find("std::log(");
EXPECT_TRUE(log_pos == std::string::npos || (log1p_pos != std::string::npos && log1p_pos < log_pos))
<< "Generated code should use std::log1p, not std::log";
}

// Testing numeric correctness in stable region
TEST(SOFIE_Softplus, NumericCorrectnessStableRegion)
{
const std::vector<std::pair<float, float>> referenceData = {
{-10.0f, std::log1p(std::exp(-10.0f))},
{ -5.0f, std::log1p(std::exp(-5.0f))},
{ -1.0f, std::log1p(std::exp(-1.0f))},
{ 0.0f, std::log1p(std::exp(0.0f))}, // ln(2) ≈ 0.693
{ 1.0f, std::log1p(std::exp(1.0f))},
{ 5.0f, std::log1p(std::exp(5.0f))},
{ 10.0f, std::log1p(std::exp(10.0f))},
{ 15.0f, std::log1p(std::exp(15.0f))},
};

// Proxy for generated logic with threshold
auto softplus_eval = [](float x) -> float {
return (x >= 0x1.4000000000000p+4f) ? x : std::log1p(std::exp(x));
};

for (const auto& [input, expected] : referenceData) {
float computed = softplus_eval(input);
float tol = 1e-6f;

EXPECT_NEAR(computed, expected, tol)
<< "Stable region mismatch at x = " << input;
}
}

// Testing threshold behavior for overflow protection
TEST(SOFIE_Softplus, NumericCorrectnessThreshold)
{
const std::vector<std::pair<float, float>> thresholdData = {
{ 20.0f, 20.0f}, // At threshold: passthrough
{ 25.0f, 25.0f}, // Above threshold: passthrough
{ 50.0f, 50.0f}, // Far above: passthrough
{100.0f, 100.0f}, // Extreme: would overflow exp() without threshold
{1000.0f, 1000.0f}, // Very extreme: definite overflow without protection
};

auto softplus_eval = [](float x) -> float {
return (x >= 0x1.4000000000000p+4f) ? x : std::log1p(std::exp(x));
};

for (const auto& [input, expected] : thresholdData) {
float computed = softplus_eval(input);
float tol = 1e-6f;

EXPECT_NEAR(computed, expected, tol)
<< "Threshold behavior mismatch at x = " << input;

// Ensure no NaN or Inf
EXPECT_FALSE(std::isnan(computed)) << "NaN at x = " << input;
EXPECT_FALSE(std::isinf(computed)) << "Inf at x = " << input;
}
}

// Testing specific known values
TEST(SOFIE_Softplus, KnownValues)
{
auto softplus_eval = [](float x) -> float {
return (x >= 0x1.4000000000000p+4f) ? x : std::log1p(std::exp(x));
};

float tol = 1e-6f;

// ln(1 + e^0) = ln(2)
EXPECT_NEAR(softplus_eval(0.0f), 0.6931471805599453f, tol);

// For large negative x: ln(1 + e^x) ≈ e^x ≈ 0
EXPECT_NEAR(softplus_eval(-20.0f), std::exp(-20.0f), tol);

// At threshold: exact passthrough
EXPECT_NEAR(softplus_eval(20.0f), 20.0f, tol);

// Just below threshold: computed value
float x = 19.9f;
EXPECT_NEAR(softplus_eval(x), std::log1p(std::exp(x)), tol);
}

// StdLib dependencies
TEST(SOFIE_Softplus, StdLibDependencies)
{
ROperator_Softplus<float> op("in", "out");
auto libs = op.GetStdLibs();
ASSERT_EQ(libs.size(), 1u);
EXPECT_EQ(libs[0], "cmath");
}

// Type and Shape Inference
TEST(SOFIE_Softplus, Inference)
{
ROperator_Softplus<float> op("in", "out");

// Type inference
auto types = op.TypeInference({ETensorType::FLOAT});
EXPECT_EQ(types[0], ETensorType::FLOAT);

// Shape inference
std::vector<size_t> shape = {4, 16, 32};
auto shapes = op.ShapeInference({shape});
EXPECT_EQ(shapes[0], shape);
}

// Error Handling
TEST(SOFIE_Softplus, ErrorHandling)
{
ROperator_Softplus<float> op("in", "out");

// Generate without Initialize
EXPECT_THROW(op.Generate("test"), std::runtime_error);

// Initialize with missing tensor
RModel model;
EXPECT_THROW(op.Initialize(model), std::runtime_error);
}

// Loop structure verification
TEST(SOFIE_Softplus, GenerateStructure)
{
RModel model;
model.AddInputTensorInfo("X", ETensorType::FLOAT, std::vector<size_t>{2, 5});
model.AddOutputTensorNameList({"Y"});

ROperator_Softplus<float> op("X", "Y");
op.Initialize(model);

std::string code = op.Generate("softplus_struct_test");

EXPECT_TRUE(code.find("tensor_Y") != std::string::npos) << "Missing output tensor access";
EXPECT_TRUE(code.find("tensor_X") != std::string::npos) << "Missing input tensor access";
// Loop limit check for shape {2, 5}
EXPECT_TRUE(code.find("10") != std::string::npos) << "Incorrect loop limit generated";
// Operator comment
EXPECT_TRUE(code.find("Softplus") != std::string::npos) << "Missing operator comment";
}

// Threshold constant verification (20.0f as hexfloat)
TEST(SOFIE_Softplus, ThresholdConstantValue)
{
// Verify the hexfloat threshold equals 20.0f exactly
float threshold = 0x1.4000000000000p+4f;
EXPECT_FLOAT_EQ(threshold, 20.0f);
}
16 changes: 16 additions & 0 deletions tmva/sofie_parsers/src/RModelParser_ONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <unordered_map>
#include <functional>
#include "TMVA/SOFIE_common.hxx"
#include "TMVA/ROperator_Softplus.hxx"

namespace TMVA {
namespace Experimental {
Expand Down Expand Up @@ -234,6 +235,21 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un
RegisterOperator("RandomUniform", ParseRandom);
RegisterOperator("RandomUniformLike", ParseRandom);
RegisterOperator("ScatterElements", ParseScatterElements);

// Softplus operator with inline lambda registration (no attributes)
RegisterOperator("Softplus", [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
auto input_name = nodeproto.input(0);
if (!parser.IsRegisteredTensorType(input_name)) {
throw std::runtime_error("TMVA::SOFIE ONNX Parser Softplus op has input tensor " +
input_name + " but its type is not yet registered");
}
std::string output_name = nodeproto.output(0);
auto op = std::make_unique<ROperator_Softplus<float>>(input_name, output_name);
if (!parser.IsRegisteredTensorType(output_name)) {
parser.RegisterTensorType(output_name, parser.GetTensorType(input_name));
}
return op;
});
}

// Destructor of the parser
Expand Down