Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions tmva/sofie/inc/TMVA/ROperator_GELU.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#ifndef TMVA_SOFIE_ROPERATOR_GELU
#define TMVA_SOFIE_ROPERATOR_GELU

#include "TMVA/SOFIE_common.hxx"
#include "TMVA/ROperator.hxx"
#include "TMVA/RModel.hxx"

#include <sstream>

namespace TMVA{
namespace Experimental{
namespace SOFIE{

template <typename T>
class ROperator_GELU final : public ROperator
{

private:

std::string fNX;
std::string fNY;
std::vector<size_t> fShape;

public:
ROperator_GELU(){}
ROperator_GELU(std::string nameX, std::string nameY):
fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){
fInputTensorNames = { fNX };
fOutputTensorNames = { fNY };
}

std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
return input;
}

std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
auto ret = input; //suggest copy to compiler
return ret;
}

void Initialize(RModel& model) override {
//input must be a graph input, or already initialized intermediate tensor
if (model.CheckIfTensorAlreadyExist(fNX) == false){
throw std::runtime_error("TMVA SOFIE GELU Op Input Tensor " + fNX + " is not found in model");
}
fShape = model.GetTensorShape(fNX);
model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
}

std::string Generate(std::string OpName) override {
OpName = "op_" + OpName;
if (fShape.empty()){
throw std::runtime_error("TMVA SOFIE GELU operator called to Generate without being initialized first");
}
std::stringstream out;
size_t length = ConvertShapeToLength(fShape);

// GELU exact formula: y = 0.5 * x * (1 + erf(x / sqrt(2)))
// Using hexfloat for compile-time precision:
// 0x1.6a09e667f3bcdp-1 = 1/sqrt(2) = 0.7071067811865476 (exact to be precise)
// 0x1.0000000000000p-1 = 0.5

out << "\n//------ GELU\n";
out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
out << SP << SP << "tensor_" << fNY << "[id] = 0x1.0000000000000p-1 * tensor_" << fNX
<< "[id] * (1.0 + std::erf(tensor_" << fNX << "[id] * 0x1.6a09e667f3bcdp-1));\n";
out << SP << "}\n";
return out.str();
}

std::vector<std::string> GetStdLibs() override { return { std::string("cmath") };}
};

}//SOFIE
}//Experimental
}//TMVA


#endif //TMVA_SOFIE_ROPERATOR_GELU
7 changes: 7 additions & 0 deletions tmva/sofie/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ if (BLAS_FOUND)
# Creating a Google Test for the automatic differentiation of Gemm_Call
ROOT_ADD_GTEST(TestGemmDerivative TestGemmDerivative.cxx LIBRARIES Core BLAS::BLAS)
endif()

# GELU Operator Unit Test
# Tests code generation with hexfloat constants for bit-exact reproducibility
ROOT_ADD_GTEST(TestSofieGELU TestSofieGELU.cxx
LIBRARIES
ROOTTMVASofie
)
endif()

# Look for needed Python modules
Expand Down
127 changes: 127 additions & 0 deletions tmva/sofie/test/TestSofieGELU.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/// \file TestSofieGELU.cxx
/// \brief Unit tests for the SOFIE GELU operator
/// \author ROOT TMVA Team

#include "TMVA/ROperator_GELU.hxx"
#include "TMVA/RModel.hxx"

#include "gtest/gtest.h"

#include <cmath>
#include <string>
#include <vector>
#include <utility>
#include <algorithm>

using namespace TMVA::Experimental::SOFIE;

// Validate generation of hexfloat constants for bit-exact reproducibility
TEST(SOFIE_GELU, GenerateHexfloatConstants)
{
RModel model;
model.AddInputTensorInfo("input", ETensorType::FLOAT, {1, 10});
model.AddOutputTensorNameList({"output"});

ROperator_GELU<float> op("input", "output");
op.Initialize(model);

std::string code = op.Generate("gelu_test");

// Expect 1/sqrt(2) as hexfloat: 0x1.6a09e667f3bcdp-1
EXPECT_TRUE(code.find("0x1.6a09e667f3bcdp-1") != std::string::npos)
<< "Generated code missing optimized hexfloat constant for 1/sqrt(2)";

// Expect 0.5 as hexfloat
EXPECT_TRUE(code.find("0x1.0000000000000p-1") != std::string::npos)
<< "Generated code missing hexfloat constant for 0.5";
}

