Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions extensions/standard/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,76 @@ Value replace(const std::vector<Value>& args) {
arr = result;
return arr;
}


// 拼接给定的所有列表
Value concat(const std::vector<Value>& args) {
std::vector<Value> result;
for (const auto &i : args) {
if (!i.is_array()) {
L_ERR("Given parameter(s) have no-array element");
return LAMINA_NULL;
}
for (const auto &j : std::get<std::vector<Value> >(i.data)) {
result.push_back(j);
}
}
return Value(result);
}

// 列表切片
Value slice(const std::vector<Value>& args) {
check_cpp_function_argv_x(args, 3, 4);
if (!args[0].is_array()) {
L_ERR("slice() requires a list");
return LAMINA_NULL;
}
std::vector<Value> val = std::get<std::vector<Value> >(args[0].data);
std::vector<Value> result;
int begin = (int) args[1].as_number(), end = (int) args[2].as_number(), step = 1;
if (end < begin) step = -1;
if (args.size() >= 4) {
step = (int) args[3].as_number();
}
for (int i = begin; i != end; i += step) {
int it = (i >= 0) ? i : (int(val.size()) + i);
if (it < 0 || it >= val.size()) {
break;
}
result.push_back(val[it]);
}
return Value(result);
}

// 排序给定的列表,第二个参数表示比较器
Value _sort(const std::vector<Value>& args) {
check_cpp_function_argv_x(args, 1, 2);
if (!args[0].is_array()) {
L_ERR("sort() requires a list");
return LAMINA_NULL;
}
std::vector<Value> val = std::get<std::vector<Value> >(args[0].data);
for (auto &i : val) {
if (!i.is_comparable()) {
L_ERR("Array has uncomparable object");
return args[0]; // A failure, not an error.
}
}
std::function<bool(const Value &a, const Value &b)> comparer;
if (args.size() >= 2) {
if (!args[1].is_lambda()) {
L_ERR("Comparer must be a lambda/function");
return LAMINA_NULL;
}
const auto func = std::get<std::shared_ptr<LambdaDeclExpr>>(args[1].data);
comparer = [&func](const Value &a, const Value &b) -> bool {
return Interpreter::call_function(func.get(), {a, b}).as_bool();
};
} else {
comparer = [](const Value &a, const Value &b) -> bool {
return a < b;
};
}
sort(val.begin(), val.end(), comparer);
return Value(val);
}
13 changes: 13 additions & 0 deletions extensions/standard/standard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ Value map(const std::vector<Value>& args);
// 替换内容:需3个参数(原字符串/容器、目标值、替换值),替换所有匹配的目标值并返回新结果
Value replace(const std::vector<Value>& args);

// 拼接给定的所有列表
Value concat(const std::vector<Value>& args);

// 切片,可以为 [a:b] 或 [a:b:c]
Value slice(const std::vector<Value>& args);

// 排序给定的列表,第二个参数表示比较器
Value _sort(const std::vector<Value>& args);

// 变量表
Value vars(const std::vector<Value>& args);

Expand Down Expand Up @@ -298,6 +307,10 @@ inline std::unordered_map<std::string, Value> register_builtins =
LAMINA_FUNC("find", find),
LAMINA_FUNC("map", map),
LAMINA_FUNC("replace", replace),

LAMINA_FUNC("concat", concat),
LAMINA_FUNC("slice", slice),
LAMINA_FUNC("sort", _sort),

