From 2d408be03791a154e059ec6b9f71f3400c99bd3e Mon Sep 17 00:00:00 2001 From: JoshuaTang <1240604020@qq.com> Date: Tue, 11 Nov 2025 19:27:55 -0800 Subject: [PATCH 1/2] feat: validation of properties for aggregation --- rust/lance-graph/src/semantic.rs | 144 +++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/rust/lance-graph/src/semantic.rs b/rust/lance-graph/src/semantic.rs index b613fa69..80c592fd 100644 --- a/rust/lance-graph/src/semantic.rs +++ b/rust/lance-graph/src/semantic.rs @@ -288,6 +288,20 @@ impl SemanticAnalyzer { location: snafu::Location::new(file!(), line!(), column!()), }); } + + // Additional validation for SUM, AVG, MIN, MAX: they require properties, not bare variables + // Only COUNT allows bare variables (COUNT(*) or COUNT(p)) + if matches!(name.to_lowercase().as_str(), "sum" | "avg" | "min" | "max") { + if let Some(ValueExpression::Variable(v)) = args.first() { + return Err(GraphError::PlanError { + message: format!( + "{}({}) is invalid - {} requires a property like {}({}.property). You cannot {} a node/entity.", + name.to_uppercase(), v, name.to_uppercase(), name.to_uppercase(), v, name.to_lowercase() + ), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } + } } _ => { // Other functions - no validation yet @@ -957,6 +971,136 @@ mod tests { ); } + #[test] + fn test_sum_with_variable_fails_validation() { + let expr = ValueExpression::Function { + name: "sum".to_string(), + args: vec![ValueExpression::Variable("n".to_string())], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + !result.errors.is_empty(), + "Expected SUM(variable) to produce validation errors" + ); + let has_sum_error = result + .errors + .iter() + .any(|e| e.contains("SUM(n) is invalid") && e.contains("requires a property")); + assert!( + has_sum_error, + "Expected error about SUM requiring property, got: {:?}", + result.errors + ); + } + + #[test] + fn test_avg_with_variable_fails_validation() { + let expr = ValueExpression::Function { + name: "avg".to_string(), + args: vec![ValueExpression::Variable("n".to_string())], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + !result.errors.is_empty(), + "Expected AVG(variable) to produce validation errors" + ); + let has_avg_error = result + .errors + .iter() + .any(|e| e.contains("AVG(n) is invalid") && e.contains("requires a property")); + assert!( + has_avg_error, + "Expected error about AVG requiring property, got: {:?}", + result.errors + ); + } + + #[test] + fn test_sum_with_property_passes_validation() { + let expr = ValueExpression::Function { + name: "sum".to_string(), + args: vec![ValueExpression::Property(PropertyRef::new("n", "age"))], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + result.errors.is_empty(), + "SUM with property should pass validation, got errors: {:?}", + result.errors + ); + } + + #[test] + fn test_min_with_variable_fails_validation() { + let expr = ValueExpression::Function { + name: "min".to_string(), + args: vec![ValueExpression::Variable("n".to_string())], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + !result.errors.is_empty(), + "Expected MIN(variable) to produce validation errors" + ); + let has_min_error = result + .errors + .iter() + .any(|e| e.contains("MIN(n) is invalid") && e.contains("requires a property")); + assert!( + has_min_error, + "Expected error about MIN requiring property, got: {:?}", + result.errors + ); + } + + #[test] + fn test_max_with_variable_fails_validation() { + let expr = ValueExpression::Function { + name: "max".to_string(), + args: vec![ValueExpression::Variable("n".to_string())], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + !result.errors.is_empty(), + "Expected MAX(variable) to produce validation errors" + ); + let has_max_error = result + .errors + .iter() + .any(|e| e.contains("MAX(n) is invalid") && e.contains("requires a property")); + assert!( + has_max_error, + "Expected error about MAX requiring property, got: {:?}", + result.errors + ); + } + + #[test] + fn test_min_with_property_passes_validation() { + let expr = ValueExpression::Function { + name: "min".to_string(), + args: vec![ValueExpression::Property(PropertyRef::new("n", "age"))], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + result.errors.is_empty(), + "MIN with property should pass validation, got errors: {:?}", + result.errors + ); + } + + #[test] + fn test_max_with_property_passes_validation() { + let expr = ValueExpression::Function { + name: "max".to_string(), + args: vec![ValueExpression::Property(PropertyRef::new("n", "age"))], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + result.errors.is_empty(), + "MAX with property should pass validation, got errors: {:?}", + result.errors + ); + } + #[test] fn test_arithmetic_with_non_numeric_literal_error() { // RETURN "x" + 1 From 821f98eeca4c6391084b94a8d67cc5b245e829ac Mon Sep 17 00:00:00 2001 From: JoshuaTang <1240604020@qq.com> Date: Tue, 11 Nov 2025 19:28:29 -0800 Subject: [PATCH 2/2] feat: support AVG aggregation of the datafusion planner --- .../src/datafusion_planner/expression.rs | 39 +++- .../tests/test_datafusion_pipeline.rs | 167 +++++++++++++++++- 2 files changed, 202 insertions(+), 4 deletions(-) diff --git a/rust/lance-graph/src/datafusion_planner/expression.rs b/rust/lance-graph/src/datafusion_planner/expression.rs index 5672136e..21b28208 100644 --- a/rust/lance-graph/src/datafusion_planner/expression.rs +++ b/rust/lance-graph/src/datafusion_planner/expression.rs @@ -7,6 +7,7 @@ use crate::ast::{BooleanExpression, PropertyValue, ValueExpression}; use datafusion::logical_expr::{col, lit, BinaryExpr, Expr, Operator}; +use datafusion_functions_aggregate::average::avg; use datafusion_functions_aggregate::count::count; use datafusion_functions_aggregate::sum::sum; @@ -87,14 +88,18 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { match name.to_lowercase().as_str() { "count" => { if args.len() == 1 { - // Check for COUNT(*) let arg_expr = if let VE::Variable(v) = &args[0] { if v == "*" { + // COUNT(*) - count all rows including NULLs lit(1) } else { - to_df_value_expr(&args[0]) + // COUNT(p) - count non-NULL rows by using a representative column + // Use __id as a null-sensitive column + // This ensures optional matches with NULL variables are not counted + col(format!("{}__id", v)) } } else { + // COUNT(p.property) - count non-null values of that property to_df_value_expr(&args[0]) }; @@ -107,14 +112,23 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } "sum" => { if args.len() == 1 { + // Note: SUM(variable) is rejected by semantic validation + // So we only handle valid cases here let arg_expr = to_df_value_expr(&args[0]); - // Use DataFusion's sum helper function sum(arg_expr) } else { // Invalid argument count - return placeholder lit(0) } } + "avg" => { + if args.len() == 1 { + let arg_expr = to_df_value_expr(&args[0]); + avg(arg_expr) + } else { + lit(0) + } + } _ => { // Unsupported function - return placeholder for now lit(0) @@ -520,6 +534,25 @@ mod tests { assert!(s.contains("p__amount"), "Should contain column reference"); } + #[test] + fn test_value_expr_function_avg() { + let expr = ValueExpression::Function { + name: "AVG".into(), + args: vec![ValueExpression::Property(PropertyRef { + variable: "p".into(), + property: "amount".into(), + })], + }; + + let df_expr = to_df_value_expr(&expr); + let s = format!("{:?}", df_expr); + assert!( + s.contains("avg") || s.contains("Avg"), + "Should be AVG function" + ); + assert!(s.contains("p__amount"), "Should contain column reference"); + } + // ======================================================================== // Unit tests for contains_aggregate() // ======================================================================== diff --git a/rust/lance-graph/tests/test_datafusion_pipeline.rs b/rust/lance-graph/tests/test_datafusion_pipeline.rs index c0cba570..177afa66 100644 --- a/rust/lance-graph/tests/test_datafusion_pipeline.rs +++ b/rust/lance-graph/tests/test_datafusion_pipeline.rs @@ -1,4 +1,4 @@ -use arrow_array::{Array, Int64Array, RecordBatch, StringArray}; +use arrow_array::{Array, Float64Array, Int64Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use lance_graph::config::GraphConfig; use lance_graph::query::CypherQuery; @@ -2446,6 +2446,34 @@ async fn test_count_star_all_nodes() { assert_eq!(count_col.value(0), 5); } +#[tokio::test] +async fn test_count_variable() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN count(p) AS total") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query.execute_datafusion(datasets).await.unwrap(); + + assert_eq!(result.num_rows(), 1); + let count_col = result + .column_by_name("total") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + // count(p) should work like count(*) - count all rows + assert_eq!(count_col.value(0), 5); +} + #[tokio::test] async fn test_count_with_filter() { let person_batch = create_person_dataset(); @@ -2747,6 +2775,143 @@ async fn test_sum_without_alias_has_descriptive_name() { ); } +#[tokio::test] +async fn test_avg_property() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN avg(p.age) AS average_age") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query.execute_datafusion(datasets).await.unwrap(); + + assert_eq!(result.num_rows(), 1); + let avg_col = result + .column_by_name("average_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + // Average of ages: (25 + 35 + 30 + 40 + 28) / 5 = 158 / 5 = 31.6 + assert_eq!(avg_col.value(0), 31.6); +} + +#[tokio::test] +async fn test_avg_with_filter() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = + CypherQuery::new("MATCH (p:Person) WHERE p.age >= 30 RETURN avg(p.age) AS average_age") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query.execute_datafusion(datasets).await.unwrap(); + + assert_eq!(result.num_rows(), 1); + let avg_col = result + .column_by_name("average_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + // Average of ages >= 30: (35 + 30 + 40) / 3 = 105 / 3 = 35.0 + assert_eq!(avg_col.value(0), 35.0); +} + +#[tokio::test] +async fn test_avg_with_grouping() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new( + "MATCH (p:Person) RETURN p.city, avg(p.age) AS average_age ORDER BY p.city", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query.execute_datafusion(datasets).await.unwrap(); + + // Should have 5 groups: NULL, Chicago, New York, San Francisco, Seattle + assert_eq!(result.num_rows(), 5); + + let city_col = result + .column_by_name("p.city") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let avg_col = result + .column_by_name("average_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // Verify grouping results (ordered by city, NULL comes first) + assert!(city_col.is_null(0)); // David: 40 (NULL city) + assert_eq!(avg_col.value(0), 40.0); + + assert_eq!(city_col.value(1), "Chicago"); // Charlie: 30 + assert_eq!(avg_col.value(1), 30.0); + + assert_eq!(city_col.value(2), "New York"); // Alice: 25 + assert_eq!(avg_col.value(2), 25.0); + + assert_eq!(city_col.value(3), "San Francisco"); // Bob: 35 + assert_eq!(avg_col.value(3), 35.0); + + assert_eq!(city_col.value(4), "Seattle"); // Eve: 28 + assert_eq!(avg_col.value(4), 28.0); +} + +#[tokio::test] +async fn test_avg_without_alias_has_descriptive_name() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN avg(p.age)") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query.execute_datafusion(datasets).await.unwrap(); + + assert_eq!(result.num_rows(), 1); + // Should have column named "avg(p.age)" not "expr" + let avg_col = result.column_by_name("avg(p.age)"); + assert!( + avg_col.is_some(), + "Expected column named 'avg(p.age)' but schema is: {:?}", + result.schema() + ); +} + // ============================================================================ // Disconnected Pattern (Join) Tests // ============================================================================