// Check structure of generated C++ code
TEST(SOFIE_GELU, GenerateStructure)
{
RModel model;
model.AddInputTensorInfo("X", ETensorType::FLOAT, {2, 5});
model.AddOutputTensorNameList({"Y"});

ROperator_GELU<float> op("X", "Y");
op.Initialize(model);

std::string code = op.Generate("gelu_struct_test");

EXPECT_TRUE(code.find("std::erf") != std::string::npos) << "Missing std::erf call";
EXPECT_TRUE(code.find("tensor_Y") != std::string::npos) << "Missing output tensor access";
EXPECT_TRUE(code.find("tensor_X") != std::string::npos) << "Missing input tensor access";
// Loop limit check for shape {2, 5}
EXPECT_TRUE(code.find("10") != std::string::npos) << "Incorrect loop limit generated";
}

// Compare implementation against SciPy reference values
TEST(SOFIE_GELU, NumericCorrectness)
{
// Reference values computed using scipy.special.erf
const std::vector<std::pair<float, float>> referenceData = {
{-3.0f, -0.00404996f},
{-2.5f, -0.01974636f},
{-2.0f, -0.04540230f},
{-1.5f, -0.08771890f},
{-1.0f, -0.15880800f},
{-0.5f, -0.15426877f},
{ 0.0f, 0.00000000f},
{ 0.5f, 0.34573123f},
{ 1.0f, 0.84119201f},
{ 1.5f, 1.41281096f},
{ 2.0f, 1.95459771f},
{ 2.5f, 2.48025364f},
{ 3.0f, 2.99595003f},
{-10.0f, 0.0f}, // Limit -> 0
{ 10.0f, 10.0f} // Limit -> x
};

// Proxy for generated logic
auto gelu_eval = [](float x) -> float {
constexpr double kInvSqrt2 = 0x1.6a09e667f3bcdp-1;
return 0.5f * x * (1.0 + std::erf(x * kInvSqrt2));
};

for (const auto& [input, expected] : referenceData) {
float computed = gelu_eval(input);
float tol = std::max(1e-6f * std::abs(expected), 1e-7f);

EXPECT_NEAR(computed, expected, tol)
<< "Mismatch at x = " << input;
}
}

TEST(SOFIE_GELU, StdLibDependencies)
{
ROperator_GELU<float> op("in", "out");
auto libs = op.GetStdLibs();
ASSERT_EQ(libs.size(), 1u);
EXPECT_EQ(libs[0], "cmath");
}

TEST(SOFIE_GELU, Inference)
{
ROperator_GELU<float> op("in", "out");

// Type inference
auto types = op.TypeInference({ETensorType::FLOAT});
EXPECT_EQ(types[0], ETensorType::FLOAT);

// Shape inference
std::vector<size_t> shape = {4, 16, 32};
auto shapes = op.ShapeInference({shape});
EXPECT_EQ(shapes[0], shape);
}

TEST(SOFIE_GELU, ErrorHandling)
{
ROperator_GELU<float> op("in", "out");

// Generate without Initialize
EXPECT_THROW(op.Generate("test"), std::runtime_error);

// Initialize with missing tensor
RModel model;
EXPECT_THROW(op.Initialize(model), std::runtime_error);
}
23 changes: 23 additions & 0 deletions tmva/sofie_parsers/src/RModelParser_ONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <unordered_map>
#include <functional>
#include "TMVA/SOFIE_common.hxx"
#include "TMVA/ROperator_GELU.hxx"

namespace TMVA {
namespace Experimental {
Expand Down Expand Up @@ -220,6 +221,28 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un
RegisterOperator("Gather", ParseGather);
RegisterOperator("Erf", ParseErf);
RegisterOperator("Elu", ParseElu);
// GELU operator with inline lambda registration
RegisterOperator("Gelu", [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
// Check for unsupported tanh approximation attribute
for (int i = 0; i < nodeproto.attribute_size(); i++) {
if (nodeproto.attribute(i).name() == "approximate") {
if (nodeproto.attribute(i).s() == "tanh") {
throw std::runtime_error("TMVA::SOFIE GELU tanh approximation not implemented");
}
}
}
auto input_name = nodeproto.input(0);
if (!parser.IsRegisteredTensorType(input_name)) {
throw std::runtime_error("TMVA::SOFIE ONNX Parser Gelu op has input tensor " +
input_name + " but its type is not yet registered");
}
std::string output_name = nodeproto.output(0);
auto op = std::make_unique<ROperator_GELU<float>>(input_name, output_name);
if (!parser.IsRegisteredTensorType(output_name)) {
parser.RegisterTensorType(output_name, parser.GetTensorType(input_name));
}
return op;
});
RegisterOperator("EyeLike", ParseEyeLike);
RegisterOperator("Range", ParseRange);
RegisterOperator("TopK", ParseTopK);
Expand Down