diff --git a/crates/lance-graph/src/ast.rs b/crates/lance-graph/src/ast.rs index 06b27f14..86e72dee 100644 --- a/crates/lance-graph/src/ast.rs +++ b/crates/lance-graph/src/ast.rs @@ -316,11 +316,20 @@ pub enum ValueExpression { Property(PropertyRef), /// Literal value Literal(PropertyValue), - /// Function call - Function { + /// Scalar function call (toLower, upper, etc.) + /// These are row-level functions that operate on individual values + ScalarFunction { name: String, args: Vec, }, + /// Aggregate function call (COUNT, SUM, AVG, MIN, MAX, COLLECT) + /// These functions operate across multiple rows and support DISTINCT + AggregateFunction { + name: String, + args: Vec, + /// Whether DISTINCT keyword was specified (e.g., COUNT(DISTINCT x)) + distinct: bool, + }, /// Arithmetic operation Arithmetic { left: Box, @@ -348,6 +357,27 @@ pub enum ValueExpression { VectorLiteral(Vec), } +/// Function type classification +#[derive(Debug, Clone, PartialEq)] +pub enum FunctionType { + /// Aggregate function (operates across multiple rows) + Aggregate, + /// Scalar function (operates on individual values) + Scalar, + /// Unknown function type + Unknown, +} + +/// Classify a function by name +pub fn classify_function(name: &str) -> FunctionType { + match name.to_lowercase().as_str() { + "count" | "sum" | "avg" | "min" | "max" | "collect" => FunctionType::Aggregate, + "tolower" | "lower" | "toupper" | "upper" => FunctionType::Scalar, + // Vector functions are handled separately as special variants + _ => FunctionType::Unknown, + } +} + /// Arithmetic operators #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum ArithmeticOperator { diff --git a/crates/lance-graph/src/datafusion_planner/builder/aggregate_ops.rs b/crates/lance-graph/src/datafusion_planner/builder/aggregate_ops.rs index 224a2459..6feb2c2e 100644 --- a/crates/lance-graph/src/datafusion_planner/builder/aggregate_ops.rs +++ b/crates/lance-graph/src/datafusion_planner/builder/aggregate_ops.rs @@ -88,9 +88,10 @@ mod tests { let project = LogicalOperator::Project { input: Box::new(scan), projections: vec![ProjectionItem { - expression: ValueExpression::Function { + expression: ValueExpression::AggregateFunction { name: "count".to_string(), args: vec![ValueExpression::Variable("*".to_string())], + distinct: false, }, alias: Some("total".to_string()), }], @@ -114,9 +115,10 @@ mod tests { let project = LogicalOperator::Project { input: Box::new(scan), projections: vec![ProjectionItem { - expression: ValueExpression::Function { + expression: ValueExpression::AggregateFunction { name: "count".to_string(), args: vec![ValueExpression::Variable("*".to_string())], + distinct: false, }, alias: None, }], diff --git a/crates/lance-graph/src/datafusion_planner/expression.rs b/crates/lance-graph/src/datafusion_planner/expression.rs index 9372afed..2ba65e37 100644 --- a/crates/lance-graph/src/datafusion_planner/expression.rs +++ b/crates/lance-graph/src/datafusion_planner/expression.rs @@ -13,6 +13,7 @@ use datafusion::logical_expr::{col, lit, BinaryExpr, Expr, Operator}; use datafusion_functions_aggregate::array_agg::array_agg; use datafusion_functions_aggregate::average::avg; use datafusion_functions_aggregate::count::count; +use datafusion_functions_aggregate::count::count_distinct; use datafusion_functions_aggregate::min_max::max; use datafusion_functions_aggregate::min_max::min; use datafusion_functions_aggregate::sum::sum; @@ -127,8 +128,37 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { let qualified_name = format!("{}__{}", prop.variable, prop.property); col(&qualified_name) } - VE::Function { name, args } => { - // Handle aggregation functions + VE::ScalarFunction { name, args } => { + match name.to_lowercase().as_str() { + "tolower" | "lower" => { + if args.len() == 1 { + let arg_expr = to_df_value_expr(&args[0]); + lower().call(vec![arg_expr]) + } else { + // Invalid argument count - return NULL + Expr::Literal(datafusion::scalar::ScalarValue::Null, None) + } + } + "toupper" | "upper" => { + if args.len() == 1 { + let arg_expr = to_df_value_expr(&args[0]); + upper().call(vec![arg_expr]) + } else { + // Invalid argument count - return NULL + Expr::Literal(datafusion::scalar::ScalarValue::Null, None) + } + } + _ => { + // Unknown scalar function - return NULL + Expr::Literal(datafusion::scalar::ScalarValue::Null, None) + } + } + } + VE::AggregateFunction { + name, + args, + distinct, + } => { match name.to_lowercase().as_str() { "count" => { if args.len() == 1 { @@ -140,6 +170,14 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { // COUNT(p) - count non-NULL rows by using a representative column // Use __id as a null-sensitive column // This ensures optional matches with NULL variables are not counted + // + // LIMITATION: This assumes all bindings expose a __id column. + // This is true for ScanByLabel and Expand, but NOT for: + // - UNWIND-created variables (e.g., UNWIND [1,2,3] AS x) + // - WITH-projected aliases (e.g., WITH 1 AS x) + // Those cases will produce a runtime "column not found" error. + // TODO: Consider using COUNT(1) for non-node variables, or add + // semantic validation to reject COUNT(variable) for non-node types. col(format!("{}__id", v)) } } else { @@ -147,8 +185,12 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { to_df_value_expr(&args[0]) }; - // Use DataFusion's count helper function - count(arg_expr) + // Use DataFusion's count or count_distinct + if *distinct { + count_distinct(arg_expr) + } else { + count(arg_expr) + } } else { // Invalid argument count - return placeholder lit(0) @@ -156,12 +198,9 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } "sum" => { if args.len() == 1 { - // Note: SUM(variable) is rejected by semantic validation - // So we only handle valid cases here let arg_expr = to_df_value_expr(&args[0]); sum(arg_expr) } else { - // Invalid argument count - return placeholder lit(0) } } @@ -189,7 +228,6 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { lit(0) } } - // COLLECT aggregation - collects values into an array "collect" => { if args.len() == 1 { let arg_expr = to_df_value_expr(&args[0]); @@ -198,27 +236,8 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { lit(0) } } - // String functions - "tolower" | "lower" => { - if args.len() == 1 { - let arg_expr = to_df_value_expr(&args[0]); - lower().call(vec![arg_expr]) - } else { - // Invalid argument count - return NULL - Expr::Literal(datafusion::scalar::ScalarValue::Null, None) - } - } - "toupper" | "upper" => { - if args.len() == 1 { - let arg_expr = to_df_value_expr(&args[0]); - upper().call(vec![arg_expr]) - } else { - // Invalid argument count - return NULL - Expr::Literal(datafusion::scalar::ScalarValue::Null, None) - } - } _ => { - // Unsupported function - return NULL which coerces to any type + // Unsupported aggregate function - return NULL which coerces to any type // This prevents type coercion errors in both string and numeric contexts // // TODO(#107): Now that semantic analysis rejects unknown functions, consider @@ -317,15 +336,8 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { pub(crate) fn contains_aggregate(expr: &ValueExpression) -> bool { use crate::ast::ValueExpression as VE; match expr { - VE::Function { name, args } => { - // Check if this is an aggregate function - let is_aggregate = matches!( - name.to_lowercase().as_str(), - "count" | "sum" | "avg" | "min" | "max" | "collect" - ); - // Also check arguments recursively - is_aggregate || args.iter().any(contains_aggregate) - } + VE::AggregateFunction { .. } => true, + VE::ScalarFunction { args, .. } => args.iter().any(contains_aggregate), VE::Arithmetic { left, right, .. } => contains_aggregate(left) || contains_aggregate(right), VE::VectorDistance { left, right, .. } => { contains_aggregate(left) || contains_aggregate(right) @@ -357,15 +369,20 @@ pub(crate) fn to_cypher_column_name(expr: &ValueExpression) -> String { // Handle nested property references format!("{}.{}", prop.variable, prop.property) } - VE::Function { name, args } => { - // Generate descriptive function name: count(*), count(p.name), etc. + VE::ScalarFunction { name, args } | VE::AggregateFunction { name, args, .. } => { + let distinct_str = if let VE::AggregateFunction { distinct: true, .. } = expr { + "DISTINCT " + } else { + "" + }; + if args.len() == 1 { let arg_repr = match &args[0] { VE::Variable(v) => v.clone(), VE::Property(prop) => format!("{}.{}", prop.variable, prop.property), _ => "expr".to_string(), }; - format!("{}({})", name.to_lowercase(), arg_repr) + format!("{}({}{})", name.to_lowercase(), distinct_str, arg_repr) } else if args.is_empty() { format!("{}()", name.to_lowercase()) } else { @@ -960,9 +977,10 @@ mod tests { #[test] fn test_value_expr_function_count_star() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "COUNT".into(), args: vec![ValueExpression::Literal(PropertyValue::String("*".into()))], + distinct: false, }; let df_expr = to_df_value_expr(&expr); @@ -975,12 +993,13 @@ mod tests { #[test] fn test_value_expr_function_count_property() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "COUNT".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), property: "id".into(), })], + distinct: false, }; let df_expr = to_df_value_expr(&expr); @@ -994,12 +1013,13 @@ mod tests { #[test] fn test_value_expr_function_sum() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "SUM".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), property: "amount".into(), })], + distinct: false, }; let df_expr = to_df_value_expr(&expr); @@ -1013,12 +1033,13 @@ mod tests { #[test] fn test_value_expr_function_avg() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "AVG".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), property: "amount".into(), })], + distinct: false, }; let df_expr = to_df_value_expr(&expr); @@ -1032,12 +1053,13 @@ mod tests { #[test] fn test_value_expr_function_min() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "MIN".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), property: "amount".into(), })], + distinct: false, }; let df_expr = to_df_value_expr(&expr); @@ -1051,12 +1073,13 @@ mod tests { #[test] fn test_value_expr_function_max() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "MAX".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), property: "amount".into(), })], + distinct: false, }; let df_expr = to_df_value_expr(&expr); @@ -1070,7 +1093,7 @@ mod tests { #[test] fn test_value_expr_function_tolower() { - let expr = ValueExpression::Function { + let expr = ValueExpression::ScalarFunction { name: "toLower".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), @@ -1091,7 +1114,7 @@ mod tests { #[test] fn test_value_expr_function_toupper() { - let expr = ValueExpression::Function { + let expr = ValueExpression::ScalarFunction { name: "toUpper".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), @@ -1113,7 +1136,7 @@ mod tests { #[test] fn test_value_expr_function_lower_alias() { // Test that 'lower' also works (SQL-style alias) - let expr = ValueExpression::Function { + let expr = ValueExpression::ScalarFunction { name: "lower".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), @@ -1133,7 +1156,7 @@ mod tests { #[test] fn test_value_expr_function_upper_alias() { // Test that 'upper' also works (SQL-style alias) - let expr = ValueExpression::Function { + let expr = ValueExpression::ScalarFunction { name: "upper".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), @@ -1154,7 +1177,7 @@ mod tests { fn test_tolower_with_contains_produces_valid_like() { // This is the bug scenario: toLower(s.name) CONTAINS 'offer' // Previously returned lit(0) which caused type coercion error - let tolower_expr = ValueExpression::Function { + let tolower_expr = ValueExpression::ScalarFunction { name: "toLower".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "s".into(), @@ -1194,9 +1217,10 @@ mod tests { #[test] fn test_contains_aggregate_count() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "COUNT".into(), args: vec![ValueExpression::Literal(PropertyValue::String("*".into()))], + distinct: false, }; assert!( @@ -1207,12 +1231,13 @@ mod tests { #[test] fn test_contains_aggregate_sum() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "SUM".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), property: "value".into(), })], + distinct: false, }; assert!( @@ -1223,12 +1248,13 @@ mod tests { #[test] fn test_contains_aggregate_min() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "MIN".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), property: "value".into(), })], + distinct: false, }; assert!( @@ -1239,12 +1265,13 @@ mod tests { #[test] fn test_contains_aggregate_max() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "MAX".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), property: "value".into(), })], + distinct: false, }; assert!( @@ -1279,9 +1306,10 @@ mod tests { fn test_contains_aggregate_arithmetic_with_aggregate() { use crate::ast::ArithmeticOperator; let expr = ValueExpression::Arithmetic { - left: Box::new(ValueExpression::Function { + left: Box::new(ValueExpression::AggregateFunction { name: "COUNT".into(), args: vec![ValueExpression::Literal(PropertyValue::String("*".into()))], + distinct: false, }), operator: ArithmeticOperator::Multiply, right: Box::new(ValueExpression::Literal(PropertyValue::Integer(2))), @@ -1313,11 +1341,12 @@ mod tests { #[test] fn test_contains_aggregate_nested_function() { - let expr = ValueExpression::Function { + let expr = ValueExpression::ScalarFunction { name: "UPPER".into(), - args: vec![ValueExpression::Function { + args: vec![ValueExpression::AggregateFunction { name: "COUNT".into(), args: vec![ValueExpression::Literal(PropertyValue::String("*".into()))], + distinct: false, }], }; @@ -1344,9 +1373,10 @@ mod tests { #[test] fn test_cypher_column_name_function_count_star() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "COUNT".into(), args: vec![ValueExpression::Literal(PropertyValue::String("*".into()))], + distinct: false, }; let name = to_cypher_column_name(&expr); @@ -1360,12 +1390,13 @@ mod tests { #[test] fn test_cypher_column_name_function_count_property() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "COUNT".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "p".into(), property: "id".into(), })], + distinct: false, }; let name = to_cypher_column_name(&expr); @@ -1374,12 +1405,13 @@ mod tests { #[test] fn test_cypher_column_name_function_sum() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "SUM".into(), args: vec![ValueExpression::Property(PropertyRef { variable: "order".into(), property: "amount".into(), })], + distinct: false, }; let name = to_cypher_column_name(&expr); diff --git a/crates/lance-graph/src/parser.rs b/crates/lance-graph/src/parser.rs index 79a89916..8887ce23 100644 --- a/crates/lance-graph/src/parser.rs +++ b/crates/lance-graph/src/parser.rs @@ -601,6 +601,15 @@ fn function_call(input: &str) -> IResult<&str, ValueExpression> { let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; + // Parse optional DISTINCT keyword + let (input, distinct) = opt(tag_no_case("DISTINCT"))(input)?; + let distinct = distinct.is_some(); + let (input, _) = if distinct { + multispace1(input)? + } else { + (input, "") + }; + // Handle COUNT(*) special case - only allow * for COUNT function if let Ok((input_after_star, _)) = char::<_, nom::error::Error<&str>>('*')(input) { // Validate that this is COUNT function @@ -609,9 +618,10 @@ fn function_call(input: &str) -> IResult<&str, ValueExpression> { let (input, _) = char(')')(input)?; return Ok(( input, - ValueExpression::Function { + ValueExpression::AggregateFunction { name: name.to_string(), args: vec![ValueExpression::Variable("*".to_string())], + distinct, }, )); } else { @@ -628,13 +638,51 @@ fn function_call(input: &str) -> IResult<&str, ValueExpression> { let (input, _) = multispace0(input)?; let (input, _) = char(')')(input)?; - Ok(( - input, - ValueExpression::Function { - name: name.to_string(), - args, - }, - )) + // Route based on function type + use crate::ast::{classify_function, FunctionType}; + match classify_function(name) { + FunctionType::Aggregate => Ok(( + input, + ValueExpression::AggregateFunction { + name: name.to_string(), + args, + distinct, + }, + )), + FunctionType::Scalar => { + // Validate: reject DISTINCT on scalar functions at parse time + if distinct { + return Err(nom::Err::Failure(nom::error::Error::new( + input, + nom::error::ErrorKind::Verify, + ))); + } + Ok(( + input, + ValueExpression::ScalarFunction { + name: name.to_string(), + args, + }, + )) + } + FunctionType::Unknown => { + // Default to ScalarFunction for unknown functions + // They'll be handled as NULL in expression conversion + if distinct { + return Err(nom::Err::Failure(nom::error::Error::new( + input, + nom::error::ErrorKind::Verify, + ))); + } + Ok(( + input, + ValueExpression::ScalarFunction { + name: name.to_string(), + args, + }, + )) + } + } } fn value_expression_list(input: &str) -> IResult<&str, Vec> { @@ -1241,7 +1289,7 @@ mod tests { assert_eq!(item.alias, Some("total".to_string())); match &item.expression { - ValueExpression::Function { name, args } => { + ValueExpression::AggregateFunction { name, args, .. } => { assert_eq!(name, "count"); assert_eq!(args.len(), 1); match &args[0] { @@ -1249,7 +1297,7 @@ mod tests { _ => panic!("Expected Variable(*) in count(*)"), } } - _ => panic!("Expected Function expression"), + _ => panic!("Expected AggregateFunction expression"), } } @@ -1262,7 +1310,7 @@ mod tests { let item = &result.return_clause.items[0]; match &item.expression { - ValueExpression::Function { name, args } => { + ValueExpression::AggregateFunction { name, args, .. } => { assert_eq!(name, "count"); assert_eq!(args.len(), 1); match &args[0] { @@ -1273,7 +1321,7 @@ mod tests { _ => panic!("Expected Property in count(n.age)"), } } - _ => panic!("Expected Function expression"), + _ => panic!("Expected AggregateFunction expression"), } } @@ -1299,14 +1347,32 @@ mod tests { // Verify the AST structure let ast = result.unwrap(); match &ast.return_clause.items[0].expression { - ValueExpression::Function { name, args } => { + ValueExpression::AggregateFunction { name, args, .. } => { assert_eq!(name, "count"); assert_eq!(args.len(), 2); } - _ => panic!("Expected Function expression"), + _ => panic!("Expected AggregateFunction expression"), } } + #[test] + fn test_parser_rejects_distinct_on_scalar() { + // Parser should reject DISTINCT on scalar functions at parse time + let query = "RETURN toLower(DISTINCT p.name)"; + let result = parse_cypher_query(query); + assert!( + result.is_err(), + "Parser should reject DISTINCT on scalar functions" + ); + + let query2 = "RETURN upper(DISTINCT p.name)"; + let result2 = parse_cypher_query(query2); + assert!( + result2.is_err(), + "Parser should reject DISTINCT on scalar functions" + ); + } + #[test] fn test_parse_like_pattern() { let query = "MATCH (n:Person) WHERE n.name LIKE 'A%' RETURN n.name"; diff --git a/crates/lance-graph/src/semantic.rs b/crates/lance-graph/src/semantic.rs index edea213f..70ac9dc4 100644 --- a/crates/lance-graph/src/semantic.rs +++ b/crates/lance-graph/src/semantic.rs @@ -357,13 +357,76 @@ impl SemanticAnalyzer { }); } } - ValueExpression::Function { name, args } => { + ValueExpression::ScalarFunction { name, args } => { let function_name = name.to_lowercase(); + // Validate arity and known functions + match function_name.as_str() { + "tolower" | "lower" | "toupper" | "upper" => { + if args.len() != 1 { + return Err(GraphError::PlanError { + message: format!( + "{} requires exactly 1 argument, got {}", + name.to_uppercase(), + args.len() + ), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } + } + _ => { + // Unknown scalar function - reject early with helpful error + return Err(GraphError::UnsupportedFeature { + feature: format!( + "Cypher function '{}' is not implemented. Supported scalar functions: toLower, lower, toUpper, upper. Supported aggregate functions: COUNT, SUM, AVG, MIN, MAX, COLLECT.", + name + ), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } + } - // Validate function-specific arity and signature rules + // Validate arguments recursively + for arg in args { + self.analyze_value_expression(arg)?; + } + } + ValueExpression::AggregateFunction { + name, + args, + distinct, + } => { + let function_name = name.to_lowercase(); + // Validate known aggregate functions match function_name.as_str() { - // Aggregations "count" | "sum" | "avg" | "min" | "max" | "collect" => { + // DISTINCT is only supported for COUNT + // Other aggregates silently ignore it in execution, so reject early + if *distinct && function_name != "count" { + return Err(GraphError::UnsupportedFeature { + feature: format!( + "DISTINCT is only supported with COUNT, not {}", + function_name.to_uppercase() + ), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } + + // COUNT(DISTINCT *) is semantically meaningless + // It would count distinct values of lit(1) which is always 1 + if *distinct && function_name == "count" { + if let Some(ValueExpression::Variable(v)) = args.first() { + if v == "*" { + return Err(GraphError::PlanError { + message: "COUNT(DISTINCT *) is not supported. \ + Use COUNT(*) to count all rows, or \ + COUNT(DISTINCT property) to count distinct values." + .to_string(), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } + } + } + // All aggregates require exactly 1 argument if args.len() != 1 { return Err(GraphError::PlanError { message: format!( @@ -389,25 +452,12 @@ impl SemanticAnalyzer { } } } - // Scalar string functions - "tolower" | "lower" | "toupper" | "upper" => { - if args.len() != 1 { - return Err(GraphError::PlanError { - message: format!( - "{} requires exactly 1 argument, got {}", - function_name.to_uppercase(), - args.len() - ), - location: snafu::Location::new(file!(), line!(), column!()), - }); - } - } - // Unknown/unimplemented scalar function _ => { + // Unknown aggregate function - reject early return Err(GraphError::UnsupportedFeature { feature: format!( - "Cypher function '{}' is not implemented", - function_name + "Cypher aggregate function '{}' is not implemented. Supported aggregate functions: COUNT, SUM, AVG, MIN, MAX, COLLECT.", + name ), location: snafu::Location::new(file!(), line!(), column!()), }); @@ -1087,7 +1137,7 @@ mod tests { #[test] fn test_function_argument_undefined_variable_in_return() { // RETURN toUpper(m.name) - let expr = ValueExpression::Function { + let expr = ValueExpression::ScalarFunction { name: "toUpper".to_string(), args: vec![ValueExpression::Property(PropertyRef::new("m", "name"))], }; @@ -1101,7 +1151,7 @@ mod tests { #[test] fn test_function_argument_valid_variable_ok() { // MATCH (n:Person) RETURN toUpper(n.name) - let expr = ValueExpression::Function { + let expr = ValueExpression::ScalarFunction { name: "toUpper".to_string(), args: vec![ValueExpression::Property(PropertyRef::new("n", "name"))], }; @@ -1142,12 +1192,13 @@ mod tests { #[test] fn test_count_with_multiple_args_fails_validation() { // COUNT(n.age, n.name) should fail semantic validation - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "count".to_string(), args: vec![ ValueExpression::Property(PropertyRef::new("n", "age")), ValueExpression::Property(PropertyRef::new("n", "name")), ], + distinct: false, }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); assert!( @@ -1163,9 +1214,10 @@ mod tests { #[test] fn test_count_with_zero_args_fails_validation() { // COUNT() with no arguments should fail - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "count".to_string(), args: vec![], + distinct: false, }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); assert!( @@ -1181,9 +1233,10 @@ mod tests { #[test] fn test_count_with_one_arg_passes_validation() { // COUNT(n.age) should pass validation - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "count".to_string(), args: vec![ValueExpression::Property(PropertyRef::new("n", "age"))], + distinct: false, }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); assert!( @@ -1199,9 +1252,10 @@ mod tests { #[test] fn test_count_star_passes_validation() { // COUNT(*) should be allowed (special-cased in semantic analysis) - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "count".to_string(), args: vec![ValueExpression::Variable("*".to_string())], + distinct: false, }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); assert!( @@ -1213,11 +1267,12 @@ mod tests { #[test] fn test_unimplemented_scalar_function_fails_validation() { - let expr = ValueExpression::Function { + let expr = ValueExpression::ScalarFunction { name: "replace".to_string(), args: vec![ValueExpression::Property(PropertyRef::new("n", "name"))], }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); + // ScalarFunction with unknown name collects an error assert!( result .errors @@ -1230,9 +1285,10 @@ mod tests { #[test] fn test_sum_with_variable_fails_validation() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "sum".to_string(), args: vec![ValueExpression::Variable("n".to_string())], + distinct: false, }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); assert!( @@ -1252,9 +1308,10 @@ mod tests { #[test] fn test_avg_with_variable_fails_validation() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "avg".to_string(), args: vec![ValueExpression::Variable("n".to_string())], + distinct: false, }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); assert!( @@ -1274,9 +1331,10 @@ mod tests { #[test] fn test_sum_with_property_passes_validation() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "sum".to_string(), args: vec![ValueExpression::Property(PropertyRef::new("n", "age"))], + distinct: false, }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); assert!( @@ -1288,9 +1346,10 @@ mod tests { #[test] fn test_min_with_variable_fails_validation() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "min".to_string(), args: vec![ValueExpression::Variable("n".to_string())], + distinct: false, }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); assert!( @@ -1310,9 +1369,10 @@ mod tests { #[test] fn test_max_with_variable_fails_validation() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "max".to_string(), args: vec![ValueExpression::Variable("n".to_string())], + distinct: false, }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); assert!( @@ -1332,9 +1392,10 @@ mod tests { #[test] fn test_min_with_property_passes_validation() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "min".to_string(), args: vec![ValueExpression::Property(PropertyRef::new("n", "age"))], + distinct: false, }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); assert!( @@ -1346,9 +1407,10 @@ mod tests { #[test] fn test_max_with_property_passes_validation() { - let expr = ValueExpression::Function { + let expr = ValueExpression::AggregateFunction { name: "max".to_string(), args: vec![ValueExpression::Property(PropertyRef::new("n", "age"))], + distinct: false, }; let result = analyze_return_with_match("n", "Person", expr).unwrap(); assert!( @@ -1358,6 +1420,60 @@ mod tests { ); } + #[test] + fn test_distinct_only_supported_on_count() { + // SUM(DISTINCT n.age) should fail - DISTINCT only supported for COUNT + let expr = ValueExpression::AggregateFunction { + name: "sum".to_string(), + args: vec![ValueExpression::Property(PropertyRef::new("n", "age"))], + distinct: true, + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + result + .errors + .iter() + .any(|e| e.contains("DISTINCT is only supported with COUNT")), + "Expected error about DISTINCT only for COUNT, got: {:?}", + result.errors + ); + } + + #[test] + fn test_count_distinct_star_rejected() { + // COUNT(DISTINCT *) is semantically meaningless - should be rejected + let expr = ValueExpression::AggregateFunction { + name: "count".to_string(), + args: vec![ValueExpression::Variable("*".to_string())], + distinct: true, + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + result + .errors + .iter() + .any(|e| e.contains("COUNT(DISTINCT *)")), + "Expected error about COUNT(DISTINCT *), got: {:?}", + result.errors + ); + } + + #[test] + fn test_count_distinct_passes_validation() { + // COUNT(DISTINCT n.age) should pass + let expr = ValueExpression::AggregateFunction { + name: "count".to_string(), + args: vec![ValueExpression::Property(PropertyRef::new("n", "age"))], + distinct: true, + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + result.errors.is_empty(), + "COUNT(DISTINCT) should pass validation, got errors: {:?}", + result.errors + ); + } + #[test] fn test_arithmetic_with_non_numeric_literal_error() { // RETURN "x" + 1 diff --git a/crates/lance-graph/src/simple_executor/expr.rs b/crates/lance-graph/src/simple_executor/expr.rs index bcdd2915..21772dbe 100644 --- a/crates/lance-graph/src/simple_executor/expr.rs +++ b/crates/lance-graph/src/simple_executor/expr.rs @@ -169,7 +169,7 @@ pub(crate) fn to_df_value_expr_simple( VE::Property(prop) => col(&prop.property), VE::Variable(v) => col(v), VE::Literal(v) => to_df_literal(v), - VE::Function { name, args } => match name.to_lowercase().as_str() { + VE::ScalarFunction { name, args } => match name.to_lowercase().as_str() { "tolower" | "lower" => { if args.len() == 1 { lower().call(vec![to_df_value_expr_simple(&args[0])]) @@ -186,6 +186,10 @@ pub(crate) fn to_df_value_expr_simple( } _ => Expr::Literal(datafusion::scalar::ScalarValue::Null, None), }, + VE::AggregateFunction { .. } => { + // Aggregates not supported in simple executor + Expr::Literal(datafusion::scalar::ScalarValue::Null, None) + } VE::Arithmetic { left, operator, @@ -223,7 +227,7 @@ mod tests { #[test] fn test_simple_expr_unknown_function_returns_null() { - let expr = ValueExpression::Function { + let expr = ValueExpression::ScalarFunction { name: "replace".to_string(), args: vec![ValueExpression::Property(PropertyRef::new("p", "name"))], }; @@ -233,7 +237,7 @@ mod tests { #[test] fn test_simple_expr_lower_wrong_arity_returns_null() { - let expr = ValueExpression::Function { + let expr = ValueExpression::ScalarFunction { name: "lower".to_string(), args: vec![ ValueExpression::Property(PropertyRef::new("p", "name")), diff --git a/crates/lance-graph/tests/test_datafusion_pipeline.rs b/crates/lance-graph/tests/test_datafusion_pipeline.rs index ae439332..243db94b 100644 --- a/crates/lance-graph/tests/test_datafusion_pipeline.rs +++ b/crates/lance-graph/tests/test_datafusion_pipeline.rs @@ -2615,6 +2615,40 @@ async fn test_count_star_all_nodes() { assert_eq!(count_col.value(0), 5); } +/// COUNT(DISTINCT *) is rejected at semantic validation because it's semantically meaningless. +/// Without this check, it would return 1 (count of distinct lit(1) values) which is misleading. +#[tokio::test] +async fn test_count_distinct_star_rejected() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (a:Person) RETURN count(DISTINCT *) AS total") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await; + + // COUNT(DISTINCT *) should be rejected with a helpful error message + assert!( + result.is_err(), + "COUNT(DISTINCT *) should be rejected at semantic validation" + ); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("COUNT(DISTINCT *)") || err_msg.contains("not supported"), + "Error should mention COUNT(DISTINCT *), got: {}", + err_msg + ); +} + #[tokio::test] async fn test_count_variable() { let person_batch = create_person_dataset(); @@ -2826,6 +2860,169 @@ async fn test_count_property_without_alias_has_descriptive_name() { ); } +#[tokio::test] +async fn test_count_distinct_basic() { + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .with_relationship("KNOWS", "src_person_id", "dst_person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_dataset()); + datasets.insert("KNOWS".to_string(), create_knows_dataset()); + + // Test COUNT(DISTINCT source.id) - count unique people who know others + // KNOWS relationships: 1->2, 2->3, 3->4, 4->5, 1->3 + // Unique source persons: 1, 2, 3, 4 (4 distinct) + let query = CypherQuery::new( + "MATCH (source:Person)-[:KNOWS]->(target:Person) + RETURN COUNT(DISTINCT source.id) AS num_people", + ) + .unwrap() + .with_config(config); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + // Verify results + assert_eq!(result.num_rows(), 1); + + let num_people = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(num_people.value(0), 4); // 4 distinct source persons +} + +#[tokio::test] +async fn test_count_vs_count_distinct() { + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .with_relationship("KNOWS", "src_person_id", "dst_person_id") + .build() + .unwrap(); + + // Test COUNT (non-distinct) - count all KNOWS relationships from person 1 + // Person 1 has 2 KNOWS relationships: 1->2 and 1->3 + { + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_dataset()); + datasets.insert("KNOWS".to_string(), create_knows_dataset()); + + let query = CypherQuery::new( + "MATCH (source:Person)-[:KNOWS]->(target:Person) + WHERE source.id = 1 + RETURN COUNT(target.id) AS total_connections", + ) + .unwrap() + .with_config(config.clone()); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + let count = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(count.value(0), 2); // 2 connections (person 1 knows 2 people) + } + + // Test COUNT(DISTINCT) - should count unique target persons (same as COUNT in this case) + { + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_dataset()); + datasets.insert("KNOWS".to_string(), create_knows_dataset()); + + let query = CypherQuery::new( + "MATCH (source:Person)-[:KNOWS]->(target:Person) + WHERE source.id = 1 + RETURN COUNT(DISTINCT target.id) AS unique_connections", + ) + .unwrap() + .with_config(config); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + let count = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(count.value(0), 2); // 2 unique target persons (2 and 3) + } +} + +#[tokio::test] +async fn test_count_distinct_with_grouping() { + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .with_relationship("KNOWS", "src_person_id", "dst_person_id") + .build() + .unwrap(); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), create_person_dataset()); + datasets.insert("KNOWS".to_string(), create_knows_dataset()); + + // Group by target person and count distinct sources who know them + // KNOWS relationships: 1->2, 2->3, 3->4, 4->5, 1->3 + // Person 2 is known by: 1 (1 distinct) + // Person 3 is known by: 1, 2 (2 distinct) + // Person 4 is known by: 3 (1 distinct) + // Person 5 is known by: 4 (1 distinct) + let query = CypherQuery::new( + "MATCH (source:Person)-[:KNOWS]->(target:Person) + RETURN target.id AS target_id, COUNT(DISTINCT source.id) AS num_sources + ORDER BY target_id", + ) + .unwrap() + .with_config(config); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + assert_eq!(result.num_rows(), 4); // 4 different targets (2, 3, 4, 5) + + let target_ids = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let counts = result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + // Person 2 is known by 1 person (person 1) + assert_eq!(target_ids.value(0), 2); + assert_eq!(counts.value(0), 1); + + // Person 3 is known by 2 people (persons 1 and 2) + assert_eq!(target_ids.value(1), 3); + assert_eq!(counts.value(1), 2); + + // Person 4 is known by 1 person (person 3) + assert_eq!(target_ids.value(2), 4); + assert_eq!(counts.value(2), 1); + + // Person 5 is known by 1 person (person 4) + assert_eq!(target_ids.value(3), 5); + assert_eq!(counts.value(3), 1); +} + #[tokio::test] async fn test_sum_property() { let person_batch = create_person_dataset(); @@ -4593,6 +4790,35 @@ async fn test_unwind_simple_list() { } } +/// Test that COUNT(x) where x comes from UNWIND fails gracefully +/// This documents a known limitation: COUNT(variable) assumes variable__id exists, +/// but UNWIND-created variables don't have __id columns. +#[tokio::test] +async fn test_count_unwind_variable_known_limitation() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + + // COUNT(x) where x comes from UNWIND - x has no __id column + let query = CypherQuery::new("UNWIND [1, 2, 3] AS x RETURN COUNT(x)") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await; + + // This currently fails because COUNT(x) looks for x__id which doesn't exist + // If this starts passing in the future, update the test to verify correct count (3) + assert!( + result.is_err(), + "COUNT(unwind_variable) should fail - no __id column. If this passes, \ + the limitation has been fixed and test should verify COUNT returns 3" + ); +} + #[tokio::test] async fn test_unwind_after_match() { let config = create_graph_config();