11#include " pyoptinterface/nlexpr.hpp"
22
33#include < cassert>
4+ #include < cmath>
45#include " fmt/core.h"
56
67bool ExpressionHandle::operator ==(const ExpressionHandle &x) const
@@ -26,6 +27,8 @@ std::string ExpressionHandle::to_string() const
2627 return fmt::format (" t{}" , id);
2728 case ArrayType::Nary:
2829 return fmt::format (" n{}" , id);
30+ default :
31+ return fmt::format (" ?{}" , id);
2932 }
3033}
3134
@@ -147,13 +150,92 @@ ExpressionHandle ExpressionGraph::add_parameter(EntityId id)
147150
148151ExpressionHandle ExpressionGraph::add_unary (UnaryOperator op, ExpressionHandle operand)
149152{
153+ // Constant folding: if the operand is a constant, compute the result directly
154+ if (operand.array == ArrayType::Constant)
155+ {
156+ double val = m_constants[operand.id ];
157+ double result;
158+ switch (op)
159+ {
160+ case UnaryOperator::Neg:
161+ result = -val;
162+ break ;
163+ case UnaryOperator::Sin:
164+ result = std::sin (val);
165+ break ;
166+ case UnaryOperator::Cos:
167+ result = std::cos (val);
168+ break ;
169+ case UnaryOperator::Tan:
170+ result = std::tan (val);
171+ break ;
172+ case UnaryOperator::Asin:
173+ result = std::asin (val);
174+ break ;
175+ case UnaryOperator::Acos:
176+ result = std::acos (val);
177+ break ;
178+ case UnaryOperator::Atan:
179+ result = std::atan (val);
180+ break ;
181+ case UnaryOperator::Abs:
182+ result = std::abs (val);
183+ break ;
184+ case UnaryOperator::Sqrt:
185+ result = std::sqrt (val);
186+ break ;
187+ case UnaryOperator::Exp:
188+ result = std::exp (val);
189+ break ;
190+ case UnaryOperator::Log:
191+ result = std::log (val);
192+ break ;
193+ case UnaryOperator::Log10:
194+ result = std::log10 (val);
195+ break ;
196+ default :
197+ // Unknown operator, fall through to create the node
198+ goto create_node;
199+ }
200+ return add_constant (result);
201+ }
202+ create_node:
150203 m_unaries.emplace_back (op, operand);
151204 return {ArrayType::Unary, static_cast <NodeId>(m_unaries.size () - 1 )};
152205}
153206
154207ExpressionHandle ExpressionGraph::add_binary (BinaryOperator op, ExpressionHandle left,
155208 ExpressionHandle right)
156209{
210+ // Constant folding: if both operands are constants, compute the result directly
211+ // Note: comparison operators are not folded as they produce boolean results
212+ if (left.array == ArrayType::Constant && right.array == ArrayType::Constant &&
213+ !is_binary_compare_op (op))
214+ {
215+ double lval = m_constants[left.id ];
216+ double rval = m_constants[right.id ];
217+ double result;
218+ switch (op)
219+ {
220+ case BinaryOperator::Sub:
221+ result = lval - rval;
222+ break ;
223+ case BinaryOperator::Div:
224+ result = lval / rval;
225+ break ;
226+ case BinaryOperator::Pow:
227+ result = std::pow (lval, rval);
228+ break ;
229+ case BinaryOperator::Mul2:
230+ result = lval * rval;
231+ break ;
232+ default :
233+ // Comparison operators or unknown, fall through to create the node
234+ goto create_node;
235+ }
236+ return add_constant (result);
237+ }
238+ create_node:
157239 m_binaries.emplace_back (op, left, right);
158240 return {ArrayType::Binary, static_cast <NodeId>(m_binaries.size () - 1 )};
159241}
@@ -168,6 +250,43 @@ ExpressionHandle ExpressionGraph::add_ternary(TernaryOperator op, ExpressionHand
168250ExpressionHandle ExpressionGraph::add_nary (NaryOperator op,
169251 const std::vector<ExpressionHandle> &operands)
170252{
253+ // Constant folding: if all operands are constants, compute the result directly
254+ bool all_constants = true ;
255+ for (const auto &operand : operands)
256+ {
257+ if (operand.array != ArrayType::Constant)
258+ {
259+ all_constants = false ;
260+ break ;
261+ }
262+ }
263+
264+ if (all_constants && !operands.empty ())
265+ {
266+ double result;
267+ switch (op)
268+ {
269+ case NaryOperator::Add:
270+ result = 0.0 ;
271+ for (const auto &operand : operands)
272+ {
273+ result += m_constants[operand.id ];
274+ }
275+ break ;
276+ case NaryOperator::Mul:
277+ result = 1.0 ;
278+ for (const auto &operand : operands)
279+ {
280+ result *= m_constants[operand.id ];
281+ }
282+ break ;
283+ default :
284+ goto create_node;
285+ }
286+ return add_constant (result);
287+ }
288+
289+ create_node:
171290 m_naries.emplace_back (op, operands);
172291 return {ArrayType::Nary, static_cast <NodeId>(m_naries.size () - 1 )};
173292}
@@ -419,6 +538,8 @@ std::string unary_operator_to_string(UnaryOperator op)
419538 return " Log" ;
420539 case UnaryOperator::Log10:
421540 return " Log10" ;
541+ default :
542+ return " UnknownUnary" ;
422543 }
423544}
424545
@@ -444,6 +565,10 @@ std::string binary_operator_to_string(BinaryOperator op)
444565 return " GreaterEqual" ;
445566 case BinaryOperator::GreaterThan:
446567 return " GreaterThan" ;
568+ case BinaryOperator::Mul2:
569+ return " Mul2" ;
570+ default :
571+ return " UnknownBinary" ;
447572 }
448573}
449574
@@ -453,6 +578,8 @@ std::string ternary_operator_to_string(TernaryOperator op)
453578 {
454579 case TernaryOperator::IfThenElse:
455580 return " IfThenElse" ;
581+ default :
582+ return " UnknownTernary" ;
456583 }
457584}
458585
@@ -464,6 +591,8 @@ std::string nary_operator_to_string(NaryOperator op)
464591 return " Add" ;
465592 case NaryOperator::Mul:
466593 return " Mul" ;
594+ default :
595+ return " UnknownNary" ;
467596 }
468597}
469598
0 commit comments