From e81a4eb840753f21fb8b7e859bc2861f433294fa Mon Sep 17 00:00:00 2001 From: Jonathan Powell Date: Mon, 11 May 2026 11:42:01 -0400 Subject: [PATCH 1/5] First pass - pretty_print function --- mongosql/src/mir/mod.rs | 1 + mongosql/src/mir/pretty_print.rs | 1161 ++++++++++++++++++++++++++++++ 2 files changed, 1162 insertions(+) create mode 100644 mongosql/src/mir/pretty_print.rs diff --git a/mongosql/src/mir/mod.rs b/mongosql/src/mir/mod.rs index 5c2eb36df..15c71c092 100644 --- a/mongosql/src/mir/mod.rs +++ b/mongosql/src/mir/mod.rs @@ -1,5 +1,6 @@ pub mod definitions; pub use definitions::*; +pub mod pretty_print; pub mod schema; pub use mongosql_datastructures::binding_tuple; diff --git a/mongosql/src/mir/pretty_print.rs b/mongosql/src/mir/pretty_print.rs new file mode 100644 index 000000000..2daa91c75 --- /dev/null +++ b/mongosql/src/mir/pretty_print.rs @@ -0,0 +1,1161 @@ +//! Pretty-printer for MIR pipeline trees. +//! +//! Renders MIR stages as an indented tree — the same shape as database EXPLAIN output — +//! so that optimizer pass changes are easy to audit visually. +//! +//! # Example +//! +//! ```text +//! FILTER [Lt(foo.int, 42i32)] +//! COLLECTION [db.foo] +//! ``` + +// This module is a diagnostic utility with no production callers; suppress the +// dead_code lint so the module compiles cleanly until callers are added. +#![allow(dead_code)] + +use super::{ + AggregationExpr, DatePart, ElemMatch, Expression, FieldAccess, FieldPath, JoinType, + LiteralValue, MatchFalse, MatchLanguageComparison, MatchLanguageComparisonOp, + MatchLanguageLogical, MatchLanguageLogicalOp, MatchLanguageRegex, MatchLanguageType, + MatchQuery, MqlStage, OptionallyAliasedExpr, ReferenceExpr, SetOperation, SortSpecification, + Stage, SubqueryComparisonOp, SubqueryExpr, SubqueryModifier, Type, TypeOrMissing, +}; +use mongosql_datastructures::binding_tuple::{DatasourceName, Key}; + +/// Pretty-prints a MIR data structure to a human-readable indented tree. +pub trait PrettyPrint { + /// Returns a pretty-printed representation of this value rooted at column 0. + /// + /// # Errors + /// + /// MIR pretty-printing is infallible; this method never returns `Err`. + fn pretty_print(&self) -> Result; +} + +/// Error type for MIR pretty-printing. +/// +/// MIR pretty-printing is infallible, so this enum has no variants. +#[derive(Debug, Clone, Copy, thiserror::Error, PartialEq, Eq)] +pub enum Error {} + +type Result = std::result::Result; + +/// Indents every line of `s` by `depth` levels (2 spaces each). +fn indent_lines(s: &str, depth: usize) -> String { + let prefix = " ".repeat(depth); + s.lines() + .map(|line| format!("{prefix}{line}")) + .collect::>() + .join("\n") +} + +fn type_str(ty: Type) -> &'static str { + match ty { + Type::Array => "Array", + Type::BinData => "BinData", + Type::Boolean => "Boolean", + Type::Datetime => "Datetime", + Type::DbPointer => "DbPointer", + Type::Decimal128 => "Decimal128", + Type::Document => "Document", + Type::Double => "Double", + Type::Int32 => "Int32", + Type::Int64 => "Int64", + Type::Javascript => "Javascript", + Type::JavascriptWithScope => "JavascriptWithScope", + Type::MaxKey => "MaxKey", + Type::MinKey => "MinKey", + Type::Null => "Null", + Type::ObjectId => "ObjectId", + Type::RegularExpression => "RegularExpression", + Type::String => "String", + Type::Symbol => "Symbol", + Type::Timestamp => "Timestamp", + Type::Undefined => "Undefined", + } +} + +fn type_or_missing_str(tom: &TypeOrMissing) -> String { + match tom { + TypeOrMissing::Missing => "Missing".to_string(), + TypeOrMissing::Number => "Number".to_string(), + TypeOrMissing::Type(ty) => type_str(*ty).to_string(), + } +} + +fn date_part_str(dp: &DatePart) -> &'static str { + match dp { + DatePart::Year => "Year", + DatePart::Quarter => "Quarter", + DatePart::Month => "Month", + DatePart::Week => "Week", + DatePart::Day => "Day", + DatePart::Hour => "Hour", + DatePart::Minute => "Minute", + DatePart::Second => "Second", + DatePart::Millisecond => "Millisecond", + } +} + +fn agg_expr_pp(agg: &AggregationExpr) -> Result { + match agg { + AggregationExpr::CountStar(distinct) => { + if *distinct { + Ok("CountStarDistinct".to_string()) + } else { + Ok("CountStar".to_string()) + } + } + AggregationExpr::Function(f) => { + let arg = f.arg.pretty_print()?; + if f.distinct { + Ok(format!("{}Distinct({})", f.function.as_str(), arg)) + } else { + Ok(format!("{}({})", f.function.as_str(), arg)) + } + } + } +} + +fn subquery_expr_pp(se: &SubqueryExpr) -> Result { + let output = se.output_expr.pretty_print()?; + let stage = indent_lines(&se.subquery.pretty_print()?, 1); + Ok(format!("Subquery({output})\n{stage}")) +} + +fn join_type_str(jt: JoinType) -> &'static str { + match jt { + JoinType::Inner => "Inner", + JoinType::Left => "Left", + } +} + +fn subquery_comparison_op_str(op: SubqueryComparisonOp) -> &'static str { + match op { + SubqueryComparisonOp::Lt => "Lt", + SubqueryComparisonOp::Lte => "Lte", + SubqueryComparisonOp::Neq => "Neq", + SubqueryComparisonOp::Eq => "Eq", + SubqueryComparisonOp::Gt => "Gt", + SubqueryComparisonOp::Gte => "Gte", + } +} + +fn subquery_modifier_str(m: SubqueryModifier) -> &'static str { + match m { + SubqueryModifier::Any => "Any", + SubqueryModifier::All => "All", + } +} + +fn match_comparison_op_str(op: MatchLanguageComparisonOp) -> &'static str { + match op { + MatchLanguageComparisonOp::Lt => "Lt", + MatchLanguageComparisonOp::Lte => "Lte", + MatchLanguageComparisonOp::Ne => "Ne", + MatchLanguageComparisonOp::Eq => "Eq", + MatchLanguageComparisonOp::Gt => "Gt", + MatchLanguageComparisonOp::Gte => "Gte", + } +} + +impl PrettyPrint for Key { + fn pretty_print(&self) -> Result { + let name = match &self.datasource { + DatasourceName::Bottom => "__bot__", + DatasourceName::Named(n) => n.as_str(), + }; + if self.scope == 0 { + Ok(name.to_string()) + } else { + Ok(format!("{name}@{}", self.scope)) + } + } +} + +impl PrettyPrint for LiteralValue { + fn pretty_print(&self) -> Result { + let s = match self { + LiteralValue::Null => "null".to_string(), + LiteralValue::Undefined => "undefined".to_string(), + LiteralValue::Boolean(b) => b.to_string(), + LiteralValue::Integer(i) => format!("{i}i32"), + LiteralValue::Long(l) => format!("{l}i64"), + LiteralValue::Double(d) => format!("{d}f64"), + LiteralValue::String(s) => format!("\"{s}\""), + LiteralValue::Decimal128(d) => format!("Decimal128({d})"), + LiteralValue::RegularExpression(r) => { + format!("Regex({}, {})", r.pattern, r.options) + } + LiteralValue::JavaScriptCode(code) => format!("Javascript({code})"), + LiteralValue::JavaScriptCodeWithScope(j) => { + format!("JavascriptWithScope({})", j.code) + } + LiteralValue::Timestamp(t) => format!("Timestamp({}, {})", t.time, t.increment), + LiteralValue::Binary(b) => { + format!("Binary({:?}, <{} bytes>)", b.subtype, b.bytes.len()) + } + LiteralValue::ObjectId(oid) => format!("ObjectId({oid})"), + LiteralValue::DateTime(dt) => format!("DateTime({})", dt.timestamp_millis()), + LiteralValue::Symbol(s) => format!("Symbol({s})"), + LiteralValue::MaxKey => "MaxKey".to_string(), + LiteralValue::MinKey => "MinKey".to_string(), + LiteralValue::DbPointer(dp) => format!("DbPointer({dp:?})"), + }; + Ok(s) + } +} + +impl PrettyPrint for FieldPath { + fn pretty_print(&self) -> Result { + let key = self.key.pretty_print()?; + if self.fields.is_empty() { + Ok(key) + } else { + Ok(format!("{key}.{}", self.fields.join("."))) + } + } +} + +impl PrettyPrint for SortSpecification { + fn pretty_print(&self) -> Result { + match self { + SortSpecification::Asc(fp) => Ok(format!("{} ASC", fp.pretty_print()?)), + SortSpecification::Desc(fp) => Ok(format!("{} DESC", fp.pretty_print()?)), + } + } +} + +impl PrettyPrint for Expression { + fn pretty_print(&self) -> Result { + match self { + Expression::Literal(lit) => lit.pretty_print(), + Expression::Reference(ReferenceExpr { key }) => Ok(format!("${}", key.pretty_print()?)), + Expression::FieldAccess(FieldAccess { expr, field, .. }) => { + Ok(format!("{}.{field}", expr.pretty_print()?)) + } + Expression::MqlIntrinsicFieldExistence(FieldAccess { expr, field, .. }) => { + Ok(format!("FieldExists({}.{field})", expr.pretty_print()?)) + } + Expression::Array(arr) => { + let elems = arr + .array + .iter() + .map(|e| e.pretty_print()) + .collect::>>()? + .join(", "); + Ok(format!("[{elems}]")) + } + Expression::Document(doc) => { + let fields = doc + .document + .iter() + .map(|(k, v)| v.pretty_print().map(|vpp| format!("{k}: {vpp}"))) + .collect::>>()? + .join(", "); + Ok(format!("{{{fields}}}")) + } + Expression::ScalarFunction(sf) => { + let args = sf + .args + .iter() + .map(|a| a.pretty_print()) + .collect::>>()? + .join(", "); + Ok(format!("{}({args})", sf.function.as_str())) + } + Expression::DateFunction(df) => { + let args = df + .args + .iter() + .map(|a| a.pretty_print()) + .collect::>>()? + .join(", "); + Ok(format!( + "{}({}, {args})", + df.function.as_str(), + date_part_str(&df.date_part) + )) + } + Expression::Is(is_expr) => Ok(format!( + "Is({}, {})", + is_expr.expr.pretty_print()?, + type_or_missing_str(&is_expr.target_type) + )), + Expression::Like(like_expr) => { + let expr = like_expr.expr.pretty_print()?; + let pattern = like_expr.pattern.pretty_print()?; + if let Some(escape) = like_expr.escape { + Ok(format!("Like({expr}, {pattern}, '{escape}')")) + } else { + Ok(format!("Like({expr}, {pattern})")) + } + } + Expression::Cast(cast) => Ok(format!( + "Cast({}, {}, on_null={}, on_error={})", + cast.expr.pretty_print()?, + type_str(cast.to), + cast.on_null.pretty_print()?, + cast.on_error.pretty_print()? + )), + Expression::TypeAssertion(ta) => Ok(format!( + "TypeAssertion({}, {})", + ta.expr.pretty_print()?, + type_str(ta.target_type) + )), + Expression::SearchedCase(sc) => { + let branches = sc + .when_branch + .iter() + .map(|wb| { + let when = wb.when.pretty_print()?; + let then = wb.then.pretty_print()?; + Ok(format!("When({when}) Then({then})")) + }) + .collect::>>()? + .join(", "); + let else_br = sc.else_branch.pretty_print()?; + Ok(format!("SearchedCase({branches}, Else({else_br}))")) + } + Expression::SimpleCase(sc) => { + let operand = sc.expr.pretty_print()?; + let branches = sc + .when_branch + .iter() + .map(|wb| { + let when = wb.when.pretty_print()?; + let then = wb.then.pretty_print()?; + Ok(format!("When({when}) Then({then})")) + }) + .collect::>>()? + .join(", "); + let else_br = sc.else_branch.pretty_print()?; + Ok(format!( + "SimpleCase({operand}, {branches}, Else({else_br}))" + )) + } + Expression::Exists(exists) => { + let stage = indent_lines(&exists.stage.pretty_print()?, 1); + Ok(format!("Exists\n{stage}")) + } + Expression::Subquery(se) => subquery_expr_pp(se), + Expression::SubqueryComparison(sc) => { + let op = subquery_comparison_op_str(sc.operator); + let modifier = subquery_modifier_str(sc.modifier); + let arg = sc.argument.pretty_print()?; + let subquery = indent_lines(&sc.subquery_expr.subquery.pretty_print()?, 1); + let output = sc.subquery_expr.output_expr.pretty_print()?; + Ok(format!( + "SubqueryComparison({op}, {modifier}, {arg}, {output})\n{subquery}" + )) + } + } + } +} + +impl PrettyPrint for MatchQuery { + fn pretty_print(&self) -> Result { + match self { + MatchQuery::Logical(ml) => ml.pretty_print(), + MatchQuery::Type(mt) => mt.pretty_print(), + MatchQuery::Regex(mr) => mr.pretty_print(), + MatchQuery::ElemMatch(em) => em.pretty_print(), + MatchQuery::Comparison(mc) => mc.pretty_print(), + MatchQuery::False(mf) => mf.pretty_print(), + } + } +} + +impl PrettyPrint for MatchLanguageLogical { + fn pretty_print(&self) -> Result { + let op = match self.op { + MatchLanguageLogicalOp::And => "And", + MatchLanguageLogicalOp::Or => "Or", + }; + let args = self + .args + .iter() + .map(|a| a.pretty_print()) + .collect::>>()? + .join(", "); + Ok(format!("{op}({args})")) + } +} + +impl PrettyPrint for MatchLanguageType { + fn pretty_print(&self) -> Result { + let ty = type_or_missing_str(&self.target_type); + match &self.input { + Some(fp) => Ok(format!("MatchType({}, {ty})", fp.pretty_print()?)), + None => Ok(format!("MatchType({ty})")), + } + } +} + +impl PrettyPrint for MatchLanguageRegex { + fn pretty_print(&self) -> Result { + match &self.input { + Some(fp) => Ok(format!( + "MatchRegex({}, {}, {})", + fp.pretty_print()?, + self.regex, + self.options + )), + None => Ok(format!("MatchRegex({}, {})", self.regex, self.options)), + } + } +} + +impl PrettyPrint for ElemMatch { + fn pretty_print(&self) -> Result { + Ok(format!( + "ElemMatch({}, {})", + self.input.pretty_print()?, + self.condition.pretty_print()? + )) + } +} + +impl PrettyPrint for MatchLanguageComparison { + fn pretty_print(&self) -> Result { + let op = match_comparison_op_str(self.function); + let arg = self.arg.pretty_print()?; + match &self.input { + Some(fp) => Ok(format!("{op}({}, {arg})", fp.pretty_print()?)), + None => Ok(format!("{op}({arg})")), + } + } +} + +impl PrettyPrint for MatchFalse { + fn pretty_print(&self) -> Result { + Ok("MatchFalse".to_string()) + } +} + +impl PrettyPrint for Stage { + fn pretty_print(&self) -> Result { + match self { + Stage::Collection(c) => Ok(format!("COLLECTION [{}.{}]", c.db, c.collection)), + Stage::Array(arr) => { + let alias = &arr.alias; + let elems = arr + .array + .iter() + .map(|e| e.pretty_print()) + .collect::>>()? + .join(", "); + Ok(format!("ARRAY [{elems}] AS {alias}")) + } + Stage::Sentinel => Ok("SENTINEL".to_string()), + Stage::Filter(f) => { + let cond = f.condition.pretty_print()?; + let source = indent_lines(&f.source.pretty_print()?, 1); + Ok(format!("FILTER [{cond}]\n{source}")) + } + Stage::Project(p) => { + let header = if p.is_add_fields { + "ADD_FIELDS" + } else { + "PROJECT" + }; + let bindings = p + .expression + .0 + .iter() + .map(|(k, v)| { + let kpp = k.pretty_print()?; + let vpp = v.pretty_print()?; + Ok(format!(" {kpp} => {vpp}")) + }) + .collect::>>()? + .join("\n"); + let source = indent_lines(&p.source.pretty_print()?, 1); + Ok(format!("{header}\n{bindings}\n{source}")) + } + Stage::Group(g) => { + let keys_str = if g.keys.is_empty() { + " (none)".to_string() + } else { + g.keys + .iter() + .map(|oae| match oae { + OptionallyAliasedExpr::Aliased(ae) => ae + .expr + .pretty_print() + .map(|e| format!(" {} = {e}", ae.alias)), + OptionallyAliasedExpr::Unaliased(e) => { + e.pretty_print().map(|e| format!(" {e}")) + } + }) + .collect::>>()? + .join("\n") + }; + let aggs_str = if g.aggregations.is_empty() { + " (none)".to_string() + } else { + g.aggregations + .iter() + .map(|aa| { + agg_expr_pp(&aa.agg_expr).map(|a| format!(" {} = {a}", aa.alias)) + }) + .collect::>>()? + .join("\n") + }; + let source = indent_lines(&g.source.pretty_print()?, 1); + Ok(format!( + "GROUP [scope={}]\n keys:\n{keys_str}\n aggs:\n{aggs_str}\n{source}", + g.scope + )) + } + Stage::Limit(l) => { + let source = indent_lines(&l.source.pretty_print()?, 1); + Ok(format!("LIMIT [{}]\n{source}", l.limit)) + } + Stage::Offset(o) => { + let source = indent_lines(&o.source.pretty_print()?, 1); + Ok(format!("OFFSET [{}]\n{source}", o.offset)) + } + Stage::Sort(s) => { + let specs = s + .specs + .iter() + .map(|sp| sp.pretty_print()) + .collect::>>()? + .join(", "); + let source = indent_lines(&s.source.pretty_print()?, 1); + Ok(format!("SORT [{specs}]\n{source}")) + } + Stage::Unwind(u) => { + let path = u.path.pretty_print()?; + let index = match &u.index { + Some(i) => i.as_str(), + None => "none", + }; + let source = indent_lines(&u.source.pretty_print()?, 1); + Ok(format!( + "UNWIND [path={path}, outer={}, index={index}, prefiltered={}]\n{source}", + u.outer, u.is_prefiltered + )) + } + Stage::Join(j) => { + let jt = join_type_str(j.join_type); + let cond_str = match &j.condition { + Some(c) => format!(", condition={}", c.pretty_print()?), + None => String::new(), + }; + let left = indent_lines(&j.left.pretty_print()?, 2); + let right = indent_lines(&j.right.pretty_print()?, 2); + Ok(format!( + "JOIN [{jt}{cond_str}]\n LEFT:\n{left}\n RIGHT:\n{right}" + )) + } + Stage::Set(s) => { + let op = match s.operation { + SetOperation::UnionAll => "UNION_ALL", + }; + let left = indent_lines(&s.left.pretty_print()?, 2); + let right = indent_lines(&s.right.pretty_print()?, 2); + Ok(format!("{op}\n LEFT:\n{left}\n RIGHT:\n{right}")) + } + Stage::Derived(d) => { + let source = indent_lines(&d.source.pretty_print()?, 1); + Ok(format!("DERIVED\n{source}")) + } + Stage::MqlIntrinsic(mql) => mql.pretty_print(), + } + } +} + +impl PrettyPrint for MqlStage { + fn pretty_print(&self) -> Result { + match self { + MqlStage::MatchFilter(mf) => { + let cond = mf.condition.pretty_print()?; + let source = indent_lines(&mf.source.pretty_print()?, 1); + Ok(format!("MATCH_FILTER [{cond}]\n{source}")) + } + MqlStage::EquiJoin(ej) => { + let jt = join_type_str(ej.join_type); + let local = ej.local_field.pretty_print()?; + let foreign = ej.foreign_field.pretty_print()?; + let source = indent_lines(&ej.source.pretty_print()?, 2); + let from = indent_lines(&ej.from.pretty_print()?, 2); + Ok(format!( + "EQUI_JOIN [{jt}, local={local}, foreign={foreign}]\n SOURCE:\n{source}\n FROM:\n{from}" + )) + } + MqlStage::LateralJoin(lj) => { + let jt = join_type_str(lj.join_type); + let source = indent_lines(&lj.source.pretty_print()?, 2); + let subquery = indent_lines(&lj.subquery.pretty_print()?, 2); + Ok(format!( + "LATERAL_JOIN [{jt}]\n SOURCE:\n{source}\n SUBQUERY:\n{subquery}" + )) + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + map, + mir::{ + self, schema::SchemaCache, AggregationFunction, AggregationFunctionApplication, + ArrayExpr, ArraySource, Derived, DocumentExpr, ElemMatch, ExistsExpr, FieldAccess, + Filter, Group, IsExpr, Join, JoinType, LateralJoin, LikeExpr, Limit, MatchFalse, + MatchFilter, MatchLanguageComparison, MatchLanguageComparisonOp, MatchLanguageLogical, + MatchLanguageLogicalOp, MatchLanguageRegex, MatchLanguageType, MatchQuery, MqlStage, + Offset, OptionallyAliasedExpr, Project, ReferenceExpr, ScalarFunction, + ScalarFunctionApplication, SearchedCaseExpr, Set, SetOperation, SimpleCaseExpr, Sort, + SortSpecification, Stage, TypeAssertionExpr, TypeOrMissing, Unwind, WhenBranch, + }, + schema::Satisfaction, + }; + use mongosql_datastructures::binding_tuple::{BindingTuple, Key}; + + fn collection(db: &str, coll: &str) -> Box { + Box::new(Stage::Collection(mir::Collection { + db: db.to_string(), + collection: coll.to_string(), + cache: SchemaCache::new(), + })) + } + + fn ref_expr(name: &str) -> Expression { + Expression::Reference(ReferenceExpr { + key: Key::named(name, 0), + }) + } + + fn field_access(name: &str, field: &str) -> Expression { + Expression::FieldAccess(FieldAccess { + expr: Box::new(ref_expr(name)), + field: field.to_string(), + is_nullable: false, + }) + } + + // ── Literals ────────────────────────────────────────────────────────────── + + mod literal { + use super::*; + + #[test] + fn null_renders_without_type_suffix() { + assert_eq!(LiteralValue::Null.pretty_print().unwrap(), "null"); + } + + #[test] + fn undefined_renders() { + assert_eq!(LiteralValue::Undefined.pretty_print().unwrap(), "undefined"); + } + + #[test] + fn boolean_true() { + assert_eq!(LiteralValue::Boolean(true).pretty_print().unwrap(), "true"); + } + + #[test] + fn boolean_false() { + assert_eq!( + LiteralValue::Boolean(false).pretty_print().unwrap(), + "false" + ); + } + + #[test] + fn integer_has_i32_suffix() { + assert_eq!(LiteralValue::Integer(42).pretty_print().unwrap(), "42i32"); + } + + #[test] + fn long_has_i64_suffix() { + assert_eq!(LiteralValue::Long(42).pretty_print().unwrap(), "42i64"); + } + + #[test] + fn double_has_f64_suffix() { + assert_eq!( + LiteralValue::Double(3.14).pretty_print().unwrap(), + "3.14f64" + ); + } + + #[test] + fn string_is_double_quoted() { + assert_eq!( + LiteralValue::String("hello".to_string()) + .pretty_print() + .unwrap(), + r#""hello""# + ); + } + + #[test] + fn max_key_renders() { + assert_eq!(LiteralValue::MaxKey.pretty_print().unwrap(), "MaxKey"); + } + + #[test] + fn min_key_renders() { + assert_eq!(LiteralValue::MinKey.pretty_print().unwrap(), "MinKey"); + } + } + + // ── Key ─────────────────────────────────────────────────────────────────── + + mod key { + use super::*; + + #[test] + fn named_scope_zero_omits_scope() { + let k = Key::named("foo", 0); + assert_eq!(k.pretty_print().unwrap(), "foo"); + } + + #[test] + fn named_scope_nonzero_shows_at_suffix() { + let k = Key::named("foo", 1); + assert_eq!(k.pretty_print().unwrap(), "foo@1"); + } + + #[test] + fn bottom_key_renders_as_dunder_bot() { + let k = Key::bot(0); + assert_eq!(k.pretty_print().unwrap(), "__bot__"); + } + } + + // ── FieldPath ───────────────────────────────────────────────────────────── + + mod field_path { + use super::*; + use crate::mir::FieldPath; + + #[test] + fn single_field() { + let fp = FieldPath { + key: Key::named("orders", 0), + fields: vec!["status".to_string()], + is_nullable: false, + }; + assert_eq!(fp.pretty_print().unwrap(), "orders.status"); + } + + #[test] + fn multi_field_chain() { + let fp = FieldPath { + key: Key::named("orders", 0), + fields: vec!["customer".to_string(), "id".to_string()], + is_nullable: false, + }; + assert_eq!(fp.pretty_print().unwrap(), "orders.customer.id"); + } + + #[test] + fn bot_key_with_field() { + let fp = FieldPath { + key: Key::bot(0), + fields: vec!["x".to_string()], + is_nullable: false, + }; + assert_eq!(fp.pretty_print().unwrap(), "__bot__.x"); + } + } + + // ── Sort ────────────────────────────────────────────────────────────────── + + mod sort_spec { + use super::*; + use crate::mir::FieldPath; + + fn fp(name: &str, field: &str) -> FieldPath { + FieldPath { + key: Key::named(name, 0), + fields: vec![field.to_string()], + is_nullable: false, + } + } + + #[test] + fn asc_spec() { + assert_eq!( + SortSpecification::Asc(fp("orders", "status")) + .pretty_print() + .unwrap(), + "orders.status ASC" + ); + } + + #[test] + fn desc_spec() { + assert_eq!( + SortSpecification::Desc(fp("orders", "amount")) + .pretty_print() + .unwrap(), + "orders.amount DESC" + ); + } + } + + // ── Expression::Reference ───────────────────────────────────────────────── + + mod reference { + use super::*; + + #[test] + fn named_reference_has_dollar_sigil() { + let expr = ref_expr("foo"); + assert_eq!(expr.pretty_print().unwrap(), "$foo"); + } + + #[test] + fn bot_reference_uses_dunder_bot() { + let expr = Expression::Reference(ReferenceExpr { key: Key::bot(0) }); + assert_eq!(expr.pretty_print().unwrap(), "$__bot__"); + } + } + + // ── Expression::FieldAccess ─────────────────────────────────────────────── + + mod field_access_expr { + use super::*; + + #[test] + fn single_level() { + let expr = field_access("foo", "bar"); + assert_eq!(expr.pretty_print().unwrap(), "$foo.bar"); + } + + #[test] + fn multi_level() { + let expr = Expression::FieldAccess(FieldAccess { + expr: Box::new(field_access("foo", "bar")), + field: "baz".to_string(), + is_nullable: false, + }); + assert_eq!(expr.pretty_print().unwrap(), "$foo.bar.baz"); + } + } + + // ── ScalarFunction ──────────────────────────────────────────────────────── + + mod scalar_function { + use super::*; + + #[test] + fn unary_function() { + let expr = Expression::ScalarFunction(ScalarFunctionApplication { + function: ScalarFunction::Upper, + args: vec![field_access("foo", "name")], + is_nullable: false, + }); + assert_eq!(expr.pretty_print().unwrap(), "Upper($foo.name)"); + } + + #[test] + fn binary_function() { + let expr = Expression::ScalarFunction(ScalarFunctionApplication { + function: ScalarFunction::Lt, + args: vec![ + field_access("foo", "int"), + Expression::Literal(LiteralValue::Integer(10)), + ], + is_nullable: false, + }); + assert_eq!(expr.pretty_print().unwrap(), "Lt($foo.int, 10i32)"); + } + } + + // ── Stage::Collection ───────────────────────────────────────────────────── + + mod collection_stage { + use super::*; + + #[test] + fn renders_db_dot_collection() { + let stage = collection("mydb", "mycoll"); + assert_eq!(stage.pretty_print().unwrap(), "COLLECTION [mydb.mycoll]"); + } + } + + // ── Stage::Filter ───────────────────────────────────────────────────────── + + mod filter_stage { + use super::*; + + #[test] + fn literal_condition() { + let stage = Stage::Filter(Filter { + condition: Expression::Literal(LiteralValue::Boolean(true)), + source: collection("db", "foo"), + cache: SchemaCache::new(), + }); + assert_eq!( + stage.pretty_print().unwrap(), + "FILTER [true]\n COLLECTION [db.foo]" + ); + } + + #[test] + fn scalar_function_condition() { + let stage = Stage::Filter(Filter { + condition: Expression::ScalarFunction(ScalarFunctionApplication { + function: ScalarFunction::Lt, + args: vec![ + field_access("foo", "int"), + Expression::Literal(LiteralValue::Integer(10)), + ], + is_nullable: false, + }), + source: collection("db", "foo"), + cache: SchemaCache::new(), + }); + assert_eq!( + stage.pretty_print().unwrap(), + "FILTER [Lt($foo.int, 10i32)]\n COLLECTION [db.foo]" + ); + } + + #[test] + fn nested_filters() { + let inner = Box::new(Stage::Filter(Filter { + condition: Expression::Literal(LiteralValue::Boolean(true)), + source: collection("db", "foo"), + cache: SchemaCache::new(), + })); + let stage = Stage::Filter(Filter { + condition: Expression::Literal(LiteralValue::Boolean(false)), + source: inner, + cache: SchemaCache::new(), + }); + assert_eq!( + stage.pretty_print().unwrap(), + "FILTER [false]\n FILTER [true]\n COLLECTION [db.foo]" + ); + } + } + + // ── Stage::Project ──────────────────────────────────────────────────────── + + mod project_stage { + use super::*; + + #[test] + fn single_binding() { + let stage = Stage::Project(Project { + is_add_fields: false, + source: collection("db", "foo"), + expression: BindingTuple(map! { + Key::named("foo", 0) => ref_expr("foo") + }), + cache: SchemaCache::new(), + }); + assert_eq!( + stage.pretty_print().unwrap(), + "PROJECT\n foo => $foo\n COLLECTION [db.foo]" + ); + } + + #[test] + fn add_fields_uses_different_header() { + let stage = Stage::Project(Project { + is_add_fields: true, + source: collection("db", "foo"), + expression: BindingTuple(map! { + Key::named("foo", 0) => ref_expr("foo") + }), + cache: SchemaCache::new(), + }); + assert!(stage.pretty_print().unwrap().starts_with("ADD_FIELDS")); + } + } + + // ── Stage::Group ────────────────────────────────────────────────────────── + + mod group_stage { + use super::*; + use crate::mir::{AliasedAggregation, AliasedExpr, FieldPath}; + + fn fp(name: &str, field: &str) -> FieldPath { + FieldPath { + key: Key::named(name, 0), + fields: vec![field.to_string()], + is_nullable: false, + } + } + + #[test] + fn no_keys_count_star() { + let stage = Stage::Group(Group { + source: collection("db", "orders"), + keys: vec![], + aggregations: vec![AliasedAggregation { + alias: "total".to_string(), + agg_expr: AggregationExpr::CountStar(false), + }], + cache: SchemaCache::new(), + scope: 0, + }); + let pp = stage.pretty_print().unwrap(); + assert!(pp.contains("keys:\n (none)"), "got: {pp}"); + assert!(pp.contains("total = CountStar"), "got: {pp}"); + } + + #[test] + fn aliased_key_and_sum_aggregation() { + let stage = Stage::Group(Group { + source: collection("db", "orders"), + keys: vec![OptionallyAliasedExpr::Aliased(AliasedExpr { + alias: "status".to_string(), + expr: Expression::FieldAccess(FieldAccess { + expr: Box::new(ref_expr("orders")), + field: "status".to_string(), + is_nullable: false, + }), + })], + aggregations: vec![AliasedAggregation { + alias: "total".to_string(), + agg_expr: AggregationExpr::Function(AggregationFunctionApplication { + function: AggregationFunction::Sum, + distinct: false, + arg: Box::new(Expression::FieldAccess(FieldAccess { + expr: Box::new(ref_expr("orders")), + field: "amount".to_string(), + is_nullable: false, + })), + arg_is_possibly_doc: Satisfaction::Not, + }), + }], + cache: SchemaCache::new(), + scope: 0, + }); + let pp = stage.pretty_print().unwrap(); + assert!(pp.contains("status = $orders.status"), "got: {pp}"); + assert!(pp.contains("total = Sum($orders.amount)"), "got: {pp}"); + } + } + + // ── Stage::Join ─────────────────────────────────────────────────────────── + + mod join_stage { + use super::*; + + #[test] + fn inner_join_no_condition() { + let stage = Stage::Join(Join { + join_type: JoinType::Inner, + left: collection("db", "orders"), + right: collection("db", "customers"), + condition: None, + cache: SchemaCache::new(), + }); + let pp = stage.pretty_print().unwrap(); + assert_eq!( + pp, + "JOIN [Inner]\n LEFT:\n COLLECTION [db.orders]\n RIGHT:\n COLLECTION [db.customers]" + ); + } + + #[test] + fn left_join_with_condition() { + let stage = Stage::Join(Join { + join_type: JoinType::Left, + left: collection("db", "orders"), + right: collection("db", "customers"), + condition: Some(Expression::ScalarFunction(ScalarFunctionApplication { + function: ScalarFunction::Eq, + args: vec![ + field_access("orders", "id"), + field_access("customers", "order_id"), + ], + is_nullable: false, + })), + cache: SchemaCache::new(), + }); + let pp = stage.pretty_print().unwrap(); + assert!(pp.starts_with("JOIN [Left, condition=Eq("), "got: {pp}"); + } + } + + // ── MqlStage::MatchFilter ───────────────────────────────────────────────── + + mod match_filter_stage { + use super::*; + + #[test] + fn renders_match_filter_with_condition() { + let stage = Stage::MqlIntrinsic(MqlStage::MatchFilter(Box::new(MatchFilter { + source: collection("db", "orders"), + condition: MatchQuery::Comparison(MatchLanguageComparison { + function: MatchLanguageComparisonOp::Gt, + input: Some(crate::mir::FieldPath { + key: Key::named("orders", 0), + fields: vec!["amount".to_string()], + is_nullable: false, + }), + arg: LiteralValue::Integer(100), + cache: SchemaCache::new(), + }), + cache: SchemaCache::new(), + }))); + let pp = stage.pretty_print().unwrap(); + assert!( + pp.starts_with("MATCH_FILTER [Gt(orders.amount, 100i32)]"), + "got: {pp}" + ); + } + } + + // ── Integration: multi-stage pipelines ──────────────────────────────────── + + mod integration { + use super::*; + + #[test] + fn filter_over_collection() { + let stage = Stage::Filter(Filter { + condition: Expression::ScalarFunction(ScalarFunctionApplication { + function: ScalarFunction::Lt, + args: vec![ + field_access("foo", "int"), + Expression::Literal(LiteralValue::Integer(10)), + ], + is_nullable: false, + }), + source: collection("db", "foo"), + cache: SchemaCache::new(), + }); + assert_eq!( + stage.pretty_print().unwrap(), + "FILTER [Lt($foo.int, 10i32)]\n COLLECTION [db.foo]" + ); + } + + #[test] + fn sort_over_filter_over_collection() { + let filter = Box::new(Stage::Filter(Filter { + condition: Expression::Literal(LiteralValue::Boolean(true)), + source: collection("db", "orders"), + cache: SchemaCache::new(), + })); + let stage = Stage::Sort(Sort { + specs: vec![SortSpecification::Asc(crate::mir::FieldPath { + key: Key::named("orders", 0), + fields: vec!["status".to_string()], + is_nullable: false, + })], + source: filter, + cache: SchemaCache::new(), + }); + assert_eq!( + stage.pretty_print().unwrap(), + "SORT [orders.status ASC]\n FILTER [true]\n COLLECTION [db.orders]" + ); + } + } +} From 61f62d6de0bd101d1ab6b8628b141e7da3784596 Mon Sep 17 00:00:00 2001 From: Jonathan Powell Date: Mon, 11 May 2026 15:16:09 -0400 Subject: [PATCH 2/5] Add a --stage flag that can print out the different stages --- README.md | 26 +++++++++ mongosql-cli/src/main.rs | 115 +++++++++++++++++++++++++++++++++++---- mongosql/src/lib.rs | 85 +++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 596e2d7ed..b9df5a28f 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,32 @@ cargo run --package mongosql-cli -- --db mydb --schema-file schema.yaml "SELECT ./target/debug/mongosql-cli --db mydb --execute --translation "SELECT * FROM products" ``` +**Inspect an intermediate compilation stage:** + +Use `--stage` to stop translation at a specific point in the pipeline and print the intermediate representation. This is useful for debugging query compilation issues. + +```bash +# Print the rewritten AST (no MongoDB connection required) +./target/debug/mongosql-cli --db mydb --schema-file schema.yaml --stage ast "SELECT name FROM users" + +# Print the optimized MIR tree +./target/debug/mongosql-cli --db mydb --schema-file schema.yaml --stage mir "SELECT name FROM users" + +# Full MQL pipeline (same as omitting --stage) +./target/debug/mongosql-cli --db mydb --schema-file schema.yaml --stage mql "SELECT name FROM users" +``` + +Available stages (in pipeline order): + +| Stage | Description | +|-------|-------------| +| `ast` | SQL parsed and syntactically rewritten; prints the AST as a Rust debug tree (`{:#?}`). No schema or MongoDB connection required. | +| `mir` | Algebrized and optimizer-pass output; prints the MIR tree. Requires schema. | +| `air` | MIR translated and desugared to AIR; prints the Rust struct tree (`{:#?}`). Requires schema. | +| `mql` | Full translation to a MongoDB aggregation pipeline (default). | + +> **Note:** `--execute` is only valid with `--stage mql` or when `--stage` is omitted. + ### Schema Files When `--schema-file` is provided, the CLI reads collection schemas from a local file. diff --git a/mongosql-cli/src/main.rs b/mongosql-cli/src/main.rs index 0fe808401..278323e4b 100644 --- a/mongosql-cli/src/main.rs +++ b/mongosql-cli/src/main.rs @@ -27,6 +27,18 @@ where } } +#[derive(clap::ValueEnum, Debug, Clone, Copy)] +enum TranslationCheckpoint { + /// Stop after SQL parsing and AST rewrites; print the AST. + Ast, + /// Stop after algebrizing to MIR and running optimizer passes; print the MIR tree. + Mir, + /// AIR pretty-printing is not yet implemented. + Air, + /// Full translation to MQL; print the generated pipeline (default when --stage is omitted). + Mql, +} + #[derive(Parser, Debug)] #[command(version, about, long_about=None)] struct Cli { @@ -64,6 +76,12 @@ struct Cli { help = "A sql file to use instead of passing the query as an argument. The sql-file argument takes precedence over sql query text." )] sql_file: Option, + #[arg( + long, + value_enum, + help = "Stop at the specified compilation stage and print its intermediate representation. When omitted the CLI behaves as normal." + )] + stage: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -94,14 +112,13 @@ fn parse_query_from_args( } } -fn main() -> Result<(), CliError> { - let args = Cli::parse(); - - let uri = args.uri.unwrap_or("mongodb://localhost:27017".to_string()); - let current_db = args.db.unwrap_or("test".to_string()); - let query = parse_query_from_args(args.query, args.sql_file)?; - let namespaces = mongosql::get_namespaces(current_db.as_str(), query.as_str())?; - let catalog = if let Some(schema_file) = args.schema_file { +fn build_catalog( + uri: &str, + current_db: &str, + namespaces: std::collections::BTreeSet, + schema_file: Option, +) -> Result { + if let Some(schema_file) = schema_file { let contents = std::fs::read_to_string(&schema_file)?; let path = std::path::Path::new(&schema_file); let extension = path @@ -114,14 +131,88 @@ fn main() -> Result<(), CliError> { Some("json") => serde_json::from_str(&contents)?, _ => { return Err(CliError(format!( - "Unsupported schema file extension: {extension:?}. Supported formats are .yml, .yaml, .json" + "Unsupported schema file extension: {extension:?}. Supported formats are .yml, .yaml, .json" ))) } }; - build_catalog_from_catalog_schema(catalog.schemas)? + Ok(build_catalog_from_catalog_schema(catalog.schemas)?) } else { - get_schema_catalog(uri.as_str(), current_db.as_str(), namespaces)? - }; + get_schema_catalog(uri, current_db, namespaces) + } +} + +fn main() -> Result<(), CliError> { + let args = Cli::parse(); + + let uri = args.uri.unwrap_or("mongodb://localhost:27017".to_string()); + let current_db = args.db.unwrap_or("test".to_string()); + let query = parse_query_from_args(args.query, args.sql_file)?; + let stage = args.stage.unwrap_or(TranslationCheckpoint::Mql); + + if args.execute && !matches!(stage, TranslationCheckpoint::Mql) { + return Err(CliError( + "--execute is only valid with --stage mql or without --stage".to_string(), + )); + } + + match stage { + TranslationCheckpoint::Ast => { + let output = mongosql::translate_sql_to_ast_repr(query.as_str())?; + println!("{output}"); + return Ok(()); + } + TranslationCheckpoint::Mir => { + let namespaces = mongosql::get_namespaces(current_db.as_str(), query.as_str())?; + let catalog = build_catalog( + uri.as_str(), + current_db.as_str(), + namespaces, + args.schema_file, + )?; + let options = mongosql::options::SqlOptions { + allow_order_by_missing_columns: true, + ..Default::default() + }; + let output = mongosql::translate_sql_to_mir_repr( + current_db.as_str(), + query.as_str(), + &catalog, + options, + )?; + println!("{output}"); + return Ok(()); + } + TranslationCheckpoint::Air => { + let namespaces = mongosql::get_namespaces(current_db.as_str(), query.as_str())?; + let catalog = build_catalog( + uri.as_str(), + current_db.as_str(), + namespaces, + args.schema_file, + )?; + let options = mongosql::options::SqlOptions { + allow_order_by_missing_columns: true, + ..Default::default() + }; + let output = mongosql::translate_sql_to_air_repr( + current_db.as_str(), + query.as_str(), + &catalog, + options, + )?; + println!("{output}"); + return Ok(()); + } + TranslationCheckpoint::Mql => {} + } + + let namespaces = mongosql::get_namespaces(current_db.as_str(), query.as_str())?; + let catalog = build_catalog( + uri.as_str(), + current_db.as_str(), + namespaces, + args.schema_file, + )?; let options = mongosql::options::SqlOptions { allow_order_by_missing_columns: true, ..Default::default() diff --git a/mongosql/src/lib.rs b/mongosql/src/lib.rs index acd11b1d8..c1a0c9a4c 100644 --- a/mongosql/src/lib.rs +++ b/mongosql/src/lib.rs @@ -131,6 +131,91 @@ pub fn translate_sql( }) } +/// Returns the debug-formatted AST for the provided SQL after syntactic rewrites. +/// +/// Prints the Rust struct tree of the rewritten AST using `{:#?}`. +/// Intended for debugging; output format is not stable across versions. +/// +/// # Errors +/// +/// Returns `Err` if parsing or rewriting fails. +pub fn translate_sql_to_ast_repr(sql: &str) -> Result { + let ast = parser::parse_query(sql)?; + let ast = ast::rewrites::rewrite_query(ast)?; + Ok(format!("{ast:#?}")) +} + +/// Returns the debug-formatted MIR for the provided SQL after algebrizing and optimization. +/// +/// Prints the Rust struct tree of the optimizer-output MIR using `{:#?}`. +/// Intended for debugging; output format is not stable across versions. +/// +/// # Errors +/// +/// Returns `Err` if parsing, algebrizing, or schema-checking fails. +pub fn translate_sql_to_mir_repr( + current_db: &str, + sql: &str, + catalog: &Catalog, + sql_options: SqlOptions, +) -> Result { + let ast = parser::parse_query(sql)?; + let ast = ast::rewrites::rewrite_query(ast)?; + + let algebrizer = Algebrizer::new( + current_db, + catalog, + 0u16, + sql_options.schema_checking_mode, + sql_options.allow_order_by_missing_columns, + crate::algebrizer::ClauseType::Unintialized, + ); + let plan = algebrizer.algebrize_query(ast)?; + let plan = mir::optimizer::optimize_plan( + plan, + sql_options.schema_checking_mode, + &algebrizer.schema_inference_state(), + ); + + Ok(format!("{plan:#?}")) +} + +/// Returns the debug-formatted AIR for the provided SQL after MIR translation and desugaring. +/// +/// Prints the Rust struct tree of the fully desugared AIR stage using `{:#?}`. +/// Intended for debugging; output format is not stable across versions. +/// +/// # Errors +/// +/// Returns `Err` if parsing, algebrizing, translating, or desugaring fails. +pub fn translate_sql_to_air_repr( + current_db: &str, + sql: &str, + catalog: &Catalog, + sql_options: SqlOptions, +) -> Result { + let ast = parser::parse_query(sql)?; + let ast = ast::rewrites::rewrite_query(ast)?; + let algebrizer = Algebrizer::new( + current_db, + catalog, + 0u16, + sql_options.schema_checking_mode, + sql_options.allow_order_by_missing_columns, + crate::algebrizer::ClauseType::Unintialized, + ); + let plan = algebrizer.algebrize_query(ast)?; + let plan = mir::optimizer::optimize_plan( + plan, + sql_options.schema_checking_mode, + &algebrizer.schema_inference_state(), + ); + let mut translator = MqlTranslator::new(sql_options); + let air_plan = translator.translate_plan(plan)?; + let air_plan = air::desugarer::desugar_pipeline(air_plan)?; + Ok(format!("{air_plan:#?}")) +} + #[allow(clippy::result_large_err)] pub fn get_namespaces( current_db: &str, From b20cf4c19265ffddc4ac0ae86520941be5e46be2 Mon Sep 17 00:00:00 2001 From: Jonathan Powell Date: Mon, 11 May 2026 17:35:06 -0400 Subject: [PATCH 3/5] In-Progress: First pass at LSP for .mir and .air --- .gitignore | 1 + Cargo.lock | 237 +++++++++++++++- Cargo.toml | 1 + README.md | 41 +++ client/package-lock.json | 103 +++++++ client/package.json | 11 + client/src/extension.ts | 43 +++ mongosql-lsp/Cargo.toml | 13 + mongosql-lsp/src/hover.rs | 135 +++++++++ mongosql-lsp/src/main.rs | 31 +++ mongosql-lsp/src/parser.rs | 554 +++++++++++++++++++++++++++++++++++++ mongosql-lsp/src/server.rs | 302 ++++++++++++++++++++ package-lock.json | 391 ++++++++++++++++++++++++++ package.json | 34 +++ rolldown.config.mjs | 10 + test_output.mir | 69 +++++ test_where.mir | 144 ++++++++++ tsconfig.json | 14 + 18 files changed, 2128 insertions(+), 6 deletions(-) create mode 100644 client/package-lock.json create mode 100644 client/package.json create mode 100644 client/src/extension.ts create mode 100644 mongosql-lsp/Cargo.toml create mode 100644 mongosql-lsp/src/hover.rs create mode 100644 mongosql-lsp/src/main.rs create mode 100644 mongosql-lsp/src/parser.rs create mode 100644 mongosql-lsp/src/server.rs create mode 100644 package-lock.json create mode 100644 package.json create mode 100644 rolldown.config.mjs create mode 100644 test_output.mir create mode 100644 test_where.mir create mode 100644 tsconfig.json diff --git a/.gitignore b/.gitignore index 1e5292bff..c1ad7e7ac 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ target/ +node_modules/ mongosqlrun .idea .DS_Store diff --git a/Cargo.lock b/Cargo.lock index e1cfac3c0..487910444 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -89,7 +89,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" dependencies = [ "anstyle", - "anstyle-parse", + "anstyle-parse 0.2.7", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstream" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" +dependencies = [ + "anstyle", + "anstyle-parse 1.0.0", "anstyle-query", "anstyle-wincon", "colorchoice", @@ -112,6 +127,15 @@ dependencies = [ "utf8parse", ] +[[package]] +name = "anstyle-parse" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" +dependencies = [ + "utf8parse", +] + [[package]] name = "anstyle-query" version = "1.1.5" @@ -190,6 +214,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "auto_impl" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffdcb70bdbc4d478427380519163274ac86e52916e10f0a8889adf0f96d3fee7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -216,7 +251,7 @@ dependencies = [ "pin-project-lite", "serde_core", "sync_wrapper 1.0.2", - "tower", + "tower 0.5.3", "tower-layer", "tower-service", ] @@ -475,7 +510,7 @@ version = "4.5.54" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa42cf4d2b7a41bc8f663a7cab4031ebafa1bf3875705bfaf8466dc60ab52c00" dependencies = [ - "anstream", + "anstream 0.6.21", "anstyle", "clap_lex", "strsim", @@ -688,6 +723,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" +[[package]] +name = "crop" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fa28d21af9044d49bcfb1eceddeeb9fadc6a17207c6a1bf0d841ef0bd9ad36e" +dependencies = [ + "str_indices", +] + [[package]] name = "crossbeam-channel" version = "0.5.15" @@ -794,6 +838,33 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core 0.9.12", +] + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core 0.9.12", +] + [[package]] name = "data-encoding" version = "2.10.0" @@ -1001,6 +1072,29 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "env_filter" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" +dependencies = [ + "anstream 1.0.0", + "anstyle", + "env_filter", + "jiff", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1308,6 +1402,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" @@ -1843,6 +1943,30 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +[[package]] +name = "jiff" +version = "0.2.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde_core", +] + +[[package]] +name = "jiff-static" +version = "0.2.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "jni" version = "0.21.1" @@ -1970,6 +2094,19 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lsp-types" +version = "0.94.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1" +dependencies = [ + "bitflags 1.2.1", + "serde", + "serde_json", + "serde_repr", + "url", +] + [[package]] name = "macro_magic" version = "0.5.1" @@ -2264,6 +2401,19 @@ dependencies = [ "thiserror 2.0.17", ] +[[package]] +name = "mongosql-lsp" +version = "0.1.0" +dependencies = [ + "crop", + "dashmap 6.1.0", + "env_logger", + "log", + "serde_json", + "tokio", + "tower-lsp", +] + [[package]] name = "mongosqltranslate" version = "0.0.0" @@ -2705,6 +2855,15 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" +[[package]] +name = "portable-atomic-util" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a106d1259c23fac8e543272398ae0e3c0b8d33c88ed73d0cc71b0f1d902618" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.4" @@ -3233,7 +3392,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.2", "tokio", - "tower", + "tower 0.5.3", "tower-http", "tower-service", "url", @@ -3530,6 +3689,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "serde_stacker" version = "0.1.14" @@ -3722,6 +3892,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "str_indices" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d08889ec5408683408db66ad89e0e1f93dff55c73a4ccc71c427d5b277ee47e6" + [[package]] name = "str_stack" version = "0.1.0" @@ -4195,7 +4371,7 @@ dependencies = [ "sync_wrapper 1.0.2", "tokio", "tokio-stream", - "tower", + "tower 0.5.3", "tower-layer", "tower-service", "tracing", @@ -4254,6 +4430,20 @@ dependencies = [ "tonic-prost", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower" version = "0.5.3" @@ -4286,7 +4476,7 @@ dependencies = [ "http-body 1.0.1", "iri-string", "pin-project-lite", - "tower", + "tower 0.5.3", "tower-layer", "tower-service", ] @@ -4297,6 +4487,40 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" +[[package]] +name = "tower-lsp" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4ba052b54a6627628d9b3c34c176e7eda8359b7da9acd497b9f20998d118508" +dependencies = [ + "async-trait", + "auto_impl", + "bytes", + "dashmap 5.5.3", + "futures", + "httparse", + "lsp-types", + "memchr", + "serde", + "serde_json", + "tokio", + "tokio-util", + "tower 0.4.13", + "tower-lsp-macros", + "tracing", +] + +[[package]] +name = "tower-lsp-macros" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84fd902d4e0b9a4b27f2f440108dc034e1758628a9b702f8ec61ad66355422fa" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "tower-service" version = "0.3.3" @@ -4492,6 +4716,7 @@ dependencies = [ "idna", "percent-encoding", "serde", + "serde_derive", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 72e4bdc07..3efc0cedb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "mongosql-datastructures", "mongosql-c", "mongosql", + "mongosql-lsp", "schema-builder-library", "service", "test-utils", diff --git a/README.md b/README.md index b9df5a28f..54bf1f304 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,47 @@ When `--schema-file` is omitted, the CLI connects to MongoDB and reads schema fr > > This CLI tool is __only__ available for MongoDB Enterprise Advanced (EA) customers. > Refer to the [Schema Builder documentation](https://www.mongodb.com/docs/sql-interface/sql-interface-install/) for more information. +## MongoSQL LSP + +The `mongosql-lsp` binary is a Language Server Protocol (LSP) server for `.mir` and `.air` debug-tree files produced by `mongosql-cli --stage mir` and `--stage air`. It provides syntax highlighting, code folding, and hover tooltips explaining each node type. + +### Building the server + +```bash +cargo build -p mongosql-lsp +``` + +### Installing the VS Code extension + +The repository includes a thin VS Code extension that spawns the LSP server. Install its dependencies once from the repo root (the `postinstall` script automatically installs the client dependencies as well): + +```bash +npm install +``` + +Then build the extension bundle: + +```bash +npm run build +``` + +### Launching in VS Code + +1. Build the server: `cargo build -p mongosql-lsp` +2. Press **F5** in VS Code — this runs the **Launch Client** configuration in `.vscode/launch.json`, which opens an Extension Development Host window with `SERVER_PATH` pointed at `target/debug/mongosql-lsp`. +3. Open any `.mir` or `.air` file to get: + - **Syntax highlighting** — enum variant names, struct field keys, string/number literals, and keywords each in a distinct colour. + - **Code folding** — fold struct bodies and arrays with the editor's fold shortcut. + - **Hover tooltips** — hover over a node name (e.g. `Filter`, `Project`, `Lookup`) for a Markdown description of that compilation stage. + +### Using with other editors + +Because `mongosql-lsp` speaks plain JSON-RPC over stdio it works with any LSP-capable editor. For Neovim: + +```lua +vim.lsp.start({ cmd = { "mongosql-lsp" }, filetypes = { "mir", "air" } }) +``` + ## Rust testing There are several types of tests for the Rust code: unit tests, fuzz tests, index usage tests, e2e diff --git a/client/package-lock.json b/client/package-lock.json new file mode 100644 index 000000000..de125831a --- /dev/null +++ b/client/package-lock.json @@ -0,0 +1,103 @@ +{ + "name": "mongosql-lsp-client", + "version": "0.1.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "mongosql-lsp-client", + "version": "0.1.0", + "dependencies": { + "vscode-languageclient": "^9.0.1" + }, + "devDependencies": { + "@types/vscode": "^1.75.0" + } + }, + "node_modules/@types/vscode": { + "version": "1.118.0", + "resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.118.0.tgz", + "integrity": "sha512-Ah6eTlqDcwIMELEVwQMO++rJAFBRz/oLluLD/vWdYrH1KuI9kfpaM+7pg0OvvascgcJy+ghLCERAYouM4QbzGw==", + "dev": true, + "license": "MIT" + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "license": "MIT" + }, + "node_modules/brace-expansion": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.1.0.tgz", + "integrity": "sha512-TN1kCZAgdgweJhWWpgKYrQaMNHcDULHkWwQIspdtjV4Y5aurRdZpjAqn6yX3FPqTA9ngHCc4hJxMAMgGfve85w==", + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/minimatch": { + "version": "5.1.9", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.9.tgz", + "integrity": "sha512-7o1wEA2RyMP7Iu7GNba9vc0RWWGACJOCZBJX2GJWip0ikV+wcOsgVuY9uE8CPiyQhkGFSlhuSkZPavN7u1c2Fw==", + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/semver": { + "version": "7.8.0", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.8.0.tgz", + "integrity": "sha512-AcM7dV/5ul4EekoQ29Agm5vri8JNqRyj39o0qpX6vDF2GZrtutZl5RwgD1XnZjiTAfncsJhMI48QQH3sN87YNA==", + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/vscode-jsonrpc": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/vscode-jsonrpc/-/vscode-jsonrpc-8.2.0.tgz", + "integrity": "sha512-C+r0eKJUIfiDIfwJhria30+TYWPtuHJXHtI7J0YlOmKAo7ogxP20T0zxB7HZQIFhIyvoBPwWskjxrvAtfjyZfA==", + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/vscode-languageclient": { + "version": "9.0.1", + "resolved": "https://registry.npmjs.org/vscode-languageclient/-/vscode-languageclient-9.0.1.tgz", + "integrity": "sha512-JZiimVdvimEuHh5olxhxkht09m3JzUGwggb5eRUkzzJhZ2KjCN0nh55VfiED9oez9DyF8/fz1g1iBV3h+0Z2EA==", + "license": "MIT", + "dependencies": { + "minimatch": "^5.1.0", + "semver": "^7.3.7", + "vscode-languageserver-protocol": "3.17.5" + }, + "engines": { + "vscode": "^1.82.0" + } + }, + "node_modules/vscode-languageserver-protocol": { + "version": "3.17.5", + "resolved": "https://registry.npmjs.org/vscode-languageserver-protocol/-/vscode-languageserver-protocol-3.17.5.tgz", + "integrity": "sha512-mb1bvRJN8SVznADSGWM9u/b07H7Ecg0I3OgXDuLdn307rl/J3A9YD6/eYOssqhecL27hK1IPZAsaqh00i/Jljg==", + "license": "MIT", + "dependencies": { + "vscode-jsonrpc": "8.2.0", + "vscode-languageserver-types": "3.17.5" + } + }, + "node_modules/vscode-languageserver-types": { + "version": "3.17.5", + "resolved": "https://registry.npmjs.org/vscode-languageserver-types/-/vscode-languageserver-types-3.17.5.tgz", + "integrity": "sha512-Ld1VelNuX9pdF39h2Hgaeb5hEZM2Z3jUrrMgWQAu82jMtZp7p3vJT3BzToKtZI7NgQssZje5o0zryOrhQvzQAg==", + "license": "MIT" + } + } +} diff --git a/client/package.json b/client/package.json new file mode 100644 index 000000000..91c63729a --- /dev/null +++ b/client/package.json @@ -0,0 +1,11 @@ +{ + "name": "mongosql-lsp-client", + "version": "0.1.0", + "private": true, + "dependencies": { + "vscode-languageclient": "^9.0.1" + }, + "devDependencies": { + "@types/vscode": "^1.75.0" + } +} diff --git a/client/src/extension.ts b/client/src/extension.ts new file mode 100644 index 000000000..998786975 --- /dev/null +++ b/client/src/extension.ts @@ -0,0 +1,43 @@ +import * as path from "path"; +import * as vscode from "vscode"; +import { + LanguageClient, + LanguageClientOptions, + ServerOptions, + TransportKind, +} from "vscode-languageclient/node"; + +let client: LanguageClient; + +export function activate(context: vscode.ExtensionContext): void { + const command = + process.env["SERVER_PATH"] ?? + context.asAbsolutePath(path.join("..", "target", "debug", "mongosql-lsp")); + + const run: ServerOptions = { + command, + transport: TransportKind.stdio, + options: { env: { ...process.env, RUST_LOG: "debug" } }, + }; + + const clientOptions: LanguageClientOptions = { + documentSelector: [ + { scheme: "file", language: "mongosql-mir" }, + { scheme: "file", language: "mongosql-air" }, + ], + }; + + client = new LanguageClient( + "mongosql-lsp", + "MongoSQL LSP", + run, + clientOptions + ); + + context.subscriptions.push(client); + client.start(); +} + +export function deactivate(): Thenable | undefined { + return client?.stop(); +} diff --git a/mongosql-lsp/Cargo.toml b/mongosql-lsp/Cargo.toml new file mode 100644 index 000000000..61d09651a --- /dev/null +++ b/mongosql-lsp/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "mongosql-lsp" +version = "0.1.0" +edition = "2021" + +[dependencies] +tower-lsp = { version = "0.20", features = ["proposed"] } +tokio = { version = "1", features = ["full"] } +dashmap = "6" +crop = "0.4" +serde_json = "1" +log = "0.4" +env_logger = "0.11" diff --git a/mongosql-lsp/src/hover.rs b/mongosql-lsp/src/hover.rs new file mode 100644 index 000000000..31f563af5 --- /dev/null +++ b/mongosql-lsp/src/hover.rs @@ -0,0 +1,135 @@ +//! Static hover-tooltip map for MIR and AIR node names. + +use std::collections::HashMap; +use std::sync::LazyLock; + +/// Maps a node name (as it appears in `{:#?}` output) to a Markdown tooltip. +pub static HOVER_MAP: LazyLock> = LazyLock::new(|| { + let mut m = HashMap::new(); + + // ── MIR (Mid-level Intermediate Representation) nodes ──────────────────── + m.insert( + "Project", + "**MIR Project** — Evaluates the SELECT-list expressions and produces one output \ + tuple per input row. The `expression` field is a `BindingTuple` whose keys become \ + the output column names.", + ); + m.insert( + "Filter", + "**MIR Filter** — Retains only rows for which `condition` evaluates to a truthy \ + value. Corresponds to a SQL `WHERE` or `HAVING` clause.", + ); + m.insert( + "Collection", + "**MIR Collection** — Scans every document in a MongoDB collection. \ + The `db` and `collection` fields identify the target namespace.", + ); + m.insert( + "Group", + "**MIR Group** — Partitions input rows into groups (SQL `GROUP BY`). \ + The `keys` field lists the grouping expressions; `aggregations` contains \ + the per-group aggregate functions.", + ); + m.insert( + "Sort", + "**MIR Sort** — Orders the result set (SQL `ORDER BY`). \ + Each `SortSpecification` carries an expression and a direction.", + ); + m.insert( + "Limit", + "**MIR Limit** — Restricts the number of output rows (SQL `LIMIT n`).", + ); + m.insert( + "Offset", + "**MIR Offset** — Skips the first `n` rows (SQL `OFFSET n`).", + ); + m.insert( + "Unwind", + "**MIR Unwind** — Unnests an array-valued field into individual rows, one per \ + array element. Equivalent to MongoDB `$unwind`.", + ); + m.insert( + "Join", + "**MIR Join** — Combines rows from two sub-trees (SQL `JOIN`). \ + The `join_type` field indicates `Inner`, `Left`, `Cross`, etc.", + ); + m.insert( + "Subquery", + "**MIR Subquery** — A correlated or uncorrelated sub-query used as an expression \ + or a data source.", + ); + m.insert( + "Set", + "**MIR Set** — A set operation node (`UNION ALL`, `UNION`, `INTERSECT`, `EXCEPT`).", + ); + m.insert( + "Derived", + "**MIR Derived** — An inline view (sub-select in a `FROM` clause).", + ); + m.insert( + "ExpressionCollection", + "**MIR ExpressionCollection** — A virtual collection built from an inline array \ + expression rather than a stored MongoDB collection.", + ); + + // ── AIR (Aggregation Intermediate Representation) nodes ────────────────── + m.insert( + "ReplaceWith", + "**AIR ReplaceWith** — Emits a MongoDB `$replaceWith` aggregation stage that \ + replaces each document with the result of an expression.", + ); + m.insert( + "Lookup", + "**AIR Lookup** — Emits a `$lookup` aggregation stage to implement a SQL `JOIN`.", + ); + m.insert( + "Unset", + "**AIR Unset** — Emits a `$unset` aggregation stage to remove fields from \ + documents.", + ); + m.insert( + "AddFields", + "**AIR AddFields** — Emits a `$addFields` aggregation stage to add or overwrite \ + document fields.", + ); + m.insert( + "Match", + "**AIR Match** — Emits a `$match` stage (corresponds to a MIR `Filter`).", + ); + m.insert( + "Project", // AIR also has Project + "**AIR / MIR Project** — Emits a `$project` aggregation stage that reshapes \ + documents to the desired output fields.", + ); + m.insert( + "Group", // AIR also has Group + "**AIR Group** — Emits a `$group` aggregation stage (SQL `GROUP BY`).", + ); + m.insert( + "Sort", // AIR also has Sort + "**AIR Sort** — Emits a `$sort` aggregation stage (SQL `ORDER BY`).", + ); + m.insert( + "Limit", // AIR also has Limit + "**AIR Limit** — Emits a `$limit` aggregation stage.", + ); + m.insert( + "Skip", + "**AIR Skip** — Emits a `$skip` aggregation stage (SQL `OFFSET`).", + ); + m.insert( + "Unwind", // AIR also has Unwind + "**AIR Unwind** — Emits a `$unwind` aggregation stage.", + ); + m.insert( + "Documents", + "**AIR Documents** — Emits a `$documents` stage for inline data.", + ); + m.insert( + "Collection", // AIR root source + "**AIR Collection** — The root source of an aggregation pipeline, identifying the \ + MongoDB collection to run the pipeline against.", + ); + + m +}); diff --git a/mongosql-lsp/src/main.rs b/mongosql-lsp/src/main.rs new file mode 100644 index 000000000..36cb5dc6d --- /dev/null +++ b/mongosql-lsp/src/main.rs @@ -0,0 +1,31 @@ +//! mongosql-lsp — Language Server Protocol server for `.mir` and `.air` debug files. +//! +//! Reads from stdin and writes to stdout (JSON-RPC over stdio), making it compatible +//! with any LSP-capable editor. Start it with the VS Code extension or configure it +//! directly in Neovim, Helix, etc. +//! +//! # Usage +//! +//! ```text +//! cargo build -p mongosql-lsp +//! # The VS Code extension launches this automatically via SERVER_PATH. +//! ``` + +mod hover; +mod parser; +mod server; + +use server::Backend; +use tower_lsp::{LspService, Server}; + +#[tokio::main] +async fn main() { + env_logger::init(); + + let stdin = tokio::io::stdin(); + let stdout = tokio::io::stdout(); + + let (service, socket) = LspService::build(Backend::new).finish(); + + Server::new(stdin, stdout, socket).serve(service).await; +} diff --git a/mongosql-lsp/src/parser.rs b/mongosql-lsp/src/parser.rs new file mode 100644 index 000000000..0515fdc8d --- /dev/null +++ b/mongosql-lsp/src/parser.rs @@ -0,0 +1,554 @@ +//! Recursive-descent parser for Rust `{:#?}` pretty-debug output. +//! +//! Produces a concrete syntax tree (CST) with byte-offset ranges so that the +//! LSP server can map cursor positions back to node names for hover and can +//! emit semantic token spans for syntax highlighting. + +use std::ops::Range; + +// ── CST node types ──────────────────────────────────────────────────────────── + +/// A node in the `{:#?}` concrete syntax tree. +#[derive(Debug)] +pub enum Node { + /// `Name(payload)` — enum variant with a tuple payload. + EnumVariant { + /// Byte range of the variant name. + name: Range, + payload: Option>, + }, + /// `Name { field: value, … }` — struct with named fields. + Struct { + /// `Some` when the struct is prefixed by an identifier (e.g. `Filter {`). + name: Option>, + fields: Vec, + }, + /// `{ Key { … }: value, … }` — Rust `HashMap`/`BTreeMap` debug output. + /// + /// Each entry is `(key_node, value_node)` where the key is itself a struct. + Map { entries: Vec<(Node, Node)> }, + /// `[ item, … ]` — sequence / array. + Sequence { items: Vec }, + /// Any atomic leaf: string literal, number, keyword (`true`, `false`, `None`, `Some(…)`). + Leaf { range: Range }, +} + +/// A `key: value` pair inside a struct. +#[derive(Debug)] +pub struct Field { + /// Byte range of the field name (before the `:`). + pub key: Range, + pub value: Node, +} + +/// The fully parsed document, keeping the original source text for range → `&str` look-ups. +#[derive(Debug)] +pub struct ParsedDoc { + pub root: Node, + /// The original source text (kept so ranges can be resolved later). + pub text: String, +} + +impl ParsedDoc { + /// Returns the name of the innermost named node that contains `offset`, if any. + pub fn node_name_at(&self, offset: usize) -> Option<&str> { + find_name_at(&self.root, offset, &self.text) + } +} + +// ── Entry point ─────────────────────────────────────────────────────────────── + +/// Parse a full `{:#?}` document. Never fails — unknown syntax is stored as `Leaf`. +pub fn parse(text: &str) -> ParsedDoc { + let mut p = Parser { src: text, pos: 0 }; + let root = p.parse_value(); + ParsedDoc { + root, + text: text.to_owned(), + } +} + +// ── Parser internals ────────────────────────────────────────────────────────── + +struct Parser<'a> { + src: &'a str, + pos: usize, +} + +impl Parser<'_> { + // ── helpers ────────────────────────────────────────────────────────────── + + fn remaining(&self) -> &str { + &self.src[self.pos..] + } + + fn peek(&self) -> Option { + self.remaining().chars().next() + } + + fn advance(&mut self, bytes: usize) { + self.pos += bytes; + } + + fn skip_whitespace(&mut self) { + while let Some(c) = self.peek() { + if c.is_whitespace() { + self.advance(c.len_utf8()); + } else { + break; + } + } + } + + fn eat(&mut self, ch: char) -> bool { + if self.remaining().starts_with(ch) { + self.advance(ch.len_utf8()); + true + } else { + false + } + } + + // ── identifier (ASCII word chars) ───────────────────────────────────────── + + fn parse_ident(&mut self) -> Option> { + let start = self.pos; + while let Some(c) = self.peek() { + if c.is_alphanumeric() || c == '_' { + self.advance(c.len_utf8()); + } else { + break; + } + } + if self.pos > start { + Some(start..self.pos) + } else { + None + } + } + + // ── string literal `"…"` ───────────────────────────────────────────────── + + fn parse_string(&mut self) -> Node { + let start = self.pos; + // consume opening `"` + self.advance(1); + let mut escaped = false; + while let Some(c) = self.peek() { + self.advance(c.len_utf8()); + if escaped { + escaped = false; + } else if c == '\\' { + escaped = true; + } else if c == '"' { + break; + } + } + Node::Leaf { + range: start..self.pos, + } + } + + // ── number literal ──────────────────────────────────────────────────────── + + fn parse_number(&mut self) -> Node { + let start = self.pos; + // optional leading `-` + if self.remaining().starts_with('-') { + self.advance(1); + } + while let Some(c) = self.peek() { + if c.is_ascii_digit() + || c == '.' + || c == '_' + || c == 'e' + || c == 'E' + || c == '+' + || c == '-' + { + self.advance(c.len_utf8()); + } else { + break; + } + } + Node::Leaf { + range: start..self.pos, + } + } + + // ── sequence `[ … ]` ───────────────────────────────────────────────────── + + fn parse_sequence(&mut self) -> Node { + // consume `[` + self.advance(1); + let mut items = Vec::new(); + loop { + self.skip_whitespace(); + if self.remaining().starts_with(']') { + self.advance(1); + break; + } + if self.remaining().is_empty() { + break; + } + let item = self.parse_value(); + items.push(item); + self.skip_whitespace(); + if self.remaining().starts_with(',') { + self.advance(1); + } + } + Node::Sequence { items } + } + + // ── struct body `{ field: value, … }` or map `{ Key {…}: val, … }` ───────── + + /// Parse the body of a `{…}` block, returning either a `Node::Struct` (for + /// normal named-field structs) or a `Node::Map` (for Rust `HashMap`/`BTreeMap` + /// debug output where each key is itself a struct expression). + /// + /// The opening `{` must not yet have been consumed. + fn parse_struct_body(&mut self) -> Node { + self.advance(1); // consume `{` + let mut fields: Vec = Vec::new(); + let mut entries: Vec<(Node, Node)> = Vec::new(); + + loop { + self.skip_whitespace(); + if self.remaining().starts_with('}') { + self.advance(1); + break; + } + if self.remaining().is_empty() { + break; + } + + if let Some(key_range) = self.parse_ident() { + self.skip_whitespace(); + if self.eat(':') { + // Normal struct field: `field_name: value` + self.skip_whitespace(); + let value = self.parse_value(); + fields.push(Field { + key: key_range, + value, + }); + } else if self.remaining().starts_with('{') { + // Map-key pattern: `StructName { inner_fields }: map_value`. + // The `ident { … }` is the HashMap key; what follows `:` is the value. + let inner = self.parse_struct_body(); // consumes inner `{…}` + let key_node = match inner { + Node::Struct { + fields: inner_fields, + .. + } => Node::Struct { + name: Some(key_range), + fields: inner_fields, + }, + other => other, + }; + self.skip_whitespace(); + self.eat(':'); // consume the map-entry separator + self.skip_whitespace(); + let map_value = self.parse_value(); + entries.push((key_node, map_value)); + } + // else: bare ident not followed by `:` or `{` — skip silently + } else if self.remaining().starts_with('<') { + // Skip opaque tokens such as ``. + while let Some(c) = self.peek() { + self.advance(c.len_utf8()); + if c == '>' { + break; + } + } + } else if let Some(c) = self.peek() { + self.advance(c.len_utf8()); + } + + self.skip_whitespace(); + if self.remaining().starts_with(',') { + self.advance(1); + } + } + + if entries.is_empty() { + Node::Struct { name: None, fields } + } else { + Node::Map { entries } + } + } + + // ── tuple payload `( … )` ───────────────────────────────────────────────── + + fn parse_paren_payload(&mut self) -> Node { + // consume `(` + self.advance(1); + self.skip_whitespace(); + if self.remaining().starts_with(')') { + self.advance(1); + return Node::Sequence { items: vec![] }; + } + let inner = self.parse_value(); + self.skip_whitespace(); + if self.remaining().starts_with(',') { + self.advance(1); + } + self.skip_whitespace(); + if self.remaining().starts_with(')') { + self.advance(1); + } + inner + } + + // ── top-level value dispatcher ──────────────────────────────────────────── + + fn parse_value(&mut self) -> Node { + self.skip_whitespace(); + match self.peek() { + Some('"') => self.parse_string(), + Some('[') => self.parse_sequence(), + Some('{') => { + // Anonymous `{…}` — may be a struct body or a map body. + self.parse_struct_body() + } + Some('(') => self.parse_paren_payload(), + Some('<') => { + // Opaque token such as `` — consume through `>`. + let start = self.pos; + while let Some(c) = self.peek() { + self.advance(c.len_utf8()); + if c == '>' { + break; + } + } + Node::Leaf { + range: start..self.pos, + } + } + Some(c) if c == '-' || c.is_ascii_digit() => self.parse_number(), + Some(_) => { + // identifier: could be keyword, enum variant, or struct name + if let Some(ident_range) = self.parse_ident() { + self.skip_whitespace(); + match self.peek() { + Some('(') => { + // enum variant with tuple payload — or `Some(…)` / `None` + let payload = self.parse_paren_payload(); + Node::EnumVariant { + name: ident_range, + payload: Some(Box::new(payload)), + } + } + Some('{') => { + // Named struct: `StructName { field: value, … }`. + // `parse_struct_body` returns `Node::Struct { name: None, … }`; + // we attach the name here. + let body = self.parse_struct_body(); + match body { + Node::Struct { fields, .. } => Node::Struct { + name: Some(ident_range), + fields, + }, + other => other, + } + } + _ => { + // bare keyword: `None`, `true`, `false`, etc. + Node::Leaf { range: ident_range } + } + } + } else { + // skip one unrecognised char and return an empty leaf + let start = self.pos; + if let Some(c) = self.peek() { + self.advance(c.len_utf8()); + } + Node::Leaf { + range: start..self.pos, + } + } + } + None => Node::Leaf { + range: self.pos..self.pos, + }, + } + } +} + +// ── Visitor helpers ─────────────────────────────────────────────────────────── + +/// Walk the CST returning all semantic token spans as `(range, token_type_index)` pairs. +/// +/// Token type indices match the legend declared in `server::TOKEN_TYPES`: +/// - 0 `type` — variant / struct names +/// - 1 `property` — struct field keys +/// - 2 `string` — string literals +/// - 3 `number` — numeric literals +/// - 4 `keyword` — `true`, `false`, `None`, `Some` +pub fn collect_tokens(node: &Node, src: &str) -> Vec<(Range, u32)> { + let mut out = Vec::new(); + collect_tokens_inner(node, src, &mut out); + out +} + +fn collect_tokens_inner(node: &Node, src: &str, out: &mut Vec<(Range, u32)>) { + match node { + Node::EnumVariant { name, payload } => { + out.push((name.clone(), 0)); // `type` + if let Some(p) = payload { + collect_tokens_inner(p, src, out); + } + } + Node::Struct { name, fields } => { + if let Some(n) = name { + out.push((n.clone(), 0)); // `type` + } + for f in fields { + out.push((f.key.clone(), 1)); // `property` + collect_tokens_inner(&f.value, src, out); + } + } + Node::Map { entries } => { + for (key, value) in entries { + // The key is typically a named struct; walk it to emit its name + // as a `type` token and its inner fields as `property` tokens. + collect_tokens_inner(key, src, out); + collect_tokens_inner(value, src, out); + } + } + Node::Sequence { items } => { + for item in items { + collect_tokens_inner(item, src, out); + } + } + Node::Leaf { range } => { + if range.is_empty() { + return; + } + let text = &src[range.clone()]; + let ty = match text { + "true" | "false" | "None" | "Some" => 4, // `keyword` + s if s.starts_with('"') => 2, // `string` + s if s.starts_with(|c: char| c.is_ascii_digit() || c == '-') => 3, // `number` + _ => 4, // treat unknown bare idents as keywords + }; + out.push((range.clone(), ty)); + } + } +} + +/// Walk the CST and collect byte-offset ranges of every `{`, `(`, or `[` … matching closer. +/// Used to produce folding ranges. +pub fn collect_fold_ranges(node: &Node, src: &str) -> Vec<(usize, usize)> { + let mut out = Vec::new(); + collect_fold_inner(node, src, &mut out); + out +} + +#[expect( + clippy::only_used_in_recursion, + reason = "src is forwarded to recursive calls for API symmetry with collect_tokens_inner" +)] +fn collect_fold_inner(node: &Node, src: &str, out: &mut Vec<(usize, usize)>) { + // We don't track brace positions directly in the CST; instead we record a range + // that spans a multi-token node by inspecting the text around its children. + // For simplicity we emit one fold per struct/sequence/variant-with-payload that + // contains at least one child. + match node { + Node::EnumVariant { + payload: Some(p), .. + } => { + collect_fold_inner(p, src, out); + } + Node::Struct { fields, .. } if !fields.is_empty() => { + // Estimate the byte range of this struct by first/last field + let first = fields.first().map(|f| f.key.start); + let last = fields.last().map(|f| node_end(&f.value)); + if let (Some(start), Some(end)) = (first, last) { + out.push((start, end)); + } + for f in fields { + collect_fold_inner(&f.value, src, out); + } + } + Node::Sequence { items } if items.len() > 1 => { + let first = items.first().map(node_start); + let last = items.last().map(node_end); + if let (Some(start), Some(end)) = (first, last) { + out.push((start, end)); + } + for item in items { + collect_fold_inner(item, src, out); + } + } + Node::Map { entries } if !entries.is_empty() => { + let first = entries.first().map(|(k, _)| node_start(k)); + let last = entries.last().map(|(_, v)| node_end(v)); + if let (Some(start), Some(end)) = (first, last) { + out.push((start, end)); + } + for (key, value) in entries { + collect_fold_inner(key, src, out); + collect_fold_inner(value, src, out); + } + } + _ => {} + } +} + +fn node_start(n: &Node) -> usize { + match n { + Node::EnumVariant { name, .. } => name.start, + Node::Struct { name: Some(r), .. } => r.start, + Node::Struct { fields, .. } => fields.first().map_or(0, |f| f.key.start), + Node::Map { entries } => entries.first().map_or(0, |(k, _)| node_start(k)), + Node::Sequence { items } => items.first().map_or(0, node_start), + Node::Leaf { range } => range.start, + } +} + +fn node_end(n: &Node) -> usize { + match n { + Node::EnumVariant { + payload: Some(p), .. + } => node_end(p), + Node::EnumVariant { name, .. } => name.end, + Node::Struct { fields, .. } => fields.last().map_or(0, |f| node_end(&f.value)), + Node::Map { entries } => entries.last().map_or(0, |(_, v)| node_end(v)), + Node::Sequence { items } => items.last().map_or(0, node_end), + Node::Leaf { range } => range.end, + } +} + +fn find_name_at<'s>(node: &Node, offset: usize, src: &'s str) -> Option<&'s str> { + match node { + Node::EnumVariant { name, payload } => { + if name.contains(&offset) { + return Some(&src[name.clone()]); + } + if let Some(p) = payload { + return find_name_at(p, offset, src); + } + None + } + Node::Struct { name, fields } => { + if let Some(n) = name { + if n.contains(&offset) { + return Some(&src[n.clone()]); + } + } + for f in fields { + if let Some(r) = find_name_at(&f.value, offset, src) { + return Some(r); + } + } + None + } + Node::Map { entries } => entries.iter().find_map(|(k, v)| { + find_name_at(k, offset, src).or_else(|| find_name_at(v, offset, src)) + }), + Node::Sequence { items } => items.iter().find_map(|i| find_name_at(i, offset, src)), + Node::Leaf { .. } => None, + } +} diff --git a/mongosql-lsp/src/server.rs b/mongosql-lsp/src/server.rs new file mode 100644 index 000000000..15d779a91 --- /dev/null +++ b/mongosql-lsp/src/server.rs @@ -0,0 +1,302 @@ +//! LSP backend for `.mir` and `.air` debug-tree files. + +use crop::Rope; +use dashmap::DashMap; +use tower_lsp::jsonrpc::Result; +use tower_lsp::lsp_types::{ + DidChangeTextDocumentParams, DidCloseTextDocumentParams, DidOpenTextDocumentParams, + DidSaveTextDocumentParams, FoldingRange, FoldingRangeParams, FoldingRangeProviderCapability, + Hover, HoverContents, HoverParams, HoverProviderCapability, InitializeParams, InitializeResult, + InitializedParams, MarkupContent, MarkupKind, MessageType, Position, SaveOptions, + SemanticToken, SemanticTokenType, SemanticTokens, SemanticTokensFullOptions, + SemanticTokensLegend, SemanticTokensOptions, SemanticTokensParams, SemanticTokensResult, + SemanticTokensServerCapabilities, ServerCapabilities, TextDocumentSyncCapability, + TextDocumentSyncKind, TextDocumentSyncOptions, TextDocumentSyncSaveOptions, Url, +}; +use tower_lsp::{Client, LanguageServer}; + +use crate::hover::HOVER_MAP; +use crate::parser::{self, ParsedDoc}; + +// ── Semantic token legend ───────────────────────────────────────────────────── + +/// Token type names registered in `initialize`. +/// +/// Index mapping: +/// - 0 `type` — enum-variant / struct names +/// - 1 `property` — struct field keys +/// - 2 `string` — string literals +/// - 3 `number` — numeric literals +/// - 4 `keyword` — `true`, `false`, `None`, `Some` +pub const TOKEN_TYPES: &[&str] = &["type", "property", "string", "number", "keyword"]; + +// ── Backend ─────────────────────────────────────────────────────────────────── + +/// LSP backend state. +pub struct Backend { + client: Client, + /// Stores rope-encoded document text, keyed by URI string. + document_map: DashMap, + /// Stores the parsed CST for each open document. + ast_map: DashMap, +} + +impl Backend { + /// Creates a new backend bound to `client`. + pub fn new(client: Client) -> Self { + Self { + client, + document_map: DashMap::new(), + ast_map: DashMap::new(), + } + } + + /// Re-parse the document whenever it is opened or changed. + fn on_change(&self, uri: &Url, text: &str) { + let rope = Rope::from(text); + let doc = parser::parse(text); + self.document_map.insert(uri.to_string(), rope); + self.ast_map.insert(uri.to_string(), doc); + } +} + +// ── LanguageServer impl ─────────────────────────────────────────────────────── + +#[tower_lsp::async_trait] +impl LanguageServer for Backend { + async fn initialize(&self, _: InitializeParams) -> Result { + Ok(InitializeResult { + capabilities: ServerCapabilities { + text_document_sync: Some(TextDocumentSyncCapability::Options( + TextDocumentSyncOptions { + open_close: Some(true), + change: Some(TextDocumentSyncKind::FULL), + save: Some(TextDocumentSyncSaveOptions::SaveOptions(SaveOptions { + include_text: Some(true), + })), + ..Default::default() + }, + )), + semantic_tokens_provider: Some( + SemanticTokensServerCapabilities::SemanticTokensOptions( + SemanticTokensOptions { + legend: SemanticTokensLegend { + token_types: TOKEN_TYPES + .iter() + .map(|s| SemanticTokenType::new(s)) + .collect(), + token_modifiers: vec![], + }, + full: Some(SemanticTokensFullOptions::Bool(true)), + range: None, + ..Default::default() + }, + ), + ), + folding_range_provider: Some(FoldingRangeProviderCapability::Simple(true)), + hover_provider: Some(HoverProviderCapability::Simple(true)), + ..Default::default() + }, + ..Default::default() + }) + } + + async fn initialized(&self, _: InitializedParams) { + self.client + .log_message(MessageType::INFO, "mongosql-lsp initialized") + .await; + } + + async fn shutdown(&self) -> Result<()> { + Ok(()) + } + + async fn did_open(&self, params: DidOpenTextDocumentParams) { + self.on_change(¶ms.text_document.uri, ¶ms.text_document.text); + } + + async fn did_change(&self, params: DidChangeTextDocumentParams) { + // Full-sync mode: there is always exactly one content change. + if let Some(change) = params.content_changes.into_iter().next() { + self.on_change(¶ms.text_document.uri, &change.text); + } + } + + async fn did_save(&self, params: DidSaveTextDocumentParams) { + if let Some(text) = params.text { + self.on_change(¶ms.text_document.uri, &text); + } + } + + async fn did_close(&self, params: DidCloseTextDocumentParams) { + let uri = params.text_document.uri.to_string(); + self.document_map.remove(&uri); + self.ast_map.remove(&uri); + } + + async fn semantic_tokens_full( + &self, + params: SemanticTokensParams, + ) -> Result> { + let uri = params.text_document.uri.to_string(); + let tokens = (|| -> Option> { + let rope = self.document_map.get(&uri)?; + let doc = self.ast_map.get(&uri)?; + let spans = parser::collect_tokens(&doc.root, &doc.text); + Some(encode_tokens(&spans, &rope, &doc.text)) + })(); + + Ok(tokens.map(|data| { + SemanticTokensResult::Tokens(SemanticTokens { + result_id: None, + data, + }) + })) + } + + async fn folding_range(&self, params: FoldingRangeParams) -> Result>> { + let uri = params.text_document.uri.to_string(); + let ranges = (|| -> Option> { + let rope = self.document_map.get(&uri)?; + let doc = self.ast_map.get(&uri)?; + let spans = parser::collect_fold_ranges(&doc.root, &doc.text); + let mut result = Vec::new(); + for (start_byte, end_byte) in spans { + let start_pos = byte_to_position(start_byte, &rope, &doc.text); + let end_pos = byte_to_position(end_byte, &rope, &doc.text); + // Only emit folding ranges that span at least two lines. + if end_pos.line > start_pos.line { + result.push(FoldingRange { + start_line: start_pos.line, + start_character: Some(start_pos.character), + end_line: end_pos.line, + end_character: Some(end_pos.character), + kind: None, + collapsed_text: None, + }); + } + } + Some(result) + })(); + Ok(ranges) + } + + async fn hover(&self, params: HoverParams) -> Result> { + let uri = params + .text_document_position_params + .text_document + .uri + .to_string(); + let pos = params.text_document_position_params.position; + + // Inline IIFE so `?` works cleanly inside. + let hover = (|| -> Option { + let rope = self.document_map.get(&uri)?; + let doc = self.ast_map.get(&uri)?; + let offset = position_to_byte_offset(pos, &rope, &doc.text); + let name = doc.node_name_at(offset)?; + let desc = HOVER_MAP.get(name)?; + Some(Hover { + contents: HoverContents::Markup(MarkupContent { + kind: MarkupKind::Markdown, + value: (*desc).to_owned(), + }), + range: None, + }) + })(); + Ok(hover) + } +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +/// Convert a byte offset into a `(line, character)` LSP `Position`. +#[expect( + clippy::cast_possible_truncation, + reason = "LSP positions are u32; documents in practice never approach 4 GiB" +)] +fn byte_to_position(offset: usize, rope: &Rope, src: &str) -> Position { + let offset = offset.min(src.len()); + let text_before = &src[..offset]; + let line = text_before.bytes().filter(|&b| b == b'\n').count() as u32; + let last_nl = text_before.rfind('\n').map_or(0, |i| i + 1); + let col_str = &text_before[last_nl..]; + // LSP character offset is in UTF-16 code units. + let character = col_str.encode_utf16().count() as u32; + // suppress unused warning for rope param (kept for API symmetry) + let _ = rope; + Position { line, character } +} + +/// Convert an LSP `Position` into a byte offset in the source text. +#[expect( + clippy::cast_possible_truncation, + reason = "LSP positions are u32; documents in practice never approach 4 GiB" +)] +fn position_to_byte_offset(pos: Position, _rope: &Rope, src: &str) -> usize { + let mut line = 0u32; + let mut byte = 0usize; + for ch in src.chars() { + if line == pos.line { + break; + } + if ch == '\n' { + line += 1; + } + byte += ch.len_utf8(); + } + // Now advance `character` UTF-16 code units within the line. + let mut col = 0u32; + for ch in src[byte..].chars() { + if col >= pos.character || ch == '\n' { + break; + } + col += ch.len_utf16() as u32; + byte += ch.len_utf8(); + } + byte.min(src.len()) +} + +/// Encode semantic token spans as LSP delta-encoded `SemanticToken` entries. +#[expect( + clippy::cast_possible_truncation, + reason = "LSP positions are u32; documents in practice never approach 4 GiB" +)] +fn encode_tokens( + spans: &[(std::ops::Range, u32)], + rope: &Rope, + src: &str, +) -> Vec { + let mut out = Vec::with_capacity(spans.len()); + let mut prev_line = 0u32; + let mut prev_start = 0u32; + + for (range, token_type) in spans { + if range.is_empty() { + continue; + } + let start_pos = byte_to_position(range.start, rope, src); + let end_offset = range.end.min(src.len()); + let length = src[range.start..end_offset].encode_utf16().count() as u32; + + let delta_line = start_pos.line - prev_line; + let delta_start = if delta_line == 0 { + start_pos.character - prev_start + } else { + start_pos.character + }; + + out.push(SemanticToken { + delta_line, + delta_start, + length, + token_type: *token_type, + token_modifiers_bitset: 0, + }); + + prev_line = start_pos.line; + prev_start = start_pos.character; + } + + out +} diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 000000000..8b262167a --- /dev/null +++ b/package-lock.json @@ -0,0 +1,391 @@ +{ + "name": "mongosql-lsp", + "version": "0.1.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "mongosql-lsp", + "version": "0.1.0", + "hasInstallScript": true, + "devDependencies": { + "rolldown": "^1.0.0-beta.7", + "typescript": "^5.4.0" + }, + "engines": { + "vscode": "^1.75.0" + } + }, + "node_modules/@emnapi/wasi-threads": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/@emnapi/wasi-threads/-/wasi-threads-1.2.1.tgz", + "integrity": "sha512-uTII7OYF+/Mes/MrcIOYp5yOtSMLBWSIoLPpcgwipoiKbli6k322tcoFsxoIIxPDqW01SQGAgko4EzZi2BNv2w==", + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "tslib": "^2.4.0" + } + }, + "node_modules/@napi-rs/wasm-runtime": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/@napi-rs/wasm-runtime/-/wasm-runtime-1.1.4.tgz", + "integrity": "sha512-3NQNNgA1YSlJb/kMH1ildASP9HW7/7kYnRI2szWJaofaS1hWmbGI4H+d3+22aGzXXN9IJ+n+GiFVcGipJP18ow==", + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "@tybys/wasm-util": "^0.10.1" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + }, + "peerDependencies": { + "@emnapi/core": "^1.7.1", + "@emnapi/runtime": "^1.7.1" + } + }, + "node_modules/@oxc-project/types": { + "version": "0.129.0", + "resolved": "https://registry.npmjs.org/@oxc-project/types/-/types-0.129.0.tgz", + "integrity": "sha512-3oz8m3FGdr2nDXVqmFUw7jolKliC4MoyXYIG2c7gpjBnzUWQpUGIYcXYKxTdTi+N2jusvt610ckTMkxdwHkYEg==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/Boshen" + } + }, + "node_modules/@rolldown/binding-android-arm64": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-android-arm64/-/binding-android-arm64-1.0.0.tgz", + "integrity": "sha512-TWMZnRLMe63C2Lhyicviu7ZHaU4kxa6PS3rofvc9GmcvptzNN11BcfQ4Sl7MwTOsisQoa2keB/EBdNCAnUo8vA==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-darwin-arm64": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-darwin-arm64/-/binding-darwin-arm64-1.0.0.tgz", + "integrity": "sha512-6XcD+8k0gPVItNagEw78/qqcBDwKcwDYS8V2hRmVsfUSIrd8cWe/CBvRDI5toqFyPfj+FJr6t8U6Xj2P2prEew==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-darwin-x64": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-darwin-x64/-/binding-darwin-x64-1.0.0.tgz", + "integrity": "sha512-iN/tWVXRQDWvmZlKdceP1Dwug9GDpEymhb9p4xnEe6zvCg5lFmzVljl+1qR1NVx3yfGpr2Na+CuLmv5IU8uzfQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-freebsd-x64": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-freebsd-x64/-/binding-freebsd-x64-1.0.0.tgz", + "integrity": "sha512-jjQMDvvwSOuhOwMszD/klSOjyWMM3zI64hWTj9KT5x4MxRbZAf+7vLQ6qouRhtsLVFHr3f0ILaJAfgENPiQdAQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-linux-arm-gnueabihf": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-arm-gnueabihf/-/binding-linux-arm-gnueabihf-1.0.0.tgz", + "integrity": "sha512-d//Dtg2x6/m3mbV64yUGNnDGNZaDGRpDLLNGerHQUVObuNaIQaaDp25yUiqGXtHEXX+NP2d0wAlmKgpYgIAJ2A==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-linux-arm64-gnu": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-arm64-gnu/-/binding-linux-arm64-gnu-1.0.0.tgz", + "integrity": "sha512-n7Ofp0mx+aB2cC+Sdy5YtMnXtY9lchnHbY+3Yt0uq9JsWQExf4f5Whu0tK0R8Jdc9S6RchTHjIFY7uc92puOVQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-linux-arm64-musl": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-arm64-musl/-/binding-linux-arm64-musl-1.0.0.tgz", + "integrity": "sha512-EIVjy2cgd7uuMMo94FVkBp7F6DhcZAUwNURkSG3RwUmvAXR6s0ISxM81U+IydcZByPG0pZIHsf1b6kTxoFDgJA==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-linux-ppc64-gnu": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-ppc64-gnu/-/binding-linux-ppc64-gnu-1.0.0.tgz", + "integrity": "sha512-JEwwOPcwTLAcpDQlqSmjEmfs63xJnSiUNIGvLcDLUHCWK4XowpS/7c7tUsUH6uT/ct6bMUTdXKfI8967FYj6mg==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-linux-s390x-gnu": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-s390x-gnu/-/binding-linux-s390x-gnu-1.0.0.tgz", + "integrity": "sha512-0wjCFhLrihtAubnT9iA0N++0pSV0z5Hg7tNGdNJ4RFaINceHadoF+kiFGyY1qSSNVIAZtLotG8Ju1bgDPkjnFA==", + "cpu": [ + "s390x" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-linux-x64-gnu": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-x64-gnu/-/binding-linux-x64-gnu-1.0.0.tgz", + "integrity": "sha512-Dfn7iak9BcMMePxcoJfpSbWqnEyrp/dRF63/8qW/eHBdOZov6x5aShLLEYGYdIeSJ6vMLK/XCVB+lGIxm41bQA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-linux-x64-musl": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-linux-x64-musl/-/binding-linux-x64-musl-1.0.0.tgz", + "integrity": "sha512-5/utzzDmD/pD/bmuaUcbTf/sZYy0aztwIVlfpoW1fTjCZ0BaPOMVWGZL1zvgxyi7ZIVYWlxKONHmSbHuiOh8Jw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-openharmony-arm64": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-openharmony-arm64/-/binding-openharmony-arm64-1.0.0.tgz", + "integrity": "sha512-ouJs8VcUomfLfpbUECqFMRqdV4x6aeAK3MA4m6vTrJJjKyWTV5KnxZx7Jd9G+GlDaQQxubcba00x16OyJ1meig==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-wasm32-wasi": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-wasm32-wasi/-/binding-wasm32-wasi-1.0.0.tgz", + "integrity": "sha512-E+oHKGiDA+lsKMmFtffDDw91EryDT7uJocrIuCHqhm6bCTM6xFK+3gaCkYOHfPwQr0cCNarSM2xaELoQDz9jJg==", + "cpu": [ + "wasm32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "@emnapi/core": "1.10.0", + "@emnapi/runtime": "1.10.0", + "@napi-rs/wasm-runtime": "^1.1.4" + }, + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-win32-arm64-msvc": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-win32-arm64-msvc/-/binding-win32-arm64-msvc-1.0.0.tgz", + "integrity": "sha512-yYK02n8Rngo+gbm1y6G0+7jk1sJ/2Wt7K0me0Y7k/ErBpyf+LJ2gFpqWVTcRV1rUepBlQRmpgWkTQCiiwrK0Ow==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/binding-win32-x64-msvc": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/binding-win32-x64-msvc/-/binding-win32-x64-msvc-1.0.0.tgz", + "integrity": "sha512-14bpChMahXRRXiTwahSl+zzHPW6qQTXtkMuJBFlbo+pqSAews2d4BdCSHfrJ/MBsCZtpmTafsY+1QhBzitcmdg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": "^20.19.0 || >=22.12.0" + } + }, + "node_modules/@rolldown/pluginutils": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@rolldown/pluginutils/-/pluginutils-1.0.0.tgz", + "integrity": "sha512-aKs/3GSWyV0mrhNmt/96/Z3yczC3yvrzYATCiCXQebBsGyYzjNdUphRVLeJQ67ySKVXRfMxt2lm12pmXvbPFQQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@tybys/wasm-util": { + "version": "0.10.2", + "resolved": "https://registry.npmjs.org/@tybys/wasm-util/-/wasm-util-0.10.2.tgz", + "integrity": "sha512-RoBvJ2X0wuKlWFIjrwffGw1IqZHKQqzIchKaadZZfnNpsAYp2mM0h36JtPCjNDAHGgYez/15uMBpfGwchhiMgg==", + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "tslib": "^2.4.0" + } + }, + "node_modules/rolldown": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/rolldown/-/rolldown-1.0.0.tgz", + "integrity": "sha512-yD986aXDESFGS95spT1LAv0jssywP4npMEjmMHyN2/5+eE8qQJUype2AaKkRiLgBgyD0LFlubwAht7VmY8rGoA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@oxc-project/types": "=0.129.0", + "@rolldown/pluginutils": "1.0.0" + }, + "bin": { + "rolldown": "bin/cli.mjs" + }, + "engines": { + "node": "^20.19.0 || >=22.12.0" + }, + "optionalDependencies": { + "@rolldown/binding-android-arm64": "1.0.0", + "@rolldown/binding-darwin-arm64": "1.0.0", + "@rolldown/binding-darwin-x64": "1.0.0", + "@rolldown/binding-freebsd-x64": "1.0.0", + "@rolldown/binding-linux-arm-gnueabihf": "1.0.0", + "@rolldown/binding-linux-arm64-gnu": "1.0.0", + "@rolldown/binding-linux-arm64-musl": "1.0.0", + "@rolldown/binding-linux-ppc64-gnu": "1.0.0", + "@rolldown/binding-linux-s390x-gnu": "1.0.0", + "@rolldown/binding-linux-x64-gnu": "1.0.0", + "@rolldown/binding-linux-x64-musl": "1.0.0", + "@rolldown/binding-openharmony-arm64": "1.0.0", + "@rolldown/binding-wasm32-wasi": "1.0.0", + "@rolldown/binding-win32-arm64-msvc": "1.0.0", + "@rolldown/binding-win32-x64-msvc": "1.0.0" + } + }, + "node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "dev": true, + "license": "0BSD", + "optional": true + }, + "node_modules/typescript": { + "version": "5.9.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.3.tgz", + "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + } + } +} diff --git a/package.json b/package.json new file mode 100644 index 000000000..40688bfb8 --- /dev/null +++ b/package.json @@ -0,0 +1,34 @@ +{ + "name": "mongosql-lsp", + "displayName": "MongoSQL LSP", + "description": "Language server for MongoSQL MIR and AIR debug-tree files", + "version": "0.1.0", + "engines": { "vscode": "^1.75.0" }, + "main": "./dist/extension.js", + "activationEvents": [ + "onLanguage:mongosql-mir", + "onLanguage:mongosql-air" + ], + "contributes": { + "languages": [ + { + "id": "mongosql-mir", + "extensions": [".mir"], + "aliases": ["MongoSQL MIR"] + }, + { + "id": "mongosql-air", + "extensions": [".air"], + "aliases": ["MongoSQL AIR"] + } + ] + }, + "scripts": { + "postinstall": "cd client && npm i", + "build": "rolldown -c" + }, + "devDependencies": { + "rolldown": "^1.0.0-beta.7", + "typescript": "^5.4.0" + } +} diff --git a/rolldown.config.mjs b/rolldown.config.mjs new file mode 100644 index 000000000..f6c7f3740 --- /dev/null +++ b/rolldown.config.mjs @@ -0,0 +1,10 @@ +import { defineConfig } from "rolldown"; + +export default defineConfig({ + input: "client/src/extension.ts", + external: ["vscode"], + output: { + file: "dist/extension.js", + format: "cjs", + }, +}); diff --git a/test_output.mir b/test_output.mir new file mode 100644 index 000000000..fedc97f2a --- /dev/null +++ b/test_output.mir @@ -0,0 +1,69 @@ +Project( + Project { + source: Project( + Project { + source: Collection( + Collection { + db: "foo", + collection: "bar", + cache: , + }, + ), + expression: BindingTuple( + { + Key { + datasource: Named( + "bar", + ), + scope: 0, + }: Reference( + ReferenceExpr { + key: Key { + datasource: Named( + "bar", + ), + scope: 0, + }, + }, + ), + }, + ), + is_add_fields: false, + cache: , + }, + ), + expression: BindingTuple( + { + Key { + datasource: Bottom, + scope: 0, + }: Document( + DocumentExpr { + document: UniqueLinkedHashMap( + { + "a": FieldAccess( + FieldAccess { + expr: Reference( + ReferenceExpr { + key: Key { + datasource: Named( + "bar", + ), + scope: 0, + }, + }, + ), + field: "a", + is_nullable: true, + }, + ), + }, + ), + }, + ), + }, + ), + is_add_fields: false, + cache: , + }, +) diff --git a/test_where.mir b/test_where.mir new file mode 100644 index 000000000..78a3e5b22 --- /dev/null +++ b/test_where.mir @@ -0,0 +1,144 @@ +[WARNING] No schema information was found for the requested collections `["users"]` in database `mydb`. Either the collections don't exist in `mydb` or they don't have a schema. For now, they will be assigned empty schemas. Hint: You either need to generate schemas for your collections or correct your query. +Project( + Project { + source: Project( + Project { + source: Filter( + Filter { + source: Collection( + Collection { + db: "mydb", + collection: "users", + cache: , + }, + ), + condition: ScalarFunction( + ScalarFunctionApplication { + function: And, + args: [ + MqlIntrinsicFieldExistence( + FieldAccess { + expr: Reference( + ReferenceExpr { + key: Key { + datasource: Named( + "users", + ), + scope: 0, + }, + }, + ), + field: "age", + is_nullable: true, + }, + ), + ScalarFunction( + ScalarFunctionApplication { + function: Gt, + args: [ + FieldAccess( + FieldAccess { + expr: Reference( + ReferenceExpr { + key: Key { + datasource: Named( + "users", + ), + scope: 0, + }, + }, + ), + field: "age", + is_nullable: false, + }, + ), + Literal( + Integer( + 30, + ), + ), + ], + is_nullable: false, + }, + ), + ], + is_nullable: false, + }, + ), + cache: , + }, + ), + expression: BindingTuple( + { + Key { + datasource: Named( + "users", + ), + scope: 0, + }: Reference( + ReferenceExpr { + key: Key { + datasource: Named( + "users", + ), + scope: 0, + }, + }, + ), + }, + ), + is_add_fields: false, + cache: , + }, + ), + expression: BindingTuple( + { + Key { + datasource: Bottom, + scope: 0, + }: Document( + DocumentExpr { + document: UniqueLinkedHashMap( + { + "name": FieldAccess( + FieldAccess { + expr: Reference( + ReferenceExpr { + key: Key { + datasource: Named( + "users", + ), + scope: 0, + }, + }, + ), + field: "name", + is_nullable: true, + }, + ), + "age": FieldAccess( + FieldAccess { + expr: Reference( + ReferenceExpr { + key: Key { + datasource: Named( + "users", + ), + scope: 0, + }, + }, + ), + field: "age", + is_nullable: true, + }, + ), + }, + ), + }, + ), + }, + ), + is_add_fields: false, + cache: , + }, +) diff --git a/tsconfig.json b/tsconfig.json new file mode 100644 index 000000000..9c73b5b1d --- /dev/null +++ b/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "module": "commonjs", + "target": "ES2020", + "lib": ["ES2020"], + "strict": true, + "outDir": "dist", + "rootDir": "client/src", + "esModuleInterop": true, + "skipLibCheck": true + }, + "include": ["client/src/**/*.ts"], + "exclude": ["node_modules", "dist"] +} From 514197f2c4d31a6c0b55fffeae69e24359e3cdc1 Mon Sep 17 00:00:00 2001 From: Jonathan Powell Date: Tue, 12 May 2026 11:39:56 -0400 Subject: [PATCH 4/5] Add a --audit-trail flag --- Cargo.lock | 257 ++++++++++++++++++++++++++ Cargo.toml | 1 + README.md | 40 +++- mongosql-cli/Cargo.toml | 4 + mongosql-cli/src/main.rs | 383 ++++++++++++++++++++++++++++++++++++++- 5 files changed, 679 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 487910444..aefe86952 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,17 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures 0.2.17", +] + [[package]] name = "agg-ast" version = "0.0.0" @@ -171,6 +182,15 @@ dependencies = [ "object 0.32.2", ] +[[package]] +name = "arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" +dependencies = [ + "derive_arbitrary", +] + [[package]] name = "arrayvec" version = "0.7.6" @@ -398,12 +418,37 @@ version = "1.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "bzip2" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49ecfb22d906f800d4fe833b6282cf4dc1c298f5057ca0b5445e5c209735ca47" +dependencies = [ + "bzip2-sys", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.13+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "cast" version = "0.3.0" @@ -417,6 +462,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] @@ -483,6 +530,16 @@ dependencies = [ "half 2.7.1", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "2.34.0" @@ -580,6 +637,12 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "convert_case" version = "0.10.0" @@ -641,6 +704,30 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eb8a2a1cd12ab0d987a5d5e825195d372001a4094a0376319d5a0ad71c1ba0d" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "217698eaf96b4a3f0bc4f3662aaa55bdf913cd54d7204591faa790070c6d0853" + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "create-tasks" version = "0.1.0" @@ -889,6 +976,12 @@ dependencies = [ "uuid 1.19.0", ] +[[package]] +name = "deflate64" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac6b926516df9c60bfa16e107b21086399f8285a44ca9711344b9e553c5146e2" + [[package]] name = "deranged" version = "0.5.5" @@ -931,6 +1024,17 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "derive_arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "derive_more" version = "2.1.1" @@ -1147,6 +1251,16 @@ version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1865,6 +1979,15 @@ dependencies = [ "str_stack", ] +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] + [[package]] name = "instant" version = "0.1.13" @@ -1989,6 +2112,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.85" @@ -2107,6 +2240,27 @@ dependencies = [ "url", ] +[[package]] +name = "lzma-rs" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "297e814c836ae64db86b36cf2a557ba54368d03f6afcd7d947c266692f71115e" +dependencies = [ + "byteorder", + "crc", +] + +[[package]] +name = "lzma-sys" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "macro_magic" version = "0.5.1" @@ -2226,6 +2380,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -2391,6 +2546,8 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "tempfile", + "zip", ] [[package]] @@ -2739,6 +2896,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" dependencies = [ "digest", + "hmac", ] [[package]] @@ -3835,6 +3993,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" + [[package]] name = "siphasher" version = "1.0.1" @@ -5468,6 +5632,15 @@ dependencies = [ "tap", ] +[[package]] +name = "xz2" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" +dependencies = [ + "lzma-sys", +] + [[package]] name = "yoke" version = "0.8.1" @@ -5537,6 +5710,20 @@ name = "zeroize" version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] [[package]] name = "zerotrie" @@ -5571,8 +5758,78 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "zip" +version = "2.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabe6324e908f85a1c52063ce7aa26b68dcb7eb6dbc83a2d148403c9bc3eba50" +dependencies = [ + "aes", + "arbitrary", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "deflate64", + "displaydoc", + "flate2", + "getrandom 0.3.4", + "hmac", + "indexmap 2.13.0", + "lzma-rs", + "memchr", + "pbkdf2", + "sha1", + "thiserror 2.0.17", + "time", + "xz2", + "zeroize", + "zopfli", + "zstd", +] + [[package]] name = "zmij" version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd8f3f50b848df28f887acb68e41201b5aea6bc8a8dacc00fb40635ff9a72fea" + +[[package]] +name = "zopfli" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f05cd8797d63865425ff89b5c4a48804f35ba0ce8d125800027ad6017d2b5249" +dependencies = [ + "bumpalo", + "crc32fast", + "log", + "simd-adler32", +] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 3efc0cedb..354037f94 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ schema_derivation = { path = "agg-ast/schema_derivation" } base64 = "0.22.1" bson = "2" tailcall = "1" +zip = "2" [profile.release] lto = "fat" diff --git a/README.md b/README.md index 54bf1f304..c97428c8e 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,39 @@ Available stages (in pipeline order): > **Note:** `--execute` is only valid with `--stage mql` or when `--stage` is omitted. +**Capture an audit trail of all translation stages:** + +Use `--audit-trail` to write every intermediate representation produced during translation into `audit_trail.zip` in the current working directory. Extracting the zip yields an `audit_trail/` folder: + +| File | Contents | +|------|----------| +| `initial_query.sql` | The original SQL query, verbatim. Always present. | +| `query.ast` | Rewritten AST (present when stage ≥ `ast`). | +| `query.mir` | Optimized MIR tree (present when stage ≥ `mir`). | +| `query.air` | Desugared AIR tree (present when stage ≥ `air`). | +| `pipeline.js` | Generated MQL aggregation pipeline as JSON (present when stage = `mql`). | + +The set of files included depends on `--stage` (defaults to all stages). Normal stdout output is unchanged — the CLI still prints the stage representation to stdout as usual. + +```bash +# Capture all stages (default) using a local schema file +./target/debug/mongosql-cli --db mydb --schema-file schema.yaml --audit-trail "SELECT name FROM users" + +# Capture only up to MIR +./target/debug/mongosql-cli --db mydb --schema-file schema.yaml --audit-trail --stage mir "SELECT name FROM users" +``` + +After running, extract the zip to inspect the intermediate representations: + +```bash +unzip audit_trail.zip +# → audit_trail/initial_query.sql +# → audit_trail/query.ast +# → audit_trail/query.mir +# → audit_trail/query.air +# → audit_trail/pipeline.js +``` + ### Schema Files When `--schema-file` is provided, the CLI reads collection schemas from a local file. @@ -164,7 +197,12 @@ npm run build ### Launching in VS Code 1. Build the server: `cargo build -p mongosql-lsp` -2. Press **F5** in VS Code — this runs the **Launch Client** configuration in `.vscode/launch.json`, which opens an Extension Development Host window with `SERVER_PATH` pointed at `target/debug/mongosql-lsp`. +2. Install extension dependencies and build the bundle: + ```bash + npm install + npm run build + ``` +3. Press **F5** in VS Code — this runs the **Launch Client** configuration in `.vscode/launch.json`, which opens an Extension Development Host window with `SERVER_PATH` pointed at `target/debug/mongosql-lsp`. 3. Open any `.mir` or `.air` file to get: - **Syntax highlighting** — enum variant names, struct field keys, string/number literals, and keywords each in a distinct colour. - **Code folding** — fold struct bodies and arrays with the editor's fold shortcut. diff --git a/mongosql-cli/Cargo.toml b/mongosql-cli/Cargo.toml index 9b88b35af..0299ff41a 100644 --- a/mongosql-cli/Cargo.toml +++ b/mongosql-cli/Cargo.toml @@ -12,3 +12,7 @@ serde_json = { "workspace" = true } serde_yaml = { "workspace" = true } serde = { "workspace" = true, features = ["derive"] } clap = { "workspace" = true } +zip = { workspace = true } + +[dev-dependencies] +tempfile = "3" diff --git a/mongosql-cli/src/main.rs b/mongosql-cli/src/main.rs index 278323e4b..f82b032aa 100644 --- a/mongosql-cli/src/main.rs +++ b/mongosql-cli/src/main.rs @@ -40,7 +40,7 @@ enum TranslationCheckpoint { } #[derive(Parser, Debug)] -#[command(version, about, long_about=None)] +#[command(version, about, long_about = None, arg_required_else_help = true)] struct Cli { #[arg( short, @@ -82,6 +82,17 @@ struct Cli { help = "Stop at the specified compilation stage and print its intermediate representation. When omitted the CLI behaves as normal." )] stage: Option, + #[arg( + long, + default_value_t = false, + help = "Write intermediate representations for each completed translation stage \ + into audit_trail.zip in the current working directory. \ + Extracting the zip produces an audit_trail/ folder containing: \ + initial_query.sql, and whichever of query.ast, query.mir, query.air, \ + pipeline.js were reached. Which stages are included depends on --stage \ + (defaults to all stages). Normal stdout output is unchanged." + )] + audit_trail: bool, } #[derive(Debug, Serialize, Deserialize)] @@ -90,6 +101,18 @@ pub struct SchemaFile { pub schemas: BTreeMap>, } +/// Collected intermediate representations for the audit trail. +/// +/// Fields are `None` for stages not reached given the requested checkpoint. +struct TranslationStages { + ast_repr: Option, + mir_repr: Option, + air_repr: Option, + pipeline_json: Option, + /// Only `Some` when stage is `Mql`; needed for the `--execute` path. + translation: Option, +} + fn parse_query_from_args( query: Option, sql_file: Option, @@ -124,10 +147,10 @@ fn build_catalog( let extension = path .extension() .and_then(|ext| ext.to_str()) - .map(|ext| ext.to_lowercase()); + .map(str::to_lowercase); let catalog: SchemaFile = match extension.as_deref() { - Some("yaml") | Some("yml") => serde_yaml::from_str(&contents)?, + Some("yaml" | "yml") => serde_yaml::from_str(&contents)?, Some("json") => serde_json::from_str(&contents)?, _ => { return Err(CliError(format!( @@ -141,6 +164,187 @@ fn build_catalog( } } +/// Runs translation stages sequentially up to `stage`, collecting each +/// intermediate representation. +/// +/// The catalog is only built when `stage` is `Mir`, `Air`, or `Mql`. +/// +/// # Errors +/// Returns `CliError` if any translation stage or catalog build fails. +fn run_stages_up_to( + stage: TranslationCheckpoint, + current_db: &str, + query: &str, + uri: &str, + schema_file: Option, +) -> Result { + let ast_repr = Some(mongosql::translate_sql_to_ast_repr(query)?); + + if matches!(stage, TranslationCheckpoint::Ast) { + return Ok(TranslationStages { + ast_repr, + mir_repr: None, + air_repr: None, + pipeline_json: None, + translation: None, + }); + } + + let namespaces = mongosql::get_namespaces(current_db, query)?; + let catalog = build_catalog(uri, current_db, namespaces, schema_file)?; + let options = mongosql::options::SqlOptions { + allow_order_by_missing_columns: true, + ..Default::default() + }; + + let mir_repr = Some(mongosql::translate_sql_to_mir_repr( + current_db, query, &catalog, options, + )?); + + if matches!(stage, TranslationCheckpoint::Mir) { + return Ok(TranslationStages { + ast_repr, + mir_repr, + air_repr: None, + pipeline_json: None, + translation: None, + }); + } + + let air_repr = Some(mongosql::translate_sql_to_air_repr( + current_db, query, &catalog, options, + )?); + + if matches!(stage, TranslationCheckpoint::Air) { + return Ok(TranslationStages { + ast_repr, + mir_repr, + air_repr, + pipeline_json: None, + translation: None, + }); + } + + // Mql + let translation = mongosql::translate_sql(current_db, query, &catalog, options)?; + let pipeline_json = + serde_json::to_string_pretty(&translation.pipeline).map_err(|e| CliError(e.to_string()))?; + + Ok(TranslationStages { + ast_repr, + mir_repr, + air_repr, + pipeline_json: Some(pipeline_json), + translation: Some(translation), + }) +} + +/// Writes intermediate translation representations into `audit_trail.zip` +/// in the current working directory, overwriting any existing file. +/// +/// `initial_query.sql` is always written. Remaining files are written only +/// for stages that were reached (i.e. their field is `Some`). +/// +/// # Errors +/// Returns `CliError` if the zip file cannot be created or written. +fn write_audit_trail(query: &str, stages: &TranslationStages) -> Result<(), CliError> { + use std::io::Write as _; + use zip::write::FileOptions; + use zip::CompressionMethod; + + let file = std::fs::File::create("audit_trail.zip")?; + let mut zip = zip::write::ZipWriter::new(file); + let options = FileOptions::<()>::default().compression_method(CompressionMethod::Deflated); + + zip.start_file("audit_trail/initial_query.sql", options)?; + zip.write_all(query.as_bytes())?; + + if let Some(ast) = &stages.ast_repr { + zip.start_file("audit_trail/query.ast", options)?; + zip.write_all(ast.as_bytes())?; + } + if let Some(mir) = &stages.mir_repr { + zip.start_file("audit_trail/query.mir", options)?; + zip.write_all(mir.as_bytes())?; + } + if let Some(air) = &stages.air_repr { + zip.start_file("audit_trail/query.air", options)?; + zip.write_all(air.as_bytes())?; + } + if let Some(pipeline) = &stages.pipeline_json { + zip.start_file("audit_trail/pipeline.js", options)?; + zip.write_all(pipeline.as_bytes())?; + } + + zip.finish()?; + Ok(()) +} + +/// Runs all translation stages up to `stage`, writes `audit_trail.zip`, prints +/// the stage output to stdout, and optionally executes the query. +/// +/// Stdout output is identical to running `--stage` alone. +/// +/// # Errors +/// Returns `CliError` if any translation stage, zip write, or query execution fails. +fn handle_audit_trail( + stage: TranslationCheckpoint, + current_db: &str, + query: &str, + uri: &str, + schema_file: Option, + execute: bool, +) -> Result<(), CliError> { + let stages = run_stages_up_to(stage, current_db, query, uri, schema_file)?; + write_audit_trail(query, &stages)?; + eprintln!("audit_trail.zip written to current directory."); + + match stage { + TranslationCheckpoint::Ast => { + println!("{}", stages.ast_repr.expect("ast computed for Ast stage")); + } + TranslationCheckpoint::Mir => { + println!("{}", stages.mir_repr.expect("mir computed for Mir stage")); + } + TranslationCheckpoint::Air => { + println!("{}", stages.air_repr.expect("air computed for Air stage")); + } + TranslationCheckpoint::Mql => { + let translation = stages + .translation + .as_ref() + .expect("translation computed for Mql stage"); + let schema = serde_json::to_string_pretty(&translation.result_set_schema) + .map_err(|e| CliError(e.to_string()))?; + println!( + "target_db: {},\ntarget_collection: {:?},\nresult set schema:\n{}\npipeline:\n[", + translation.target_db, translation.target_collection, schema + ); + let bson::Bson::Array(ref pipeline) = translation.pipeline else { + return Err(CliError("pipeline is not an array".to_string())); + }; + for doc in pipeline { + println!(" {doc},"); + } + println!("]"); + } + } + + if execute { + run_query_and_display_results( + uri, + stages + .translation + .expect("translation computed for Mql stage"), + )?; + } + Ok(()) +} + +#[expect( + clippy::too_many_lines, + reason = "each TranslationCheckpoint arm repeats catalog/options setup; extracting further would obscure the control flow" +)] fn main() -> Result<(), CliError> { let args = Cli::parse(); @@ -155,6 +359,17 @@ fn main() -> Result<(), CliError> { )); } + if args.audit_trail { + return handle_audit_trail( + stage, + current_db.as_str(), + query.as_str(), + uri.as_str(), + args.schema_file, + args.execute, + ); + } + match stage { TranslationCheckpoint::Ast => { let output = mongosql::translate_sql_to_ast_repr(query.as_str())?; @@ -258,7 +473,7 @@ fn run_query_and_display_results( }; let pipeline = pipeline .into_iter() - .map(|doc| doc.as_document().map(|doc| doc.to_owned())) + .map(|doc| doc.as_document().map(std::borrow::ToOwned::to_owned)) .collect::>>() .ok_or_else(|| CliError("Pipeline contains non-Document!".to_string()))?; let results = if let Some(target_collection) = translation.target_collection { @@ -277,6 +492,10 @@ fn run_query_and_display_results( Ok(()) } +#[expect( + clippy::needless_pass_by_value, + reason = "changing to &BTreeSet would require updating build_catalog and all callers" +)] fn get_schema_catalog( uri: &str, current_db: &str, @@ -377,7 +596,7 @@ fn get_schema_catalog( schema_catalog_doc_vec.push(schema_catalog_doc); } - let mut schema_catalog_doc = schema_catalog_doc_vec[0].to_owned(); + let mut schema_catalog_doc = schema_catalog_doc_vec[0].clone(); let collections_schema_doc = schema_catalog_doc.get_document_mut(current_db)?; @@ -407,6 +626,9 @@ fn get_schema_catalog( #[cfg(test)] mod tests { use super::*; + + // Serializes tests that mutate the process-wide current directory. + static CWD_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); #[test] fn it_parses_query_correctly() { let query = "SELECT * FROM users".to_string(); @@ -458,4 +680,155 @@ mod tests { let parse_result = parse_query_from_args(query, sql_file); assert!(parse_result.is_err()); } + + fn make_test_stages_mql() -> TranslationStages { + TranslationStages { + ast_repr: Some("ast output".to_string()), + mir_repr: Some("mir output".to_string()), + air_repr: Some("air output".to_string()), + pipeline_json: Some(r#"[{"$match":{}}]"#.to_string()), + translation: None, + } + } + + #[test] + fn write_audit_trail_creates_correct_entries_for_mql_stage() { + // Arrange + let _guard = CWD_LOCK.lock().unwrap(); + let dir = tempfile::TempDir::new().unwrap(); + let zip_path = dir.path().join("audit_trail.zip"); + let original = std::env::current_dir().unwrap(); + std::env::set_current_dir(dir.path()).unwrap(); + let stages = make_test_stages_mql(); + + // Act + write_audit_trail("SELECT 1", &stages).unwrap(); + std::env::set_current_dir(original).unwrap(); + + // Assert + let file = std::fs::File::open(&zip_path).unwrap(); + let mut archive = zip::ZipArchive::new(file).unwrap(); + let names: Vec = (0..archive.len()) + .map(|i| archive.by_index(i).unwrap().name().to_string()) + .collect(); + + assert!(names.contains(&"audit_trail/initial_query.sql".to_string())); + assert!(names.contains(&"audit_trail/query.ast".to_string())); + assert!(names.contains(&"audit_trail/query.mir".to_string())); + assert!(names.contains(&"audit_trail/query.air".to_string())); + assert!(names.contains(&"audit_trail/pipeline.js".to_string())); + assert_eq!(names.len(), 5); + } + + #[test] + fn write_audit_trail_creates_correct_entries_for_ast_stage() { + // Arrange + let _guard = CWD_LOCK.lock().unwrap(); + let dir = tempfile::TempDir::new().unwrap(); + let zip_path = dir.path().join("audit_trail.zip"); + let original = std::env::current_dir().unwrap(); + std::env::set_current_dir(dir.path()).unwrap(); + let stages = TranslationStages { + ast_repr: Some("ast only".to_string()), + mir_repr: None, + air_repr: None, + pipeline_json: None, + translation: None, + }; + + // Act + write_audit_trail("SELECT 1", &stages).unwrap(); + std::env::set_current_dir(original).unwrap(); + + // Assert + let file = std::fs::File::open(&zip_path).unwrap(); + let mut archive = zip::ZipArchive::new(file).unwrap(); + let names: Vec = (0..archive.len()) + .map(|i| archive.by_index(i).unwrap().name().to_string()) + .collect(); + + assert_eq!(names.len(), 2); + assert!(names.contains(&"audit_trail/initial_query.sql".to_string())); + assert!(names.contains(&"audit_trail/query.ast".to_string())); + } + + #[test] + fn write_audit_trail_sql_content_is_verbatim() { + // Arrange + let _guard = CWD_LOCK.lock().unwrap(); + let dir = tempfile::TempDir::new().unwrap(); + let zip_path = dir.path().join("audit_trail.zip"); + let original = std::env::current_dir().unwrap(); + std::env::set_current_dir(dir.path()).unwrap(); + let query = "SELECT\n á\tFROM t"; + let stages = make_test_stages_mql(); + + // Act + write_audit_trail(query, &stages).unwrap(); + std::env::set_current_dir(original).unwrap(); + + // Assert + let file = std::fs::File::open(&zip_path).unwrap(); + let mut archive = zip::ZipArchive::new(file).unwrap(); + let mut entry = archive.by_name("audit_trail/initial_query.sql").unwrap(); + let mut content = Vec::new(); + std::io::Read::read_to_end(&mut entry, &mut content).unwrap(); + + assert_eq!(content, query.as_bytes()); + } + + #[test] + fn write_audit_trail_pipeline_json_is_valid_json() { + // Arrange + let _guard = CWD_LOCK.lock().unwrap(); + let dir = tempfile::TempDir::new().unwrap(); + let zip_path = dir.path().join("audit_trail.zip"); + let original = std::env::current_dir().unwrap(); + std::env::set_current_dir(dir.path()).unwrap(); + let stages = TranslationStages { + ast_repr: Some("ast".to_string()), + mir_repr: Some("mir".to_string()), + air_repr: Some("air".to_string()), + pipeline_json: Some(r#"[{"$match": {"x": 1}}]"#.to_string()), + translation: None, + }; + + // Act + write_audit_trail("SELECT 1", &stages).unwrap(); + std::env::set_current_dir(original).unwrap(); + + // Assert + let file = std::fs::File::open(&zip_path).unwrap(); + let mut archive = zip::ZipArchive::new(file).unwrap(); + let mut entry = archive.by_name("audit_trail/pipeline.js").unwrap(); + let mut content = String::new(); + std::io::Read::read_to_string(&mut entry, &mut content).unwrap(); + + assert!(serde_json::from_str::(&content).is_ok()); + } + + #[test] + fn write_audit_trail_overwrites_existing_zip() { + // Arrange + let _guard = CWD_LOCK.lock().unwrap(); + let dir = tempfile::TempDir::new().unwrap(); + let zip_path = dir.path().join("audit_trail.zip"); + let original = std::env::current_dir().unwrap(); + std::env::set_current_dir(dir.path()).unwrap(); + // Pre-create a dummy zip + std::fs::write("audit_trail.zip", b"not a real zip").unwrap(); + let stages = make_test_stages_mql(); + + // Act + let result = write_audit_trail("SELECT 1", &stages); + std::env::set_current_dir(original).unwrap(); + + // Assert + assert!(result.is_ok()); + let file = std::fs::File::open(&zip_path).unwrap(); + let archive_result = zip::ZipArchive::new(file); + assert!(archive_result.is_ok()); + let mut archive = archive_result.unwrap(); + assert!(archive.by_name("audit_trail/initial_query.sql").is_ok()); + } } From 3edd3a7d0d96cc4691e81027fd2b468fc2acfbeb Mon Sep 17 00:00:00 2001 From: Jonathan Powell Date: Tue, 12 May 2026 14:15:19 -0400 Subject: [PATCH 5/5] Refactor audit_tail code to its own file --- mongosql-cli/src/audit_trail.rs | 340 +++++++++++++++++++ mongosql-cli/src/catalog.rs | 171 ++++++++++ mongosql-cli/src/main.rs | 581 +------------------------------- 3 files changed, 528 insertions(+), 564 deletions(-) create mode 100644 mongosql-cli/src/audit_trail.rs create mode 100644 mongosql-cli/src/catalog.rs diff --git a/mongosql-cli/src/audit_trail.rs b/mongosql-cli/src/audit_trail.rs new file mode 100644 index 000000000..9c7c44a12 --- /dev/null +++ b/mongosql-cli/src/audit_trail.rs @@ -0,0 +1,340 @@ +//! Audit-trail and translation-stage helpers for the mongosql CLI. +//! +//! This module owns the `TranslationCheckpoint` enum, the intermediate-representation +//! collection struct, and the logic for running each compilation stage, writing +//! `audit_trail.zip`, and dispatching the combined audit-trail flow. + +use crate::{run_query_and_display_results, CliError}; +use mongosql::catalog::Catalog; + +/// Compilation stages at which translation can be halted and inspected. +#[derive(clap::ValueEnum, Debug, Clone, Copy)] +pub(crate) enum TranslationCheckpoint { + /// Stop after SQL parsing and AST rewrites; print the AST. + Ast, + /// Stop after algebrizing to MIR and running optimizer passes; print the MIR tree. + Mir, + /// AIR pretty-printing is not yet implemented. + Air, + /// Full translation to MQL; print the generated pipeline (default when --stage is omitted). + Mql, +} + +/// Collected intermediate representations for the audit trail. +/// +/// Fields are `None` for stages not reached given the requested checkpoint. +struct TranslationStages { + ast_repr: Option, + mir_repr: Option, + air_repr: Option, + pipeline_json: Option, + /// Only `Some` when stage is `Mql`; needed for the `--execute` path. + translation: Option, +} + +/// Runs translation stages sequentially up to `stage`, collecting each +/// intermediate representation. +/// +/// The catalog is only built when `stage` is `Mir`, `Air`, or `Mql`. +/// +/// # Errors +/// Returns `CliError` if any translation stage or catalog build fails. +fn run_stages_up_to( + stage: TranslationCheckpoint, + current_db: &str, + query: &str, + catalog: &Catalog, +) -> Result { + let ast_repr = Some(mongosql::translate_sql_to_ast_repr(query)?); + + if matches!(stage, TranslationCheckpoint::Ast) { + return Ok(TranslationStages { + ast_repr, + mir_repr: None, + air_repr: None, + pipeline_json: None, + translation: None, + }); + } + + let options = mongosql::options::SqlOptions { + allow_order_by_missing_columns: true, + ..Default::default() + }; + + let mir_repr = Some(mongosql::translate_sql_to_mir_repr( + current_db, query, &catalog, options, + )?); + + if matches!(stage, TranslationCheckpoint::Mir) { + return Ok(TranslationStages { + ast_repr, + mir_repr, + air_repr: None, + pipeline_json: None, + translation: None, + }); + } + + let air_repr = Some(mongosql::translate_sql_to_air_repr( + current_db, query, &catalog, options, + )?); + + if matches!(stage, TranslationCheckpoint::Air) { + return Ok(TranslationStages { + ast_repr, + mir_repr, + air_repr, + pipeline_json: None, + translation: None, + }); + } + + // Mql + let translation = mongosql::translate_sql(current_db, query, &catalog, options)?; + let pipeline_json = + serde_json::to_string_pretty(&translation.pipeline).map_err(|e| CliError(e.to_string()))?; + + Ok(TranslationStages { + ast_repr, + mir_repr, + air_repr, + pipeline_json: Some(pipeline_json), + translation: Some(translation), + }) +} + +/// Writes intermediate translation representations into `audit_trail.zip` +/// in the current working directory, overwriting any existing file. +/// +/// `initial_query.sql` is always written. Remaining files are written only +/// for stages that were reached (i.e. their field is `Some`). +/// +/// # Errors +/// Returns `CliError` if the zip file cannot be created or written. +fn write_audit_trail(query: &str, stages: &TranslationStages) -> Result<(), CliError> { + use std::io::Write as _; + use zip::write::FileOptions; + use zip::CompressionMethod; + + let file = std::fs::File::create("audit_trail.zip")?; + let mut zip = zip::write::ZipWriter::new(file); + let options = FileOptions::<()>::default().compression_method(CompressionMethod::Deflated); + + zip.start_file("audit_trail/initial_query.sql", options)?; + zip.write_all(query.as_bytes())?; + + if let Some(ast) = &stages.ast_repr { + zip.start_file("audit_trail/query.ast", options)?; + zip.write_all(ast.as_bytes())?; + } + if let Some(mir) = &stages.mir_repr { + zip.start_file("audit_trail/query.mir", options)?; + zip.write_all(mir.as_bytes())?; + } + if let Some(air) = &stages.air_repr { + zip.start_file("audit_trail/query.air", options)?; + zip.write_all(air.as_bytes())?; + } + if let Some(pipeline) = &stages.pipeline_json { + zip.start_file("audit_trail/pipeline.js", options)?; + zip.write_all(pipeline.as_bytes())?; + } + + zip.finish()?; + Ok(()) +} + +/// Runs all translation stages up to `stage`, writes `audit_trail.zip`, prints +/// the stage output to stdout, and optionally executes the query. +/// +/// Stdout output is identical to running `--stage` alone. +/// +/// # Errors +/// Returns `CliError` if any translation stage, zip write, or query execution fails. +pub(crate) fn handle_audit_trail( + stage: TranslationCheckpoint, + current_db: &str, + query: &str, + uri: &str, + execute: bool, + catalog: &Catalog, +) -> Result<(), CliError> { + if execute && !matches!(stage, TranslationCheckpoint::Mql) { + return Err(CliError( + "--execute is only valid with --stage mql or without --stage".to_string(), + )); + } + let stages = run_stages_up_to(stage, current_db, query, catalog)?; + write_audit_trail(query, &stages)?; + eprintln!("[INFO] audit_trail.zip written to current working directory."); + + match stage { + TranslationCheckpoint::Ast => { + println!("{}", stages.ast_repr.expect("ast computed for Ast stage")); + } + TranslationCheckpoint::Mir => { + println!("{}", stages.mir_repr.expect("mir computed for Mir stage")); + } + TranslationCheckpoint::Air => { + println!("{}", stages.air_repr.expect("air computed for Air stage")); + } + TranslationCheckpoint::Mql => { + let translation = stages + .translation + .as_ref() + .expect("translation computed for Mql stage"); + let schema = serde_json::to_string_pretty(&translation.result_set_schema) + .map_err(|e| CliError(e.to_string()))?; + println!( + "target_db: {},\ntarget_collection: {:?},\nresult set schema:\n{}\npipeline:\n[", + translation.target_db, translation.target_collection, schema + ); + let bson::Bson::Array(ref pipeline) = translation.pipeline else { + return Err(CliError("pipeline is not an array".to_string())); + }; + for doc in pipeline { + println!(" {doc},"); + } + println!("]"); + } + } + + if execute { + run_query_and_display_results( + uri, + stages + .translation + .expect("translation computed for Mql stage"), + )?; + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + static CWD_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + + fn make_test_stages_mql() -> TranslationStages { + TranslationStages { + ast_repr: Some("ast output".to_string()), + mir_repr: Some("mir output".to_string()), + air_repr: Some("air output".to_string()), + pipeline_json: Some(r#"[{"$match":{}}]"#.to_string()), + translation: None, + } + } + + #[test] + fn write_audit_trail_creates_correct_entries_for_mql_stage() { + let _guard = CWD_LOCK.lock().unwrap(); + let dir = tempfile::TempDir::new().unwrap(); + let zip_path = dir.path().join("audit_trail.zip"); + let original = std::env::current_dir().unwrap(); + std::env::set_current_dir(dir.path()).unwrap(); + let stages = make_test_stages_mql(); + write_audit_trail("SELECT 1", &stages).unwrap(); + std::env::set_current_dir(original).unwrap(); + let file = std::fs::File::open(&zip_path).unwrap(); + let mut archive = zip::ZipArchive::new(file).unwrap(); + let names: Vec = (0..archive.len()) + .map(|i| archive.by_index(i).unwrap().name().to_string()) + .collect(); + assert!(names.contains(&"audit_trail/initial_query.sql".to_string())); + assert!(names.contains(&"audit_trail/query.ast".to_string())); + assert!(names.contains(&"audit_trail/query.mir".to_string())); + assert!(names.contains(&"audit_trail/query.air".to_string())); + assert!(names.contains(&"audit_trail/pipeline.js".to_string())); + assert_eq!(names.len(), 5); + } + + #[test] + fn write_audit_trail_creates_correct_entries_for_ast_stage() { + let _guard = CWD_LOCK.lock().unwrap(); + let dir = tempfile::TempDir::new().unwrap(); + let zip_path = dir.path().join("audit_trail.zip"); + let original = std::env::current_dir().unwrap(); + std::env::set_current_dir(dir.path()).unwrap(); + let stages = TranslationStages { + ast_repr: Some("ast only".to_string()), + mir_repr: None, + air_repr: None, + pipeline_json: None, + translation: None, + }; + write_audit_trail("SELECT 1", &stages).unwrap(); + std::env::set_current_dir(original).unwrap(); + let file = std::fs::File::open(&zip_path).unwrap(); + let mut archive = zip::ZipArchive::new(file).unwrap(); + let names: Vec = (0..archive.len()) + .map(|i| archive.by_index(i).unwrap().name().to_string()) + .collect(); + assert_eq!(names.len(), 2); + assert!(names.contains(&"audit_trail/initial_query.sql".to_string())); + assert!(names.contains(&"audit_trail/query.ast".to_string())); + } + + #[test] + fn write_audit_trail_sql_content_is_verbatim() { + let _guard = CWD_LOCK.lock().unwrap(); + let dir = tempfile::TempDir::new().unwrap(); + let zip_path = dir.path().join("audit_trail.zip"); + let original = std::env::current_dir().unwrap(); + std::env::set_current_dir(dir.path()).unwrap(); + let query = "SELECT\n á\tFROM t"; + let stages = make_test_stages_mql(); + write_audit_trail(query, &stages).unwrap(); + std::env::set_current_dir(original).unwrap(); + let file = std::fs::File::open(&zip_path).unwrap(); + let mut archive = zip::ZipArchive::new(file).unwrap(); + let mut entry = archive.by_name("audit_trail/initial_query.sql").unwrap(); + let mut content = Vec::new(); + std::io::Read::read_to_end(&mut entry, &mut content).unwrap(); + assert_eq!(content, query.as_bytes()); + } + + #[test] + fn write_audit_trail_pipeline_json_is_valid_json() { + let _guard = CWD_LOCK.lock().unwrap(); + let dir = tempfile::TempDir::new().unwrap(); + let zip_path = dir.path().join("audit_trail.zip"); + let original = std::env::current_dir().unwrap(); + std::env::set_current_dir(dir.path()).unwrap(); + let stages = TranslationStages { + ast_repr: Some("ast".to_string()), + mir_repr: Some("mir".to_string()), + air_repr: Some("air".to_string()), + pipeline_json: Some(r#"[{"$match": {"x": 1}}]"#.to_string()), + translation: None, + }; + write_audit_trail("SELECT 1", &stages).unwrap(); + std::env::set_current_dir(original).unwrap(); + let file = std::fs::File::open(&zip_path).unwrap(); + let mut archive = zip::ZipArchive::new(file).unwrap(); + let mut entry = archive.by_name("audit_trail/pipeline.js").unwrap(); + let mut content = String::new(); + std::io::Read::read_to_string(&mut entry, &mut content).unwrap(); + assert!(serde_json::from_str::(&content).is_ok()); + } + + #[test] + fn write_audit_trail_overwrites_existing_zip() { + let _guard = CWD_LOCK.lock().unwrap(); + let dir = tempfile::TempDir::new().unwrap(); + let zip_path = dir.path().join("audit_trail.zip"); + let original = std::env::current_dir().unwrap(); + std::env::set_current_dir(dir.path()).unwrap(); + std::fs::write("audit_trail.zip", b"not a real zip").unwrap(); + let stages = make_test_stages_mql(); + let result = write_audit_trail("SELECT 1", &stages); + std::env::set_current_dir(original).unwrap(); + assert!(result.is_ok()); + let file = std::fs::File::open(&zip_path).unwrap(); + let archive_result = zip::ZipArchive::new(file); + assert!(archive_result.is_ok()); + let mut archive = archive_result.unwrap(); + assert!(archive.by_name("audit_trail/initial_query.sql").is_ok()); + } +} diff --git a/mongosql-cli/src/catalog.rs b/mongosql-cli/src/catalog.rs new file mode 100644 index 000000000..fd94d3219 --- /dev/null +++ b/mongosql-cli/src/catalog.rs @@ -0,0 +1,171 @@ +use agg_ast::definitions::Namespace; +use bson::{doc, Document}; +use mongodb::sync::Client; +use mongosql::{build_catalog_from_catalog_schema, catalog::Catalog, json_schema::Schema}; +use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, BTreeSet}; + +use crate::CliError; + +const SQL_SCHEMAS_COLLECTION: &str = "__sql_schemas"; + +#[derive(Debug, Serialize, Deserialize)] +struct SchemaFile { + #[serde(flatten)] + schemas: BTreeMap>, +} + +pub(crate) fn build_catalog( + uri: &str, + current_db: &str, + namespaces: BTreeSet, + schema_file: Option, +) -> Result { + if let Some(schema_file) = schema_file { + let contents = std::fs::read_to_string(&schema_file)?; + let path = std::path::Path::new(&schema_file); + let extension = path + .extension() + .and_then(|ext| ext.to_str()) + .map(str::to_lowercase); + + let catalog: SchemaFile = match extension.as_deref() { + Some("yaml" | "yml") => serde_yaml::from_str(&contents)?, + Some("json") => serde_json::from_str(&contents)?, + _ => { + return Err(CliError(format!( + "Unsupported schema file extension: {extension:?}. Supported formats are .yml, .yaml, .json" + ))) + } + }; + Ok(build_catalog_from_catalog_schema(catalog.schemas)?) + } else { + get_schema_catalog(uri, current_db, namespaces) + } +} + +#[expect( + clippy::needless_pass_by_value, + reason = "changing to &BTreeSet would require updating build_catalog and all callers" +)] +fn get_schema_catalog( + uri: &str, + current_db: &str, + namespaces: BTreeSet, +) -> Result { + // If there are no namespaces (e.g. queries with only array datasources), assign + // an empty schema to `current_db` + if namespaces.is_empty() { + let schema_catalog_doc = doc! { + current_db: doc! {}, + }; + + return Ok(mongosql::build_catalog_from_catalog_schema( + serde_json::from_str::>>( + &schema_catalog_doc.to_string(), + )?, + )?); + } + + // Otherwise, fetch the schema information for the specified collections. + let client = Client::with_uri_str(uri)?; + let db = client.database(current_db); + let schema_collection = db.collection::(SQL_SCHEMAS_COLLECTION); + + let collection_names = namespaces + .iter() + .map(|namespace| namespace.collection.as_str()) + .collect::>(); + + let schema_catalog_aggregation_pipeline = vec![ + doc! {"$match": { + "_id": { + "$in": &collection_names + } + } + }, + doc! {"$project":{ + "_id": 1, + "schema": 1 + } + }, + doc! {"$group": { + "_id": null, + "collections": { + "$push": { + "collectionName": "$_id", + "schema": "$schema" + } + } + } + }, + doc! {"$project": { + "_id": 0, + current_db: { + "$arrayToObject": [{ + "$map": { + "input": "$collections", + "as": "coll", + "in": { + "k": "$$coll.collectionName", + "v": "$$coll.schema" + } + } + }] + } + } + }, + ]; + + let mut schema_catalog_doc_vec: Vec = schema_collection + .aggregate(schema_catalog_aggregation_pipeline) + .run()? + .collect::, _>>()?; + + if schema_catalog_doc_vec.len() > 1 { + return Err(CliError("Multiple Schema Documents Returned".to_string())); + } + + if schema_catalog_doc_vec.is_empty() { + println!("[WARNING] No schema information was found for the requested collections `{collection_names:?}` in database `{current_db}`. Either the collections don't exist \ + in `{current_db}` or they don't have a schema. For now, they will be assigned empty schemas. Hint: You either need to generate schemas for your collections \ + or correct your query."); + + let mut collections_schema_doc = doc! {}; + + for collection in collection_names { + collections_schema_doc.insert(collection, doc! {}); + } + + let schema_catalog_doc = doc! { + current_db: collections_schema_doc, + }; + + schema_catalog_doc_vec.push(schema_catalog_doc); + } + + let mut schema_catalog_doc = schema_catalog_doc_vec[0].clone(); + + let collections_schema_doc = schema_catalog_doc.get_document_mut(current_db)?; + + if namespaces.len() != collections_schema_doc.len() { + let missing_collections: Vec = namespaces + .iter() + .map(|namespace| namespace.collection.clone()) + .filter(|collection| !collections_schema_doc.contains_key(collection.as_str())) + .collect(); + + println!("[WARNING] No schema was found for the following collections: {missing_collections:?}. These collections will be assigned empty schemas. \ + Hint: Generate schemas for your collections."); + + for collection in missing_collections { + collections_schema_doc.insert(collection, doc! {}); + } + } + + Ok(mongosql::build_catalog_from_catalog_schema( + serde_json::from_str::>>( + &schema_catalog_doc.to_string(), + )?, + )?) +} diff --git a/mongosql-cli/src/main.rs b/mongosql-cli/src/main.rs index f82b032aa..80e4a651f 100644 --- a/mongosql-cli/src/main.rs +++ b/mongosql-cli/src/main.rs @@ -1,16 +1,15 @@ -use agg_ast::definitions::Namespace; -use bson::{doc, Document}; +mod audit_trail; +mod catalog; + +use audit_trail::{handle_audit_trail, TranslationCheckpoint}; +use bson::Document; +use catalog::build_catalog; use clap::Parser; use mongodb::sync::{Client, Collection}; -use mongosql::{build_catalog_from_catalog_schema, catalog::Catalog, json_schema::Schema}; -use serde::{Deserialize, Serialize}; -use std::collections::{BTreeMap, BTreeSet}; use std::path::PathBuf; -const SQL_SCHEMAS_COLLECTION: &str = "__sql_schemas"; - #[derive(Debug)] -struct CliError(String); +pub(crate) struct CliError(String); impl std::fmt::Display for CliError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { @@ -27,18 +26,6 @@ where } } -#[derive(clap::ValueEnum, Debug, Clone, Copy)] -enum TranslationCheckpoint { - /// Stop after SQL parsing and AST rewrites; print the AST. - Ast, - /// Stop after algebrizing to MIR and running optimizer passes; print the MIR tree. - Mir, - /// AIR pretty-printing is not yet implemented. - Air, - /// Full translation to MQL; print the generated pipeline (default when --stage is omitted). - Mql, -} - #[derive(Parser, Debug)] #[command(version, about, long_about = None, arg_required_else_help = true)] struct Cli { @@ -95,24 +82,6 @@ struct Cli { audit_trail: bool, } -#[derive(Debug, Serialize, Deserialize)] -pub struct SchemaFile { - #[serde(flatten)] - pub schemas: BTreeMap>, -} - -/// Collected intermediate representations for the audit trail. -/// -/// Fields are `None` for stages not reached given the requested checkpoint. -struct TranslationStages { - ast_repr: Option, - mir_repr: Option, - air_repr: Option, - pipeline_json: Option, - /// Only `Some` when stage is `Mql`; needed for the `--execute` path. - translation: Option, -} - fn parse_query_from_args( query: Option, sql_file: Option, @@ -135,216 +104,6 @@ fn parse_query_from_args( } } -fn build_catalog( - uri: &str, - current_db: &str, - namespaces: std::collections::BTreeSet, - schema_file: Option, -) -> Result { - if let Some(schema_file) = schema_file { - let contents = std::fs::read_to_string(&schema_file)?; - let path = std::path::Path::new(&schema_file); - let extension = path - .extension() - .and_then(|ext| ext.to_str()) - .map(str::to_lowercase); - - let catalog: SchemaFile = match extension.as_deref() { - Some("yaml" | "yml") => serde_yaml::from_str(&contents)?, - Some("json") => serde_json::from_str(&contents)?, - _ => { - return Err(CliError(format!( - "Unsupported schema file extension: {extension:?}. Supported formats are .yml, .yaml, .json" - ))) - } - }; - Ok(build_catalog_from_catalog_schema(catalog.schemas)?) - } else { - get_schema_catalog(uri, current_db, namespaces) - } -} - -/// Runs translation stages sequentially up to `stage`, collecting each -/// intermediate representation. -/// -/// The catalog is only built when `stage` is `Mir`, `Air`, or `Mql`. -/// -/// # Errors -/// Returns `CliError` if any translation stage or catalog build fails. -fn run_stages_up_to( - stage: TranslationCheckpoint, - current_db: &str, - query: &str, - uri: &str, - schema_file: Option, -) -> Result { - let ast_repr = Some(mongosql::translate_sql_to_ast_repr(query)?); - - if matches!(stage, TranslationCheckpoint::Ast) { - return Ok(TranslationStages { - ast_repr, - mir_repr: None, - air_repr: None, - pipeline_json: None, - translation: None, - }); - } - - let namespaces = mongosql::get_namespaces(current_db, query)?; - let catalog = build_catalog(uri, current_db, namespaces, schema_file)?; - let options = mongosql::options::SqlOptions { - allow_order_by_missing_columns: true, - ..Default::default() - }; - - let mir_repr = Some(mongosql::translate_sql_to_mir_repr( - current_db, query, &catalog, options, - )?); - - if matches!(stage, TranslationCheckpoint::Mir) { - return Ok(TranslationStages { - ast_repr, - mir_repr, - air_repr: None, - pipeline_json: None, - translation: None, - }); - } - - let air_repr = Some(mongosql::translate_sql_to_air_repr( - current_db, query, &catalog, options, - )?); - - if matches!(stage, TranslationCheckpoint::Air) { - return Ok(TranslationStages { - ast_repr, - mir_repr, - air_repr, - pipeline_json: None, - translation: None, - }); - } - - // Mql - let translation = mongosql::translate_sql(current_db, query, &catalog, options)?; - let pipeline_json = - serde_json::to_string_pretty(&translation.pipeline).map_err(|e| CliError(e.to_string()))?; - - Ok(TranslationStages { - ast_repr, - mir_repr, - air_repr, - pipeline_json: Some(pipeline_json), - translation: Some(translation), - }) -} - -/// Writes intermediate translation representations into `audit_trail.zip` -/// in the current working directory, overwriting any existing file. -/// -/// `initial_query.sql` is always written. Remaining files are written only -/// for stages that were reached (i.e. their field is `Some`). -/// -/// # Errors -/// Returns `CliError` if the zip file cannot be created or written. -fn write_audit_trail(query: &str, stages: &TranslationStages) -> Result<(), CliError> { - use std::io::Write as _; - use zip::write::FileOptions; - use zip::CompressionMethod; - - let file = std::fs::File::create("audit_trail.zip")?; - let mut zip = zip::write::ZipWriter::new(file); - let options = FileOptions::<()>::default().compression_method(CompressionMethod::Deflated); - - zip.start_file("audit_trail/initial_query.sql", options)?; - zip.write_all(query.as_bytes())?; - - if let Some(ast) = &stages.ast_repr { - zip.start_file("audit_trail/query.ast", options)?; - zip.write_all(ast.as_bytes())?; - } - if let Some(mir) = &stages.mir_repr { - zip.start_file("audit_trail/query.mir", options)?; - zip.write_all(mir.as_bytes())?; - } - if let Some(air) = &stages.air_repr { - zip.start_file("audit_trail/query.air", options)?; - zip.write_all(air.as_bytes())?; - } - if let Some(pipeline) = &stages.pipeline_json { - zip.start_file("audit_trail/pipeline.js", options)?; - zip.write_all(pipeline.as_bytes())?; - } - - zip.finish()?; - Ok(()) -} - -/// Runs all translation stages up to `stage`, writes `audit_trail.zip`, prints -/// the stage output to stdout, and optionally executes the query. -/// -/// Stdout output is identical to running `--stage` alone. -/// -/// # Errors -/// Returns `CliError` if any translation stage, zip write, or query execution fails. -fn handle_audit_trail( - stage: TranslationCheckpoint, - current_db: &str, - query: &str, - uri: &str, - schema_file: Option, - execute: bool, -) -> Result<(), CliError> { - let stages = run_stages_up_to(stage, current_db, query, uri, schema_file)?; - write_audit_trail(query, &stages)?; - eprintln!("audit_trail.zip written to current directory."); - - match stage { - TranslationCheckpoint::Ast => { - println!("{}", stages.ast_repr.expect("ast computed for Ast stage")); - } - TranslationCheckpoint::Mir => { - println!("{}", stages.mir_repr.expect("mir computed for Mir stage")); - } - TranslationCheckpoint::Air => { - println!("{}", stages.air_repr.expect("air computed for Air stage")); - } - TranslationCheckpoint::Mql => { - let translation = stages - .translation - .as_ref() - .expect("translation computed for Mql stage"); - let schema = serde_json::to_string_pretty(&translation.result_set_schema) - .map_err(|e| CliError(e.to_string()))?; - println!( - "target_db: {},\ntarget_collection: {:?},\nresult set schema:\n{}\npipeline:\n[", - translation.target_db, translation.target_collection, schema - ); - let bson::Bson::Array(ref pipeline) = translation.pipeline else { - return Err(CliError("pipeline is not an array".to_string())); - }; - for doc in pipeline { - println!(" {doc},"); - } - println!("]"); - } - } - - if execute { - run_query_and_display_results( - uri, - stages - .translation - .expect("translation computed for Mql stage"), - )?; - } - Ok(()) -} - -#[expect( - clippy::too_many_lines, - reason = "each TranslationCheckpoint arm repeats catalog/options setup; extracting further would obscure the control flow" -)] fn main() -> Result<(), CliError> { let args = Cli::parse(); @@ -359,14 +118,22 @@ fn main() -> Result<(), CliError> { )); } + let namespaces = mongosql::get_namespaces(current_db.as_str(), query.as_str())?; + let catalog = build_catalog( + uri.as_str(), + current_db.as_str(), + namespaces, + args.schema_file, + )?; + if args.audit_trail { return handle_audit_trail( stage, current_db.as_str(), query.as_str(), uri.as_str(), - args.schema_file, args.execute, + &catalog, ); } @@ -377,13 +144,6 @@ fn main() -> Result<(), CliError> { return Ok(()); } TranslationCheckpoint::Mir => { - let namespaces = mongosql::get_namespaces(current_db.as_str(), query.as_str())?; - let catalog = build_catalog( - uri.as_str(), - current_db.as_str(), - namespaces, - args.schema_file, - )?; let options = mongosql::options::SqlOptions { allow_order_by_missing_columns: true, ..Default::default() @@ -398,13 +158,6 @@ fn main() -> Result<(), CliError> { return Ok(()); } TranslationCheckpoint::Air => { - let namespaces = mongosql::get_namespaces(current_db.as_str(), query.as_str())?; - let catalog = build_catalog( - uri.as_str(), - current_db.as_str(), - namespaces, - args.schema_file, - )?; let options = mongosql::options::SqlOptions { allow_order_by_missing_columns: true, ..Default::default() @@ -421,13 +174,6 @@ fn main() -> Result<(), CliError> { TranslationCheckpoint::Mql => {} } - let namespaces = mongosql::get_namespaces(current_db.as_str(), query.as_str())?; - let catalog = build_catalog( - uri.as_str(), - current_db.as_str(), - namespaces, - args.schema_file, - )?; let options = mongosql::options::SqlOptions { allow_order_by_missing_columns: true, ..Default::default() @@ -462,7 +208,7 @@ fn main() -> Result<(), CliError> { run_query_and_display_results(uri.as_str(), translation) } -fn run_query_and_display_results( +pub(crate) fn run_query_and_display_results( uri: &str, translation: mongosql::Translation, ) -> Result<(), CliError> { @@ -492,150 +238,15 @@ fn run_query_and_display_results( Ok(()) } -#[expect( - clippy::needless_pass_by_value, - reason = "changing to &BTreeSet would require updating build_catalog and all callers" -)] -fn get_schema_catalog( - uri: &str, - current_db: &str, - namespaces: BTreeSet, -) -> Result { - // If there are no namespaces (e.g. queries with only array datasources), assign - // an empty schema to `current_db` - if namespaces.is_empty() { - let schema_catalog_doc = doc! { - current_db: doc! {}, - }; - - return Ok(mongosql::build_catalog_from_catalog_schema( - serde_json::from_str::>>( - &schema_catalog_doc.to_string(), - )?, - )?); - } - - // Otherwise, fetch the schema information for the specified collections. - let client = Client::with_uri_str(uri)?; - let db = client.database(current_db); - let schema_collection = db.collection::(SQL_SCHEMAS_COLLECTION); - - let collection_names = namespaces - .iter() - .map(|namespace| namespace.collection.as_str()) - .collect::>(); - - // Create an aggregation pipeline to fetch the schema information for the specified collections. - // The pipeline uses $in to query all the specified collections and projects them into the desired format: - // "dbName": { "collection1" : "Schema1", "collection2" : "Schema2", ... } - let schema_catalog_aggregation_pipeline = vec![ - doc! {"$match": { - "_id": { - "$in": &collection_names - } - } - }, - doc! {"$project":{ - "_id": 1, - "schema": 1 - } - }, - doc! {"$group": { - "_id": null, - "collections": { - "$push": { - "collectionName": "$_id", - "schema": "$schema" - } - } - } - }, - doc! {"$project": { - "_id": 0, - current_db: { - "$arrayToObject": [{ - "$map": { - "input": "$collections", - "as": "coll", - "in": { - "k": "$$coll.collectionName", - "v": "$$coll.schema" - } - } - }] - } - } - }, - ]; - - // create the schema_catalog document - let mut schema_catalog_doc_vec: Vec = schema_collection - .aggregate(schema_catalog_aggregation_pipeline) - .run()? - .collect::, _>>()?; - - if schema_catalog_doc_vec.len() > 1 { - return Err(CliError("Multiple Schema Documents Returned".to_string())); - } - - if schema_catalog_doc_vec.is_empty() { - println!("[WARNING] No schema information was found for the requested collections `{collection_names:?}` in database `{current_db}`. Either the collections don't exist \ - in `{current_db}` or they don't have a schema. For now, they will be assigned empty schemas. Hint: You either need to generate schemas for your collections \ - or correct your query."); - - let mut collections_schema_doc = doc! {}; - - for collection in collection_names { - collections_schema_doc.insert(collection, doc! {}); - } - - let schema_catalog_doc = doc! { - current_db: collections_schema_doc, - }; - - schema_catalog_doc_vec.push(schema_catalog_doc); - } - - let mut schema_catalog_doc = schema_catalog_doc_vec[0].clone(); - - let collections_schema_doc = schema_catalog_doc.get_document_mut(current_db)?; - - // If there are collections with no schema available, assign them empty schemas. - if namespaces.len() != collections_schema_doc.len() { - let missing_collections: Vec = namespaces - .iter() - .map(|namespace| namespace.collection.clone()) - .filter(|collection| !collections_schema_doc.contains_key(collection.as_str())) - .collect(); - - println!("[WARNING] No schema was found for the following collections: {missing_collections:?}. These collections will be assigned empty schemas. \ - Hint: Generate schemas for your collections."); - - for collection in missing_collections { - collections_schema_doc.insert(collection, doc! {}); - } - } - - Ok(mongosql::build_catalog_from_catalog_schema( - serde_json::from_str::>>( - &schema_catalog_doc.to_string(), - )?, - )?) -} - #[cfg(test)] mod tests { use super::*; - // Serializes tests that mutate the process-wide current directory. - static CWD_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); #[test] fn it_parses_query_correctly() { let query = "SELECT * FROM users".to_string(); let sql_file = None; - let parse_result = parse_query_from_args(Some(query), sql_file); - assert!(parse_result.is_ok()); assert_eq!(parse_result.unwrap(), "SELECT * FROM users".to_string()); } @@ -644,9 +255,7 @@ mod tests { fn it_parses_sql_file_correctly() { let query: Option = None; let sql_file = Some(PathBuf::from("./test/sample_query.sql")); - let parse_result = parse_query_from_args(query, sql_file); - assert!(parse_result.is_ok()); assert_eq!(parse_result.unwrap(), "SELECT customerAge, COUNT(*) FROM sample_supplies.sales GROUP BY customer.age AS customerAge limit 10;".trim().to_string()); } @@ -655,9 +264,7 @@ mod tests { fn sql_file_takes_precedence_over_query() { let query = "SELECT * FROM users".to_string(); let sql_file = Some(PathBuf::from("./test/sample_query.sql")); - let parse_result = parse_query_from_args(Some(query), sql_file); - assert!(parse_result.is_ok()); assert_eq!(parse_result.unwrap(), "SELECT customerAge, COUNT(*) FROM sample_supplies.sales GROUP BY customer.age AS customerAge limit 10;".trim().to_string()); } @@ -666,9 +273,7 @@ mod tests { fn no_query_provided_returns_error() { let query: Option = None; let sql_file: Option = None; - let parse_result = parse_query_from_args(query, sql_file); - assert!(parse_result.is_err()); } @@ -676,159 +281,7 @@ mod tests { fn empty_sql_file_returns_error() { let query: Option = None; let sql_file = Some(PathBuf::from("./test/empty_query.sql")); - let parse_result = parse_query_from_args(query, sql_file); assert!(parse_result.is_err()); } - - fn make_test_stages_mql() -> TranslationStages { - TranslationStages { - ast_repr: Some("ast output".to_string()), - mir_repr: Some("mir output".to_string()), - air_repr: Some("air output".to_string()), - pipeline_json: Some(r#"[{"$match":{}}]"#.to_string()), - translation: None, - } - } - - #[test] - fn write_audit_trail_creates_correct_entries_for_mql_stage() { - // Arrange - let _guard = CWD_LOCK.lock().unwrap(); - let dir = tempfile::TempDir::new().unwrap(); - let zip_path = dir.path().join("audit_trail.zip"); - let original = std::env::current_dir().unwrap(); - std::env::set_current_dir(dir.path()).unwrap(); - let stages = make_test_stages_mql(); - - // Act - write_audit_trail("SELECT 1", &stages).unwrap(); - std::env::set_current_dir(original).unwrap(); - - // Assert - let file = std::fs::File::open(&zip_path).unwrap(); - let mut archive = zip::ZipArchive::new(file).unwrap(); - let names: Vec = (0..archive.len()) - .map(|i| archive.by_index(i).unwrap().name().to_string()) - .collect(); - - assert!(names.contains(&"audit_trail/initial_query.sql".to_string())); - assert!(names.contains(&"audit_trail/query.ast".to_string())); - assert!(names.contains(&"audit_trail/query.mir".to_string())); - assert!(names.contains(&"audit_trail/query.air".to_string())); - assert!(names.contains(&"audit_trail/pipeline.js".to_string())); - assert_eq!(names.len(), 5); - } - - #[test] - fn write_audit_trail_creates_correct_entries_for_ast_stage() { - // Arrange - let _guard = CWD_LOCK.lock().unwrap(); - let dir = tempfile::TempDir::new().unwrap(); - let zip_path = dir.path().join("audit_trail.zip"); - let original = std::env::current_dir().unwrap(); - std::env::set_current_dir(dir.path()).unwrap(); - let stages = TranslationStages { - ast_repr: Some("ast only".to_string()), - mir_repr: None, - air_repr: None, - pipeline_json: None, - translation: None, - }; - - // Act - write_audit_trail("SELECT 1", &stages).unwrap(); - std::env::set_current_dir(original).unwrap(); - - // Assert - let file = std::fs::File::open(&zip_path).unwrap(); - let mut archive = zip::ZipArchive::new(file).unwrap(); - let names: Vec = (0..archive.len()) - .map(|i| archive.by_index(i).unwrap().name().to_string()) - .collect(); - - assert_eq!(names.len(), 2); - assert!(names.contains(&"audit_trail/initial_query.sql".to_string())); - assert!(names.contains(&"audit_trail/query.ast".to_string())); - } - - #[test] - fn write_audit_trail_sql_content_is_verbatim() { - // Arrange - let _guard = CWD_LOCK.lock().unwrap(); - let dir = tempfile::TempDir::new().unwrap(); - let zip_path = dir.path().join("audit_trail.zip"); - let original = std::env::current_dir().unwrap(); - std::env::set_current_dir(dir.path()).unwrap(); - let query = "SELECT\n á\tFROM t"; - let stages = make_test_stages_mql(); - - // Act - write_audit_trail(query, &stages).unwrap(); - std::env::set_current_dir(original).unwrap(); - - // Assert - let file = std::fs::File::open(&zip_path).unwrap(); - let mut archive = zip::ZipArchive::new(file).unwrap(); - let mut entry = archive.by_name("audit_trail/initial_query.sql").unwrap(); - let mut content = Vec::new(); - std::io::Read::read_to_end(&mut entry, &mut content).unwrap(); - - assert_eq!(content, query.as_bytes()); - } - - #[test] - fn write_audit_trail_pipeline_json_is_valid_json() { - // Arrange - let _guard = CWD_LOCK.lock().unwrap(); - let dir = tempfile::TempDir::new().unwrap(); - let zip_path = dir.path().join("audit_trail.zip"); - let original = std::env::current_dir().unwrap(); - std::env::set_current_dir(dir.path()).unwrap(); - let stages = TranslationStages { - ast_repr: Some("ast".to_string()), - mir_repr: Some("mir".to_string()), - air_repr: Some("air".to_string()), - pipeline_json: Some(r#"[{"$match": {"x": 1}}]"#.to_string()), - translation: None, - }; - - // Act - write_audit_trail("SELECT 1", &stages).unwrap(); - std::env::set_current_dir(original).unwrap(); - - // Assert - let file = std::fs::File::open(&zip_path).unwrap(); - let mut archive = zip::ZipArchive::new(file).unwrap(); - let mut entry = archive.by_name("audit_trail/pipeline.js").unwrap(); - let mut content = String::new(); - std::io::Read::read_to_string(&mut entry, &mut content).unwrap(); - - assert!(serde_json::from_str::(&content).is_ok()); - } - - #[test] - fn write_audit_trail_overwrites_existing_zip() { - // Arrange - let _guard = CWD_LOCK.lock().unwrap(); - let dir = tempfile::TempDir::new().unwrap(); - let zip_path = dir.path().join("audit_trail.zip"); - let original = std::env::current_dir().unwrap(); - std::env::set_current_dir(dir.path()).unwrap(); - // Pre-create a dummy zip - std::fs::write("audit_trail.zip", b"not a real zip").unwrap(); - let stages = make_test_stages_mql(); - - // Act - let result = write_audit_trail("SELECT 1", &stages); - std::env::set_current_dir(original).unwrap(); - - // Assert - assert!(result.is_ok()); - let file = std::fs::File::open(&zip_path).unwrap(); - let archive_result = zip::ZipArchive::new(file); - assert!(archive_result.is_ok()); - let mut archive = archive_result.unwrap(); - assert!(archive.by_name("audit_trail/initial_query.sql").is_ok()); - } }