From 93f8ddff9670adb9bb78533e6b9e7f668cbd56d5 Mon Sep 17 00:00:00 2001 From: Shlok Date: Tue, 30 Dec 2025 15:15:29 +0000 Subject: [PATCH 1/4] TMVA(SOFIE): add GELU activation operator --- tmva/sofie/inc/TMVA/ROperator_Gelu.hxx | 79 ++++++++++++++++++++ tmva/sofie_parsers/CMakeLists.txt | 1 + tmva/sofie_parsers/src/ParseGelu.cxx | 39 ++++++++++ tmva/sofie_parsers/src/RModelParser_ONNX.cxx | 2 + 4 files changed, 121 insertions(+) create mode 100644 tmva/sofie/inc/TMVA/ROperator_Gelu.hxx create mode 100644 tmva/sofie_parsers/src/ParseGelu.cxx diff --git a/tmva/sofie/inc/TMVA/ROperator_Gelu.hxx b/tmva/sofie/inc/TMVA/ROperator_Gelu.hxx new file mode 100644 index 0000000000000..76db67a803d3c --- /dev/null +++ b/tmva/sofie/inc/TMVA/ROperator_Gelu.hxx @@ -0,0 +1,79 @@ +#ifndef TMVA_SOFIE_ROPERATOR_GELU +#define TMVA_SOFIE_ROPERATOR_GELU + +#include "TMVA/SOFIE_common.hxx" +#include "TMVA/ROperator.hxx" +#include "TMVA/RModel.hxx" + +#include + +namespace TMVA{ +namespace Experimental{ +namespace SOFIE{ + +template +class ROperator_Gelu final : public ROperator +{ + +private: + + std::string fNX; + std::string fNY; + std::vector fShape; + +public: + ROperator_Gelu(){} + ROperator_Gelu(std::string nameX, std::string nameY): + fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){ + fInputTensorNames = { fNX }; + fOutputTensorNames = { fNY }; + } + + std::vector TypeInference(std::vector input) override { + return input; + } + + std::vector> ShapeInference(std::vector> input) override { + auto ret = input; //suggest copy to compiler + return ret; + } + + void Initialize(RModel& model) override { + if (model.CheckIfTensorAlreadyExist(fNX) == false){ //input must be a graph input, or already initialized intermediate tensor + throw std::runtime_error("TMVA SOFIE Gelu Op Input Tensor " + fNX + " is not found in model"); + } + + fShape = model.GetDimTensorShape(fNX); + + model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape); + if (model.Verbose()) { + std::cout << "Gelu : " << fNX << " -> " << fNY << " " << ConvertShapeToString(fShape) << std::endl; + } + } + + std::string Generate(std::string OpName) override { + OpName = "op_" + OpName; + if (fShape.empty()) { + throw std::runtime_error("TMVA SOFIE Operator Gelu called to Generate without being initialized first"); + } + std::stringstream out; + auto length = ConvertDynamicShapeToLength(fShape); + out << "\n//------ GELU\n"; + out << SP << "for (int id = 0; id < " << length << " ; id++){\n"; + out << SP << SP + << "tensor_" << fNY << "[id] = 0.5 * tensor_" << fNX << "[id] * " + << "(1 + std::tanh(0.7978845608 * " + << "(tensor_" << fNX << "[id] + 0.044715 * " + << "tensor_" << fNX << "[id] * tensor_" << fNX << "[id] * tensor_" << fNX << "[id])));\n"; + out << SP << "}\n"; + return out.str(); + } + +}; + +}//SOFIE +}//Experimental +}//TMVA + + +#endif //TMVA_SOFIE_ROPERATOR_GELU diff --git a/tmva/sofie_parsers/CMakeLists.txt b/tmva/sofie_parsers/CMakeLists.txt index 4ad063693fbe5..56a69f4d43e64 100644 --- a/tmva/sofie_parsers/CMakeLists.txt +++ b/tmva/sofie_parsers/CMakeLists.txt @@ -41,6 +41,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofieParser src/ParsePool.cxx src/ParseReduce.cxx src/ParseRelu.cxx + src/ParseGelu.cxx src/ParseReshape.cxx src/ParseRNN.cxx src/ParseSelu.cxx diff --git a/tmva/sofie_parsers/src/ParseGelu.cxx b/tmva/sofie_parsers/src/ParseGelu.cxx new file mode 100644 index 0000000000000..d7c9269451de6 --- /dev/null +++ b/tmva/sofie_parsers/src/ParseGelu.cxx @@ -0,0 +1,39 @@ +#include "TMVA/RModelParser_ONNX.hxx" +#include "TMVA/ROperator_Gelu.hxx" +#include "onnx_proto3.pb.h" + +namespace TMVA { +namespace Experimental { +namespace SOFIE { + +ParserFuncSignature ParseGelu = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + ETensorType input_type; + + auto input_name = nodeproto.input(0); + if (parser.IsRegisteredTensorType(input_name)) { + input_type = parser.GetTensorType(input_name); + } else { + throw std::runtime_error("TMVA::SOFIE ONNX Parser Gelu op has input tensor" + input_name + + " but its type is not yet registered"); + } + + std::unique_ptr op; + std::string output_name = nodeproto.output(0); + + switch (input_type) { + case ETensorType::FLOAT: op.reset(new ROperator_Gelu(input_name, output_name)); break; + default: + throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Gelu does not yet support input type " + + std::to_string(static_cast(input_type))); + } + + if (!parser.IsRegisteredTensorType(output_name)) { + parser.RegisterTensorType(output_name, input_type); + } + + return op; +}; + +} // namespace SOFIE +} // namespace Experimental +} // namespace TMVA diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index 7b4ade2b6bc09..e3c2bc4664626 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -51,6 +51,7 @@ extern ParserFuncSignature ParseReduceProd; extern ParserFuncSignature ParseBatchNormalization; extern ParserFuncSignature ParseConstant; extern ParserFuncSignature ParseTranspose; +extern ParserFuncSignature ParseGelu; extern ParserFuncSignature ParseRelu; extern ParserFuncSignature ParseTanh; extern ParserFuncSignature ParseConv; @@ -201,6 +202,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("AveragePool", ParsePool); RegisterOperator("GlobalAveragePool", ParsePool); RegisterOperator("MaxPool", ParsePool); + RegisterOperator("Gelu", ParseGelu); RegisterOperator("Relu", ParseRelu); RegisterOperator("Reshape", ParseReshape); RegisterOperator("Flatten", ParseReshape); From 43122f662f1c2666544339a479b79bb7771f55fd Mon Sep 17 00:00:00 2001 From: Shlok Date: Mon, 5 Jan 2026 14:23:33 +0000 Subject: [PATCH 2/4] SOFIE: register GELU in operator list --- tmva/sofie/inc/TMVA/OperatorList.hxx | 1 + 1 file changed, 1 insertion(+) diff --git a/tmva/sofie/inc/TMVA/OperatorList.hxx b/tmva/sofie/inc/TMVA/OperatorList.hxx index 309a0fc703147..1eb72874cf15d 100644 --- a/tmva/sofie/inc/TMVA/OperatorList.hxx +++ b/tmva/sofie/inc/TMVA/OperatorList.hxx @@ -1,6 +1,7 @@ #include "TMVA/ROperator_Transpose.hxx" #include "TMVA/ROperator_Gemm.hxx" #include "TMVA/ROperator_Relu.hxx" +#include "TMVA/ROperator_Gelu.hxx" #include "TMVA/ROperator_Tanh.hxx" #include "TMVA/ROperator_LeakyRelu.hxx" #include "TMVA/ROperator_Selu.hxx" From 76e8fc346e906cc0cfb35d18ad53ee93f04e021a Mon Sep 17 00:00:00 2001 From: Shlok Date: Fri, 9 Jan 2026 14:30:57 +0000 Subject: [PATCH 3/4] [tmva][sofie] Use erf-based GELU definition --- tmva/sofie/inc/TMVA/ROperator_Gelu.hxx | 45 ++++++++++++++------------ 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/tmva/sofie/inc/TMVA/ROperator_Gelu.hxx b/tmva/sofie/inc/TMVA/ROperator_Gelu.hxx index 76db67a803d3c..db10d1bb4c252 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Gelu.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Gelu.hxx @@ -6,6 +6,7 @@ #include "TMVA/RModel.hxx" #include +#include namespace TMVA{ namespace Experimental{ @@ -34,39 +35,44 @@ public: } std::vector> ShapeInference(std::vector> input) override { - auto ret = input; //suggest copy to compiler + auto ret = input; // suggest copy to compiler return ret; } void Initialize(RModel& model) override { - if (model.CheckIfTensorAlreadyExist(fNX) == false){ //input must be a graph input, or already initialized intermediate tensor - throw std::runtime_error("TMVA SOFIE Gelu Op Input Tensor " + fNX + " is not found in model"); + if (model.CheckIfTensorAlreadyExist(fNX) == false){ + throw std::runtime_error( + "TMVA SOFIE Gelu Op Input Tensor " + fNX + " is not found in model"); } fShape = model.GetDimTensorShape(fNX); model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape); if (model.Verbose()) { - std::cout << "Gelu : " << fNX << " -> " << fNY << " " << ConvertShapeToString(fShape) << std::endl; + std::cout << "Gelu : " << fNX << " -> " << fNY << " " + << ConvertShapeToString(fShape) << std::endl; } } std::string Generate(std::string OpName) override { - OpName = "op_" + OpName; - if (fShape.empty()) { - throw std::runtime_error("TMVA SOFIE Operator Gelu called to Generate without being initialized first"); - } - std::stringstream out; - auto length = ConvertDynamicShapeToLength(fShape); - out << "\n//------ GELU\n"; - out << SP << "for (int id = 0; id < " << length << " ; id++){\n"; - out << SP << SP - << "tensor_" << fNY << "[id] = 0.5 * tensor_" << fNX << "[id] * " - << "(1 + std::tanh(0.7978845608 * " - << "(tensor_" << fNX << "[id] + 0.044715 * " - << "tensor_" << fNX << "[id] * tensor_" << fNX << "[id] * tensor_" << fNX << "[id])));\n"; - out << SP << "}\n"; - return out.str(); + OpName = "op_" + OpName; + if (fShape.empty()) { + throw std::runtime_error( + "TMVA SOFIE Operator Gelu called to Generate without being initialized first"); + } + + std::stringstream out; + auto length = ConvertDynamicShapeToLength(fShape); + + out << "\n//------ GELU (exact, erf-based)\n"; + out << SP << "for (int id = 0; id < " << length << " ; id++){\n"; + out << SP << SP + << "tensor_" << fNY << "[id] = " + << "tensor_" << fNX << "[id] * 0.5 * " + << "(1.0 + std::erf(tensor_" << fNX << "[id] / std::sqrt(2.0)));\n"; + out << SP << "}\n"; + + return out.str(); } }; @@ -75,5 +81,4 @@ public: }//Experimental }//TMVA - #endif //TMVA_SOFIE_ROPERATOR_GELU From 2194bdbb26131548f2424ac23640adda4457df59 Mon Sep 17 00:00:00 2001 From: Shlok Date: Tue, 13 Jan 2026 21:26:00 +0000 Subject: [PATCH 4/4] [tmva][sofie] Handle ONNX Gelu approximate attribute --- tmva/sofie_parsers/src/ParseGelu.cxx | 29 +++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tmva/sofie_parsers/src/ParseGelu.cxx b/tmva/sofie_parsers/src/ParseGelu.cxx index d7c9269451de6..df68c88d81961 100644 --- a/tmva/sofie_parsers/src/ParseGelu.cxx +++ b/tmva/sofie_parsers/src/ParseGelu.cxx @@ -2,6 +2,9 @@ #include "TMVA/ROperator_Gelu.hxx" #include "onnx_proto3.pb.h" +#include +#include + namespace TMVA { namespace Experimental { namespace SOFIE { @@ -9,11 +12,35 @@ namespace SOFIE { ParserFuncSignature ParseGelu = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { ETensorType input_type; + // --- Handle ONNX Gelu attribute: approximate --- + // ONNX Gelu has attribute "approximate": "none" (default) or "tanh" + std::string approximate = "none"; + for (const auto &attr : nodeproto.attribute()) { + if (attr.name() == "approximate") { + if (attr.type() != onnx::AttributeProto::STRING) + throw std::runtime_error( + "TMVA::SOFIE ONNX Parser: Gelu attribute 'approximate' must be a string"); + + approximate = attr.s(); + } + } + + if (approximate != "none") { + if (approximate == "tanh") { + throw std::runtime_error( + "TMVA::SOFIE ONNX Parser: Gelu attribute approximate='tanh' not supported yet"); + } + + throw std::runtime_error( + "TMVA::SOFIE ONNX Parser: Gelu attribute approximate='" + approximate + + "' not supported (expected 'none' or 'tanh')"); + } + auto input_name = nodeproto.input(0); if (parser.IsRegisteredTensorType(input_name)) { input_type = parser.GetTensorType(input_name); } else { - throw std::runtime_error("TMVA::SOFIE ONNX Parser Gelu op has input tensor" + input_name + + throw std::runtime_error("TMVA::SOFIE ONNX Parser Gelu op has input tensor " + input_name + " but its type is not yet registered"); }