// CAS数学模块:封装符号计算相关的解析、化简、求导等功能
LAMINA_MODULE("cas", LAMINA_VERSION, {
Expand Down
84 changes: 79 additions & 5 deletions interpreter/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Value HANDLE_BINARYEXPR_ADD(Value* l, Value* r) {
return Value(l->to_string() + r->to_string());
} else if (ltype & VALUE_IS_ARRAY && rtype & VALUE_IS_ARRAY) {
// Vector addition
return l->vector_add(r);
return l->vector_add(*r);
// 只要有一方是 Irrational 或 Symbolic,优先生成符号表达式
} else if (((ltype & VALUE_IS_IRRATIONAL) || (ltype & VALUE_IS_SYMBOLIC) || (rtype & VALUE_IS_IRRATIONAL) || (rtype & VALUE_IS_SYMBOLIC)) && (ltype & VALUE_IS_NUMERIC) && (rtype & VALUE_IS_NUMERIC)) {
std::shared_ptr<SymbolicExpr> leftExpr = GET_SYMBOLICEXPR(l, ltype);
Expand Down Expand Up @@ -111,6 +111,7 @@ Value HANDLE_BINARYEXPR_ADD(Value* l, Value* r) {
Value HANDLE_BINARYEXPR_STR_ADD_STR(Value* l, Value* r) {
std::string ls = l->to_string();
std::string rs = r->to_string();
/*
try {
// Try to parse both strings as CAS expressions and combine them symbolically
LaminaCAS::Parser pl(ls);
Expand Down Expand Up @@ -257,6 +258,8 @@ Value HANDLE_BINARYEXPR_STR_ADD_STR(Value* l, Value* r) {
} catch (...) {
// parsing failed for one or both strings; fall back to normal concatenation
}
*/
return Value(ls + rs);
}

Value Interpreter::eval_LiteralExpr(const LiteralExpr* node) {
Expand Down Expand Up @@ -337,11 +340,13 @@ Value Interpreter::eval_CallExpr(const CallExpr* call) {
return {};
}
// User function
return Interpreter::call_function(func.get(), args, self);
return Interpreter::call_function(func.get(), args, self, left.in_module ? Value(left.in_module) : LAMINA_NULL);
}

if (std::holds_alternative<std::shared_ptr<LmCppFunction>>(left.data)) {
push_frame("<cpp function>", " ");
push_scope();
set_module_as(left.in_module ? Value(left.in_module) : LAMINA_NULL);

Value result;
std::shared_ptr<LmCppFunction> func;
Expand All @@ -350,17 +355,19 @@ Value Interpreter::eval_CallExpr(const CallExpr* call) {
result = func->function(args);
} catch (...) {
pop_frame();
pop_scope();
throw;
}
pop_frame();
pop_scope();
return result;
}

std::cerr << "Type is not a callable object " << std::endl;
return {};
}

Value Interpreter::call_function(const LambdaDeclExpr* func, const std::vector<Value>& args, Value self) {
Value Interpreter::call_function(const LambdaDeclExpr* func, const std::vector<Value>& args, Value self, Value module) {
if (func == nullptr) {
std::cerr << "Error: Function at '" << func << "' is null" << std::endl;
return Value("<func error>");
Expand All @@ -369,6 +376,7 @@ Value Interpreter::call_function(const LambdaDeclExpr* func, const std::vector<V
Interpreter::push_frame(func->name, "<script>", 0);// Add to call stack

Interpreter::push_scope();// Create scope here
Interpreter::set_module_as(module); // After pushing scope!! If it's null, then there's nothing to consider.
// Pass arguments
for (size_t j = 0; j < func->params.size(); ++j) {
set_variable(func->params[j], args[j]);
Expand Down Expand Up @@ -429,6 +437,11 @@ Value Interpreter::eval_BinaryExpr(const BinaryExpr* bin) {
// Arithmetic operations (require numeric operands or vector operations)
if (bin->op == "-" || bin->op == "*" || bin->op == "/" ||
bin->op == "%" || bin->op == "^") {

if (l.is_infinity() || r.is_infinity()) {
L_ERR("Error: Infinity cannot participate in evaluations");
}

// Special handling for multiplication
if (bin->op == "*") {
// Vector and matrix operations
Expand Down Expand Up @@ -711,10 +724,44 @@ Value Interpreter::eval_BinaryExpr(const BinaryExpr* bin) {
// Comparison operators
if (bin->op == "==" || bin->op == "!=" || bin->op == "<" ||
bin->op == "<=" || bin->op == ">" || bin->op == ">=") {
if (l.is_infinity() && r.is_infinity()) {
int lt = std::get<int>(l.data), rt = std::get<int>(r.data);
if (bin->op == "==") return lt == rt;
else if (bin->op == "!=") return lt != rt;
else if (bin->op == "<") return lt < rt;
else if (bin->op == "<=") return lt <= rt;
else if (bin->op == ">") return lt > rt;
else if (bin->op == ">=") return lt >= rt;
else return false;
}
if (l.is_infinity()) {
if (bin->op == "==") return false;
if (bin->op == "!=") return true;
if (bin->op == ">" || bin->op == ">=") return (std::get<int>(l.data) > 0);
else return !(std::get<int>(l.data) > 0);
}
if (r.is_infinity()) {
if (bin->op == "==") return false;
if (bin->op == "!=") return true;
if (bin->op == "<" || bin->op == "<=") return (std::get<int>(r.data) > 0);
else return !(std::get<int>(r.data) > 0);
}
// Handle different type combinations
if (l.is_numeric() && r.is_numeric()) {
// BigInt 比较优先
if (l.is_bigint() || r.is_bigint()) {
if (l.is_symbolic() || r.is_symbolic()) {
auto ls = l.as_symbolic();
auto rs = SymbolicExpr::multiply(SymbolicExpr::number(-1), r.as_symbolic());
auto res = SymbolicExpr::add(ls, rs)->simplify();
double rd = res->to_double();

if (bin->op == "==") return Value(rd == 0);
if (bin->op == "!=") return Value(rd != 0);
if (bin->op == "<") return Value(rd < 0);
if (bin->op == "<=") return Value(rd <= 0);
if (bin->op == ">") return Value(rd > 0);
if (bin->op == ">=") return Value(rd >= 0);
} else if (l.is_bigint() || r.is_bigint()) {
::BigInt lb = l.is_bigint() ? std::get<::BigInt>(l.data) : ::BigInt(l.as_number());
::BigInt rb = r.is_bigint() ? std::get<::BigInt>(r.data) : ::BigInt(r.as_number());

Expand Down Expand Up @@ -760,7 +807,24 @@ Value Interpreter::eval_BinaryExpr(const BinaryExpr* bin) {
if (bin->op == ">") return Value(!result_less && ls != rs);
if (bin->op == ">=") return Value(!result_less);
}
} else {
} else if (l.is_rational() || r.is_rational()) {
const auto& ld = l.as_rational();
const auto& rd = r.as_rational();
if (bin->op == "==") return Value(ld == rd);
if (bin->op == "!=") return Value(ld != rd);
if (bin->op == "<") return Value(ld < rd);
if (bin->op == "<=") return Value(ld <= rd);
if (bin->op == ">") return Value(ld > rd);
if (bin->op == ">=") return Value(ld >= rd);
} else if ((l.is_array() || l.is_matrix()) && (r.is_array() || r.is_matrix())) {
// Implemented in value.hpp
bool judge = (l == r);
if (bin->op == "==") return Value(judge);
if (bin->op == "!=") return Value(!judge);

L_ERR("Cannot compare array or matrix with operator '" + bin->op + "'");
return Value();
} else {
double ld = l.as_number();
double rd = r.as_number();

Expand Down Expand Up @@ -805,12 +869,22 @@ Value Interpreter::eval_BinaryExpr(const BinaryExpr* bin) {
}
}

if (bin->op == "and" || bin->op == "or") {
bool lb = l.as_bool(), rb = r.as_bool();
if (bin->op == "and") return lb and rb;
else if (bin->op == "or") return lb or rb;
}

L_ERR("Unknown binary operator '" + bin->op + "'");
return {};
}

Value Interpreter::eval_UnaryExpr(const UnaryExpr* unary) {
Value v = eval(unary->operand.get());

if (unary->op == "not") {
return Value(!v.as_bool());
}

if (unary->op == "-") {
if (v.type != Value::Type::Int && v.type != Value::Type::BigInt && v.type != Value::Type::Float && v.type != Value::Type::Rational) {
Expand Down
Loading