Skip to content

Commit 0049661

Browse files
committed
Add constant folding and some default in switch (helped by AI)
1 parent 8512084 commit 0049661

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed

lib/nlexpr.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "pyoptinterface/nlexpr.hpp"
22

33
#include <cassert>
4+
#include <cmath>
45
#include "fmt/core.h"
56

67
bool ExpressionHandle::operator==(const ExpressionHandle &x) const
@@ -26,6 +27,8 @@ std::string ExpressionHandle::to_string() const
2627
return fmt::format("t{}", id);
2728
case ArrayType::Nary:
2829
return fmt::format("n{}", id);
30+
default:
31+
return fmt::format("?{}", id);
2932
}
3033
}
3134

@@ -147,13 +150,92 @@ ExpressionHandle ExpressionGraph::add_parameter(EntityId id)
147150

148151
ExpressionHandle ExpressionGraph::add_unary(UnaryOperator op, ExpressionHandle operand)
149152
{
153+
// Constant folding: if the operand is a constant, compute the result directly
154+
if (operand.array == ArrayType::Constant)
155+
{
156+
double val = m_constants[operand.id];
157+
double result;
158+
switch (op)
159+
{
160+
case UnaryOperator::Neg:
161+
result = -val;
162+
break;
163+
case UnaryOperator::Sin:
164+
result = std::sin(val);
165+
break;
166+
case UnaryOperator::Cos:
167+
result = std::cos(val);
168+
break;
169+
case UnaryOperator::Tan:
170+
result = std::tan(val);
171+
break;
172+
case UnaryOperator::Asin:
173+
result = std::asin(val);
174+
break;
175+
case UnaryOperator::Acos:
176+
result = std::acos(val);
177+
break;
178+
case UnaryOperator::Atan:
179+
result = std::atan(val);
180+
break;
181+
case UnaryOperator::Abs:
182+
result = std::abs(val);
183+
break;
184+
case UnaryOperator::Sqrt:
185+
result = std::sqrt(val);
186+
break;
187+
case UnaryOperator::Exp:
188+
result = std::exp(val);
189+
break;
190+
case UnaryOperator::Log:
191+
result = std::log(val);
192+
break;
193+
case UnaryOperator::Log10:
194+
result = std::log10(val);
195+
break;
196+
default:
197+
// Unknown operator, fall through to create the node
198+
goto create_node;
199+
}
200+
return add_constant(result);
201+
}
202+
create_node:
150203
m_unaries.emplace_back(op, operand);
151204
return {ArrayType::Unary, static_cast<NodeId>(m_unaries.size() - 1)};
152205
}
153206

154207
ExpressionHandle ExpressionGraph::add_binary(BinaryOperator op, ExpressionHandle left,
155208
ExpressionHandle right)
156209
{
210+
// Constant folding: if both operands are constants, compute the result directly
211+
// Note: comparison operators are not folded as they produce boolean results
212+
if (left.array == ArrayType::Constant && right.array == ArrayType::Constant &&
213+
!is_binary_compare_op(op))
214+
{
215+
double lval = m_constants[left.id];
216+
double rval = m_constants[right.id];
217+
double result;
218+
switch (op)
219+
{
220+
case BinaryOperator::Sub:
221+
result = lval - rval;
222+
break;
223+
case BinaryOperator::Div:
224+
result = lval / rval;
225+
break;
226+
case BinaryOperator::Pow:
227+
result = std::pow(lval, rval);
228+
break;
229+
case BinaryOperator::Mul2:
230+
result = lval * rval;
231+
break;
232+
default:
233+
// Comparison operators or unknown, fall through to create the node
234+
goto create_node;
235+
}
236+
return add_constant(result);
237+
}
238+
create_node:
157239
m_binaries.emplace_back(op, left, right);
158240
return {ArrayType::Binary, static_cast<NodeId>(m_binaries.size() - 1)};
159241
}
@@ -168,6 +250,43 @@ ExpressionHandle ExpressionGraph::add_ternary(TernaryOperator op, ExpressionHand
168250
ExpressionHandle ExpressionGraph::add_nary(NaryOperator op,
169251
const std::vector<ExpressionHandle> &operands)
170252
{
253+
// Constant folding: if all operands are constants, compute the result directly
254+
bool all_constants = true;
255+
for (const auto &operand : operands)
256+
{
257+
if (operand.array != ArrayType::Constant)
258+
{
259+
all_constants = false;
260+
break;
261+
}
262+
}
263+
264+
if (all_constants && !operands.empty())
265+
{
266+
double result;
267+
switch (op)
268+
{
269+
case NaryOperator::Add:
270+
result = 0.0;
271+
for (const auto &operand : operands)
272+
{
273+
result += m_constants[operand.id];
274+
}
275+
break;
276+
case NaryOperator::Mul:
277+
result = 1.0;
278+
for (const auto &operand : operands)
279+
{
280+
result *= m_constants[operand.id];
281+
}
282+
break;
283+
default:
284+
goto create_node;
285+
}
286+
return add_constant(result);
287+
}
288+
289+
create_node:
171290
m_naries.emplace_back(op, operands);
172291
return {ArrayType::Nary, static_cast<NodeId>(m_naries.size() - 1)};
173292
}
@@ -419,6 +538,8 @@ std::string unary_operator_to_string(UnaryOperator op)
419538
return "Log";
420539
case UnaryOperator::Log10:
421540
return "Log10";
541+
default:
542+
return "UnknownUnary";
422543
}
423544
}
424545

@@ -444,6 +565,10 @@ std::string binary_operator_to_string(BinaryOperator op)
444565
return "GreaterEqual";
445566
case BinaryOperator::GreaterThan:
446567
return "GreaterThan";
568+
case BinaryOperator::Mul2:
569+
return "Mul2";
570+
default:
571+
return "UnknownBinary";
447572
}
448573
}
449574

@@ -453,6 +578,8 @@ std::string ternary_operator_to_string(TernaryOperator op)
453578
{
454579
case TernaryOperator::IfThenElse:
455580
return "IfThenElse";
581+
default:
582+
return "UnknownTernary";
456583
}
457584
}
458585

@@ -464,6 +591,8 @@ std::string nary_operator_to_string(NaryOperator op)
464591
return "Add";
465592
case NaryOperator::Mul:
466593
return "Mul";
594+
default:
595+
return "UnknownNary";
467596
}
468597
}
469598

0 commit comments

Comments
 (0)