diff --git a/crates/lance-graph/src/ast.rs b/crates/lance-graph/src/ast.rs index 2e770ab8..06b27f14 100644 --- a/crates/lance-graph/src/ast.rs +++ b/crates/lance-graph/src/ast.rs @@ -13,14 +13,14 @@ use std::collections::HashMap; /// A complete Cypher query #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct CypherQuery { - /// MATCH clauses - pub match_clauses: Vec, + /// READING clauses (MATCH, UNWIND, etc.) + pub reading_clauses: Vec, /// WHERE clause (optional, before WITH if present) pub where_clause: Option, /// WITH clause (optional) - intermediate projection/aggregation pub with_clause: Option, - /// MATCH clauses after WITH (optional) - query chaining - pub post_with_match_clauses: Vec, + /// Post-WITH READING clauses + pub post_with_reading_clauses: Vec, /// WHERE clause after WITH (optional) - filters the WITH results pub post_with_where_clause: Option, /// RETURN clause @@ -37,28 +37,31 @@ impl CypherQuery { /// Extract all node labels referenced in the query pub fn get_node_labels(&self) -> Vec { let mut labels = Vec::new(); - for match_clause in &self.match_clauses { - for pattern in &match_clause.patterns { - match pattern { - GraphPattern::Node(node) => { - for label in &node.labels { - if !labels.contains(label) { - labels.push(label.clone()); - } - } - } - GraphPattern::Path(path) => { - for label in &path.start_node.labels { - if !labels.contains(label) { - labels.push(label.clone()); + // Iterate all match clauses directly + for clause in &self.reading_clauses { + if let ReadingClause::Match(match_clause) = clause { + for pattern in &match_clause.patterns { + match pattern { + GraphPattern::Node(node) => { + for label in &node.labels { + if !labels.contains(label) { + labels.push(label.clone()); + } } } - for segment in &path.segments { - for label in &segment.end_node.labels { + GraphPattern::Path(path) => { + for label in &path.start_node.labels { if !labels.contains(label) { labels.push(label.clone()); } } + for segment in &path.segments { + for label in &segment.end_node.labels { + if !labels.contains(label) { + labels.push(label.clone()); + } + } + } } } } @@ -70,21 +73,38 @@ impl CypherQuery { /// Extract all relationship types referenced in the query pub fn get_relationship_types(&self) -> Vec { let mut types = Vec::new(); - for match_clause in &self.match_clauses { - for pattern in &match_clause.patterns { - if let GraphPattern::Path(path) = pattern { - for segment in &path.segments { - for rel_type in &segment.relationship.types { - if !types.contains(rel_type) { - types.push(rel_type.clone()); - } - } - } + for clause in &self.reading_clauses { + if let ReadingClause::Match(match_clause) = clause { + for pattern in &match_clause.patterns { + self.collect_relationship_types_from_pattern(pattern, &mut types); } } } types } + + fn collect_relationship_types_from_pattern( + &self, + pattern: &GraphPattern, + types: &mut Vec, + ) { + if let GraphPattern::Path(path) = pattern { + for segment in &path.segments { + for rel_type in &segment.relationship.types { + if !types.contains(rel_type) { + types.push(rel_type.clone()); + } + } + } + } + } +} + +/// A clause that reads from the graph (MATCH, UNWIND) +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ReadingClause { + Match(MatchClause), + Unwind(UnwindClause), } /// A MATCH clause containing graph patterns @@ -94,6 +114,15 @@ pub struct MatchClause { pub patterns: Vec, } +/// An UNWIND clause +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct UnwindClause { + /// Expression to unwind + pub expression: ValueExpression, + /// Alias for the unwound values + pub alias: String, +} + /// A graph pattern (nodes and relationships) #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum GraphPattern { diff --git a/crates/lance-graph/src/datafusion_planner/analysis.rs b/crates/lance-graph/src/datafusion_planner/analysis.rs index bbe0943e..113072c2 100644 --- a/crates/lance-graph/src/datafusion_planner/analysis.rs +++ b/crates/lance-graph/src/datafusion_planner/analysis.rs @@ -206,6 +206,11 @@ fn analyze_operator( analyze_operator(left, analysis, rel_counter)?; analyze_operator(right, analysis, rel_counter)?; } + LogicalOperator::Unwind { input, .. } => { + if let Some(op) = input { + analyze_operator(op, analysis, rel_counter)?; + } + } } Ok(()) } diff --git a/crates/lance-graph/src/datafusion_planner/builder/basic_ops.rs b/crates/lance-graph/src/datafusion_planner/builder/basic_ops.rs index 3df5fc75..59f5807a 100644 --- a/crates/lance-graph/src/datafusion_planner/builder/basic_ops.rs +++ b/crates/lance-graph/src/datafusion_planner/builder/basic_ops.rs @@ -141,6 +141,45 @@ impl DataFusionPlanner { .build() .map_err(|e| self.plan_error("Failed to build plan", e)) } + + pub(crate) fn build_unwind( + &self, + ctx: &mut PlanningContext, + input: &Option>, + expression: &crate::ast::ValueExpression, + alias: &str, + ) -> Result { + let input_plan = if let Some(input_op) = input { + self.build_operator(ctx, input_op)? + } else { + // Create an empty relation that produces one row for UNWIND [..] + LogicalPlanBuilder::empty(true) + .build() + .map_err(|e| self.plan_error("Failed to create empty relation", e))? + }; + + // Convert expression to DataFusion Expr + let df_expr = super::super::expression::to_df_value_expr(expression); + + // We project the list expression first (aliased as the target alias temporarily) + // DataFusion unnest takes a column name. + let builder = LogicalPlanBuilder::from(input_plan); + + let builder = builder + .project(vec![ + datafusion::logical_expr::wildcard(), + datafusion::logical_expr::select_expr::SelectExpr::Expression(df_expr.alias(alias)), + ]) + .map_err(|e| self.plan_error("Failed to project unwind expression", e))?; + + let builder = builder + .unnest_column(alias) + .map_err(|e| self.plan_error("Failed to build unnest", e))?; + + builder + .build() + .map_err(|e| self.plan_error("Failed to build unwind plan", e)) + } } #[cfg(test)] diff --git a/crates/lance-graph/src/datafusion_planner/builder/helpers.rs b/crates/lance-graph/src/datafusion_planner/builder/helpers.rs index d3b60235..9d7db39a 100644 --- a/crates/lance-graph/src/datafusion_planner/builder/helpers.rs +++ b/crates/lance-graph/src/datafusion_planner/builder/helpers.rs @@ -85,6 +85,12 @@ impl DataFusionPlanner { Self::collect_variables(left, vars); Self::collect_variables(right, vars); } + LogicalOperator::Unwind { input, alias, .. } => { + if let Some(op) = input { + Self::collect_variables(op, vars); + } + vars.push(alias.clone()); + } } } } diff --git a/crates/lance-graph/src/datafusion_planner/builder/mod.rs b/crates/lance-graph/src/datafusion_planner/builder/mod.rs index f34b648b..aaf23872 100644 --- a/crates/lance-graph/src/datafusion_planner/builder/mod.rs +++ b/crates/lance-graph/src/datafusion_planner/builder/mod.rs @@ -96,6 +96,11 @@ impl DataFusionPlanner { right, join_type, } => self.build_join(ctx, left, right, join_type), + LogicalOperator::Unwind { + input, + expression, + alias, + } => self.build_unwind(ctx, input, expression, alias), } } } diff --git a/crates/lance-graph/src/logical_plan.rs b/crates/lance-graph/src/logical_plan.rs index 392feb64..9fae28c6 100644 --- a/crates/lance-graph/src/logical_plan.rs +++ b/crates/lance-graph/src/logical_plan.rs @@ -23,6 +23,16 @@ pub enum LogicalOperator { properties: HashMap, }, + /// Unwind a list into a sequence of rows + Unwind { + /// The input operator + input: Option>, + /// The expression to unwind (must yield a list) + expression: ValueExpression, + /// The alias for the unwound element + alias: String, + }, + /// Apply a filter predicate (WHERE clause) Filter { input: Box, @@ -153,8 +163,8 @@ impl LogicalPlanner { /// Convert a Cypher AST to a logical plan pub fn plan(&mut self, query: &CypherQuery) -> Result { - // Start with the MATCH clause(s) - let mut plan = self.plan_match_clauses(&query.match_clauses)?; + // Plan main MATCH clauses + let mut plan = self.plan_reading_clauses(None, &query.reading_clauses)?; // Apply WHERE clause if present (before WITH) if let Some(where_clause) = &query.where_clause { @@ -169,9 +179,9 @@ impl LogicalPlanner { plan = self.plan_with_clause(with_clause, plan)?; } - // Apply post-WITH MATCH clauses if present (query chaining) - for match_clause in &query.post_with_match_clauses { - plan = self.plan_match_clause_with_base(Some(plan), match_clause)?; + // Plan post-WITH MATCH clauses + if !query.post_with_reading_clauses.is_empty() { + plan = self.plan_reading_clauses(Some(plan), &query.post_with_reading_clauses)?; } // Apply post-WITH WHERE clause if present @@ -219,25 +229,63 @@ impl LogicalPlanner { Ok(plan) } - /// Plan MATCH clauses - the core graph pattern matching - fn plan_match_clauses(&mut self, match_clauses: &[MatchClause]) -> Result { - if match_clauses.is_empty() { + fn plan_reading_clauses( + &mut self, + base_plan: Option, + reading_clauses: &[ReadingClause], + ) -> Result { + let mut plan = base_plan; + + if reading_clauses.is_empty() && plan.is_none() { return Err(GraphError::PlanError { - message: "Query must have at least one MATCH clause".to_string(), + message: "Query must have at least one MATCH or UNWIND clause".to_string(), location: snafu::Location::new(file!(), line!(), column!()), }); } - let plan = match_clauses.iter().try_fold(None, |plan, clause| { - self.plan_match_clause_with_base(plan, clause).map(Some) - })?; + for clause in reading_clauses { + plan = Some(self.plan_reading_clause_with_base(plan, clause)?); + } plan.ok_or_else(|| GraphError::PlanError { - message: "Failed to plan MATCH clauses".to_string(), + message: "Failed to plan clauses".to_string(), location: snafu::Location::new(file!(), line!(), column!()), }) } + /// Plan a single READING clause, optionally starting from an existing base plan + fn plan_reading_clause_with_base( + &mut self, + base: Option, + clause: &ReadingClause, + ) -> Result { + match clause { + ReadingClause::Match(match_clause) => { + self.plan_match_clause_with_base(base, match_clause) + } + ReadingClause::Unwind(unwind_clause) => { + self.plan_unwind_clause_with_base(base, unwind_clause) + } + } + } + + /// Plan an UNWIND clause + fn plan_unwind_clause_with_base( + &mut self, + base: Option, + unwind_clause: &UnwindClause, + ) -> Result { + // Register the alias variable + self.variables + .insert(unwind_clause.alias.clone(), "Unwound".to_string()); + + Ok(LogicalOperator::Unwind { + input: base.map(Box::new), + expression: unwind_clause.expression.clone(), + alias: unwind_clause.alias.clone(), + }) + } + /// Plan a single MATCH clause, optionally starting from an existing base plan fn plan_match_clause_with_base( &mut self, @@ -398,6 +446,7 @@ impl LogicalPlanner { fn extract_variable_from_plan(&self, plan: &LogicalOperator) -> Result { match plan { LogicalOperator::ScanByLabel { variable, .. } => Ok(variable.clone()), + LogicalOperator::Unwind { alias, .. } => Ok(alias.clone()), LogicalOperator::Expand { target_variable, .. } => Ok(target_variable.clone()), diff --git a/crates/lance-graph/src/parser.rs b/crates/lance-graph/src/parser.rs index b458dd1b..79a89916 100644 --- a/crates/lance-graph/src/parser.rs +++ b/crates/lance-graph/src/parser.rs @@ -41,17 +41,17 @@ pub fn parse_cypher_query(input: &str) -> Result { // Top-level parser for a complete Cypher query fn cypher_query(input: &str) -> IResult<&str, CypherQuery> { let (input, _) = multispace0(input)?; - let (input, match_clauses) = many0(match_clause)(input)?; + let (input, reading_clauses) = many0(reading_clause)(input)?; let (input, pre_with_where) = opt(where_clause)(input)?; // Optional WITH clause with optional post-WITH MATCH and WHERE let (input, with_result) = opt(with_clause)(input)?; // Only try to parse post-WITH clauses if we have a WITH clause - let (input, post_with_matches, post_with_where) = match with_result { + let (input, post_with_reading_clauses, post_with_where) = match with_result { Some(_) => { - let (input, matches) = many0(match_clause)(input)?; + let (input, readings) = many0(reading_clause)(input)?; let (input, where_cl) = opt(where_clause)(input)?; - (input, matches, where_cl) + (input, readings, where_cl) } None => (input, vec![], None), }; @@ -64,10 +64,10 @@ fn cypher_query(input: &str) -> IResult<&str, CypherQuery> { Ok(( input, CypherQuery { - match_clauses, + reading_clauses, where_clause: pre_with_where, with_clause: with_result, - post_with_match_clauses: post_with_matches, + post_with_reading_clauses, post_with_where_clause: post_with_where, return_clause, limit, @@ -77,6 +77,14 @@ fn cypher_query(input: &str) -> IResult<&str, CypherQuery> { )) } +// Parse a reading clause (MATCH or UNWIND) +fn reading_clause(input: &str) -> IResult<&str, ReadingClause> { + alt(( + map(match_clause, ReadingClause::Match), + map(unwind_clause, ReadingClause::Unwind), + ))(input) +} + // Parse a MATCH clause fn match_clause(input: &str) -> IResult<&str, MatchClause> { let (input, _) = multispace0(input)?; @@ -87,6 +95,26 @@ fn match_clause(input: &str) -> IResult<&str, MatchClause> { Ok((input, MatchClause { patterns })) } +// Parse an UNWIND clause +fn unwind_clause(input: &str) -> IResult<&str, UnwindClause> { + let (input, _) = multispace0(input)?; + let (input, _) = tag_no_case("UNWIND")(input)?; + let (input, _) = multispace1(input)?; + let (input, expression) = value_expression(input)?; + let (input, _) = multispace1(input)?; + let (input, _) = tag_no_case("AS")(input)?; + let (input, _) = multispace1(input)?; + let (input, alias) = identifier(input)?; + + Ok(( + input, + UnwindClause { + expression, + alias: alias.to_string(), + }, + )) +} + // Parse a graph pattern (node or path) fn graph_pattern(input: &str) -> IResult<&str, GraphPattern> { alt(( @@ -967,7 +995,7 @@ mod tests { let query = "MATCH (n:Person) RETURN n.name"; let result = parse_cypher_query(query).unwrap(); - assert_eq!(result.match_clauses.len(), 1); + assert_eq!(result.reading_clauses.len(), 1); assert_eq!(result.return_clause.items.len(), 1); } @@ -976,11 +1004,15 @@ mod tests { let query = r#"MATCH (n:Person {name: "John", age: 30}) RETURN n"#; let result = parse_cypher_query(query).unwrap(); - if let GraphPattern::Node(node) = &result.match_clauses[0].patterns[0] { - assert_eq!(node.labels, vec!["Person"]); - assert_eq!(node.properties.len(), 2); + if let ReadingClause::Match(match_clause) = &result.reading_clauses[0] { + if let GraphPattern::Node(node) = &match_clause.patterns[0] { + assert_eq!(node.labels, vec!["Person"]); + assert_eq!(node.properties.len(), 2); + } else { + panic!("Expected node pattern"); + } } else { - panic!("Expected node pattern"); + panic!("Expected match clause"); } } @@ -989,14 +1021,18 @@ mod tests { let query = "MATCH (a:Person)-[r:KNOWS]->(b:Person) RETURN a.name, b.name"; let result = parse_cypher_query(query).unwrap(); - assert_eq!(result.match_clauses.len(), 1); + assert_eq!(result.reading_clauses.len(), 1); assert_eq!(result.return_clause.items.len(), 2); - if let GraphPattern::Path(path) = &result.match_clauses[0].patterns[0] { - assert_eq!(path.segments.len(), 1); - assert_eq!(path.segments[0].relationship.types, vec!["KNOWS"]); + if let ReadingClause::Match(match_clause) = &result.reading_clauses[0] { + if let GraphPattern::Path(path) = &match_clause.patterns[0] { + assert_eq!(path.segments.len(), 1); + assert_eq!(path.segments[0].relationship.types, vec!["KNOWS"]); + } else { + panic!("Expected path pattern"); + } } else { - panic!("Expected path pattern"); + panic!("Expected match clause"); } } @@ -1005,17 +1041,21 @@ mod tests { let query = "MATCH (a:Person)-[:FRIEND_OF*1..2]-(b:Person) RETURN a.name, b.name"; let result = parse_cypher_query(query).unwrap(); - assert_eq!(result.match_clauses.len(), 1); + assert_eq!(result.reading_clauses.len(), 1); - if let GraphPattern::Path(path) = &result.match_clauses[0].patterns[0] { - assert_eq!(path.segments.len(), 1); - assert_eq!(path.segments[0].relationship.types, vec!["FRIEND_OF"]); + if let ReadingClause::Match(match_clause) = &result.reading_clauses[0] { + if let GraphPattern::Path(path) = &match_clause.patterns[0] { + assert_eq!(path.segments.len(), 1); + assert_eq!(path.segments[0].relationship.types, vec!["FRIEND_OF"]); - let length = path.segments[0].relationship.length.as_ref().unwrap(); - assert_eq!(length.min, Some(1)); - assert_eq!(length.max, Some(2)); + let length = path.segments[0].relationship.length.as_ref().unwrap(); + assert_eq!(length.min, Some(1)); + assert_eq!(length.max, Some(2)); + } else { + panic!("Expected path pattern"); + } } else { - panic!("Expected path pattern"); + panic!("Expected match clause"); } } @@ -1757,4 +1797,19 @@ mod tests { _ => panic!("Expected comparison"), } } + + // UNWIND parser tests + #[test] + fn test_parse_unwind_simple() { + let query = "UNWIND [1, 2, 3] AS num RETURN num"; + let ast = parse_cypher_query(query); + assert!(ast.is_ok(), "Failed to parse simple UNWIND query"); + } + + #[test] + fn test_parse_unwind_after_match() { + let query = "MATCH (n) UNWIND n.list AS item RETURN item"; + let ast = parse_cypher_query(query); + assert!(ast.is_ok(), "Failed to parse UNWIND after MATCH"); + } } diff --git a/crates/lance-graph/src/query.rs b/crates/lance-graph/src/query.rs index 52d5840e..281a9fa0 100644 --- a/crates/lance-graph/src/query.rs +++ b/crates/lance-graph/src/query.rs @@ -4,6 +4,7 @@ //! High-level Cypher query interface for Lance datasets use crate::ast::CypherQuery as CypherAST; +use crate::ast::ReadingClause; use crate::config::GraphConfig; use crate::error::{GraphError, Result}; use crate::logical_plan::LogicalPlanner; @@ -1016,9 +1017,11 @@ impl CypherQuery { pub fn referenced_node_labels(&self) -> Vec { let mut labels = Vec::new(); - for match_clause in &self.ast.match_clauses { - for pattern in &match_clause.patterns { - self.collect_node_labels_from_pattern(pattern, &mut labels); + for clause in &self.ast.reading_clauses { + if let ReadingClause::Match(match_clause) = clause { + for pattern in &match_clause.patterns { + self.collect_node_labels_from_pattern(pattern, &mut labels); + } } } @@ -1031,9 +1034,11 @@ impl CypherQuery { pub fn referenced_relationship_types(&self) -> Vec { let mut types = Vec::new(); - for match_clause in &self.ast.match_clauses { - for pattern in &match_clause.patterns { - self.collect_relationship_types_from_pattern(pattern, &mut types); + for clause in &self.ast.reading_clauses { + if let ReadingClause::Match(match_clause) = clause { + for pattern in &match_clause.patterns { + self.collect_relationship_types_from_pattern(pattern, &mut types); + } } } @@ -1046,9 +1051,16 @@ impl CypherQuery { pub fn variables(&self) -> Vec { let mut variables = Vec::new(); - for match_clause in &self.ast.match_clauses { - for pattern in &match_clause.patterns { - self.collect_variables_from_pattern(pattern, &mut variables); + for clause in &self.ast.reading_clauses { + match clause { + ReadingClause::Match(match_clause) => { + for pattern in &match_clause.patterns { + self.collect_variables_from_pattern(pattern, &mut variables); + } + } + ReadingClause::Unwind(unwind_clause) => { + variables.push(unwind_clause.alias.clone()); + } } } @@ -1166,10 +1178,12 @@ impl CypherQuery { ctx: &datafusion::prelude::SessionContext, ) -> Result> { use crate::ast::GraphPattern; - let [mc] = self.ast.match_clauses.as_slice() else { - return Ok(None); + // Only support single MATCH clause for path execution + + let match_clause = match self.ast.reading_clauses.as_slice() { + [ReadingClause::Match(mc)] => mc, + _ => return Ok(None), }; - let match_clause = mc; let path = match match_clause.patterns.as_slice() { [GraphPattern::Path(p)] if !p.segments.is_empty() => p, _ => return Ok(None), @@ -1340,12 +1354,16 @@ impl CypherQueryBuilder { } let ast = crate::ast::CypherQuery { - match_clauses: self.match_clauses, + reading_clauses: self + .match_clauses + .into_iter() + .map(crate::ast::ReadingClause::Match) + .collect(), where_clause: self .where_expression .map(|expr| crate::ast::WhereClause { expression: expr }), with_clause: None, // WITH not supported via builder yet - post_with_match_clauses: vec![], + post_with_reading_clauses: vec![], post_with_where_clause: None, return_clause: crate::ast::ReturnClause { distinct: self.distinct, diff --git a/crates/lance-graph/src/semantic.rs b/crates/lance-graph/src/semantic.rs index dfe44537..926fc0c9 100644 --- a/crates/lance-graph/src/semantic.rs +++ b/crates/lance-graph/src/semantic.rs @@ -73,11 +73,20 @@ impl SemanticAnalyzer { let mut errors = Vec::new(); let mut warnings = Vec::new(); - // Phase 1: Variable discovery in MATCH clauses + // Phase 1: Variable discovery in READING clauses (MATCH/UNWIND) self.current_scope = ScopeType::Match; - for match_clause in &query.match_clauses { - if let Err(e) = self.analyze_match_clause(match_clause) { - errors.push(format!("MATCH clause error: {}", e)); + for clause in &query.reading_clauses { + match clause { + ReadingClause::Match(match_clause) => { + if let Err(e) = self.analyze_match_clause(match_clause) { + errors.push(format!("MATCH clause error: {}", e)); + } + } + ReadingClause::Unwind(unwind_clause) => { + if let Err(e) = self.analyze_unwind_clause(unwind_clause) { + errors.push(format!("UNWIND clause error: {}", e)); + } + } } } @@ -141,6 +150,35 @@ impl SemanticAnalyzer { Ok(()) } + /// Analyze UNWIND clause and register variables + fn analyze_unwind_clause(&mut self, unwind_clause: &UnwindClause) -> Result<()> { + self.analyze_value_expression(&unwind_clause.expression)?; + + // Register the aliased variable + let var_name = &unwind_clause.alias; + if let Some(existing) = self.variables.get_mut(var_name) { + // Shadowing or redefinition - in Cypher variables can be bound multiple times in some contexts + // But here we enforce uniqueness of types mostly. + // For now, treat UNWIND alias as a Property type variable. + if existing.variable_type != VariableType::Property { + return Err(GraphError::PlanError { + message: format!("Variable '{}' redefined with different type", var_name), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } + } else { + let var_info = VariableInfo { + name: var_name.clone(), + variable_type: VariableType::Property, + labels: vec![], + properties: HashSet::new(), + defined_in: self.current_scope.clone(), + }; + self.variables.insert(var_name.clone(), var_info); + } + Ok(()) + } + /// Analyze a graph pattern and register variables fn analyze_graph_pattern(&mut self, pattern: &GraphPattern) -> Result<()> { match pattern { @@ -589,10 +627,10 @@ mod tests { // Helper: analyze a query that only has a single RETURN expression fn analyze_return_expr(expr: ValueExpression) -> Result { let query = CypherQuery { - match_clauses: vec![], + reading_clauses: vec![], where_clause: None, with_clause: None, - post_with_match_clauses: vec![], + post_with_reading_clauses: vec![], post_with_where_clause: None, return_clause: ReturnClause { distinct: false, @@ -617,12 +655,12 @@ mod tests { ) -> Result { let node = NodePattern::new(Some(var.to_string())).with_label(label); let query = CypherQuery { - match_clauses: vec![MatchClause { + reading_clauses: vec![ReadingClause::Match(MatchClause { patterns: vec![GraphPattern::Node(node)], - }], + })], where_clause: None, with_clause: None, - post_with_match_clauses: vec![], + post_with_reading_clauses: vec![], post_with_where_clause: None, return_clause: ReturnClause { distinct: false, @@ -650,12 +688,12 @@ mod tests { .with_property("dept", PropertyValue::String("X".to_string())); let query = CypherQuery { - match_clauses: vec![MatchClause { + reading_clauses: vec![ReadingClause::Match(MatchClause { patterns: vec![GraphPattern::Node(node1), GraphPattern::Node(node2)], - }], + })], where_clause: None, with_clause: None, - post_with_match_clauses: vec![], + post_with_reading_clauses: vec![], post_with_where_clause: None, return_clause: ReturnClause { distinct: false, @@ -699,12 +737,12 @@ mod tests { }; let query = CypherQuery { - match_clauses: vec![MatchClause { + reading_clauses: vec![ReadingClause::Match(MatchClause { patterns: vec![GraphPattern::Path(path)], - }], + })], where_clause: None, with_clause: None, - post_with_match_clauses: vec![], + post_with_reading_clauses: vec![], post_with_where_clause: None, return_clause: ReturnClause { distinct: false, @@ -731,12 +769,12 @@ mod tests { expression: BooleanExpression::Exists(PropertyRef::new("m", "name")), }; let query = CypherQuery { - match_clauses: vec![MatchClause { + reading_clauses: vec![ReadingClause::Match(MatchClause { patterns: vec![GraphPattern::Node(node)], - }], + })], where_clause: Some(where_clause), with_clause: None, - post_with_match_clauses: vec![], + post_with_reading_clauses: vec![], post_with_where_clause: None, return_clause: ReturnClause { distinct: false, @@ -773,12 +811,12 @@ mod tests { }; let query = CypherQuery { - match_clauses: vec![MatchClause { + reading_clauses: vec![ReadingClause::Match(MatchClause { patterns: vec![GraphPattern::Path(path)], - }], + })], where_clause: None, with_clause: None, - post_with_match_clauses: vec![], + post_with_reading_clauses: vec![], post_with_where_clause: None, return_clause: ReturnClause { distinct: false, @@ -802,13 +840,13 @@ mod tests { // MATCH (x:Unknown) let node = NodePattern::new(Some("x".to_string())).with_label("Unknown"); let query = CypherQuery { - match_clauses: vec![MatchClause { + reading_clauses: vec![ReadingClause::Match(MatchClause { patterns: vec![GraphPattern::Node(node)], - }], + })], + post_with_reading_clauses: vec![], + post_with_where_clause: None, where_clause: None, with_clause: None, - post_with_match_clauses: vec![], - post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], @@ -842,13 +880,13 @@ mod tests { .with_label("Person") .with_property("age", PropertyValue::Integer(30)); let query = CypherQuery { - match_clauses: vec![MatchClause { + reading_clauses: vec![ReadingClause::Match(MatchClause { patterns: vec![GraphPattern::Node(node)], - }], + })], + post_with_reading_clauses: vec![], + post_with_where_clause: None, where_clause: None, with_clause: None, - post_with_match_clauses: vec![], - post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], @@ -887,13 +925,13 @@ mod tests { }; let query = CypherQuery { - match_clauses: vec![MatchClause { + reading_clauses: vec![ReadingClause::Match(MatchClause { patterns: vec![GraphPattern::Path(path)], - }], + })], + post_with_reading_clauses: vec![], + post_with_where_clause: None, where_clause: None, with_clause: None, - post_with_match_clauses: vec![], - post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], @@ -954,13 +992,13 @@ mod tests { .unwrap(); let query = CypherQuery { - match_clauses: vec![MatchClause { + reading_clauses: vec![ReadingClause::Match(MatchClause { patterns: vec![GraphPattern::Path(path)], - }], + })], + post_with_reading_clauses: vec![], + post_with_where_clause: None, where_clause: None, with_clause: None, - post_with_match_clauses: vec![], - post_with_where_clause: None, return_clause: ReturnClause { distinct: false, items: vec![], diff --git a/crates/lance-graph/tests/test_datafusion_pipeline.rs b/crates/lance-graph/tests/test_datafusion_pipeline.rs index 7f7d4ec6..b9725752 100644 --- a/crates/lance-graph/tests/test_datafusion_pipeline.rs +++ b/crates/lance-graph/tests/test_datafusion_pipeline.rs @@ -4506,3 +4506,186 @@ async fn test_with_post_match_chaining() { assert!(result.column_by_name("city").is_some()); assert!(result.column_by_name("cnt").is_some()); } + +// ============================================================================ +// UNWIND Tests +// ============================================================================ + +#[tokio::test] +async fn test_unwind_simple_list() { + let config = create_graph_config(); + // We need at least one dataset to initialize the catalog/context even if not used in query + let person_batch = create_person_dataset(); + + let query = CypherQuery::new("UNWIND [1, 2, 3] AS x RETURN x") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + assert_eq!(result.num_rows(), 3); + assert_eq!(result.num_columns(), 1); + + // Check the column type - UNWIND returns different types based on implementation + let col = result.column(0); + let data_type = col.data_type(); + + // Try to get values as Int64, Float32, or Float64 + if let Some(int_values) = col.as_any().downcast_ref::() { + let result_values: Vec = (0..result.num_rows()) + .map(|i| int_values.value(i)) + .collect(); + assert_eq!(result_values, vec![1, 2, 3]); + } else if let Some(float_values) = col.as_any().downcast_ref::() { + let result_values: Vec = (0..result.num_rows()) + .map(|i| float_values.value(i)) + .collect(); + assert_eq!(result_values, vec![1.0, 2.0, 3.0]); + } else if let Some(float_values) = col.as_any().downcast_ref::() { + let result_values: Vec = (0..result.num_rows()) + .map(|i| float_values.value(i)) + .collect(); + assert_eq!(result_values, vec![1.0, 2.0, 3.0]); + } else { + panic!("Unexpected column type: {:?}", data_type); + } +} + +#[tokio::test] +async fn test_unwind_after_match() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + + // Alice, Bob + // UNWIND [10, 20] -> Cartesian product: (Alice, 10), (Alice, 20), (Bob, 10), (Bob, 20) + let query = CypherQuery::new("MATCH (p:Person) UNWIND [10, 20] AS x RETURN p.name, x") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + assert_eq!(result.num_rows(), 10); // 5 people * 2 values = 10 rows + assert_eq!(result.num_columns(), 2); + + let names = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // Try Int64, Float32, or Float64 for the unwound values + let mut rows: Vec<(String, i32)> = if let Some(int_values) = + result.column(1).as_any().downcast_ref::() + { + (0..result.num_rows()) + .map(|i| (names.value(i).to_string(), int_values.value(i) as i32)) + .collect() + } else if let Some(float_values) = result + .column(1) + .as_any() + .downcast_ref::() + { + (0..result.num_rows()) + .map(|i| (names.value(i).to_string(), float_values.value(i) as i32)) + .collect() + } else if let Some(float_values) = result.column(1).as_any().downcast_ref::() { + (0..result.num_rows()) + .map(|i| (names.value(i).to_string(), float_values.value(i) as i32)) + .collect() + } else { + panic!( + "Unexpected column type for unwound values: {:?}", + result.column(1).data_type() + ); + }; + + rows.sort(); + + let expected = vec![ + ("Alice".to_string(), 10), + ("Alice".to_string(), 20), + ("Bob".to_string(), 10), + ("Bob".to_string(), 20), + ("Charlie".to_string(), 10), + ("Charlie".to_string(), 20), + ("David".to_string(), 10), + ("David".to_string(), 20), + ("Eve".to_string(), 10), + ("Eve".to_string(), 20), + ]; + + assert_eq!(rows, expected); +} + +#[tokio::test] +async fn test_unwind_then_match() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + + // UNWIND [1, 2] usually yields Float64Array in current parser/planner pipeline + // p.id is Int64 + // We expect DataFusion to handle comparison (maybe with cast) + let query = CypherQuery::new("UNWIND [1, 2] AS target_id MATCH (p:Person) WHERE p.id = target_id RETURN p.name, target_id") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + // We expect 2 matches: Alice (id=1) and Bob (id=2) + assert_eq!(result.num_rows(), 2); + + let names = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // target_id from Unwind might be Int64, Float32, or Float64 + let mut rows: Vec<(String, i32)> = + if let Some(int_ids) = result.column(1).as_any().downcast_ref::() { + (0..result.num_rows()) + .map(|i| (names.value(i).to_string(), int_ids.value(i) as i32)) + .collect() + } else if let Some(float_ids) = result + .column(1) + .as_any() + .downcast_ref::() + { + (0..result.num_rows()) + .map(|i| (names.value(i).to_string(), float_ids.value(i) as i32)) + .collect() + } else if let Some(float_ids) = result.column(1).as_any().downcast_ref::() { + (0..result.num_rows()) + .map(|i| (names.value(i).to_string(), float_ids.value(i) as i32)) + .collect() + } else { + panic!( + "Unexpected column type for target_id: {:?}", + result.column(1).data_type() + ); + }; + + rows.sort(); + + let expected = vec![("Alice".to_string(), 1), ("Bob".to_string(), 2)]; + + assert_eq!(rows, expected); +}