diff --git a/crates/lance-graph/src/datafusion_planner/expression.rs b/crates/lance-graph/src/datafusion_planner/expression.rs index 954b5465..9372afed 100644 --- a/crates/lance-graph/src/datafusion_planner/expression.rs +++ b/crates/lance-graph/src/datafusion_planner/expression.rs @@ -220,6 +220,10 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { _ => { // Unsupported function - return NULL which coerces to any type // This prevents type coercion errors in both string and numeric contexts + // + // TODO(#107): Now that semantic analysis rejects unknown functions, consider + // upgrading this to a hard internal error (e.g. `unreachable!()` or returning + // a planner/execution error) to catch validator regressions early. Expr::Literal(datafusion::scalar::ScalarValue::Null, None) } } diff --git a/crates/lance-graph/src/query.rs b/crates/lance-graph/src/query.rs index 281a9fa0..a7e15ae1 100644 --- a/crates/lance-graph/src/query.rs +++ b/crates/lance-graph/src/query.rs @@ -721,7 +721,13 @@ impl CypherQuery { // Phase 1: Semantic Analysis let mut analyzer = SemanticAnalyzer::new(config.clone()); - analyzer.analyze(&self.ast)?; + let semantic = analyzer.analyze(&self.ast)?; + if !semantic.errors.is_empty() { + return Err(GraphError::PlanError { + message: format!("Semantic analysis failed:\n{}", semantic.errors.join("\n")), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } // Phase 2: Graph Logical Plan let mut logical_planner = LogicalPlanner::new(); @@ -878,13 +884,24 @@ impl CypherQuery { &self, datasets: HashMap, ) -> Result { + use crate::semantic::SemanticAnalyzer; use arrow::compute::concat_batches; use datafusion::datasource::MemTable; use datafusion::prelude::*; use std::sync::Arc; // Require a config for now, even if we don't fully exploit it yet - let _config = self.require_config()?; + let config = self.require_config()?.clone(); + + // Ensure we don't silently ignore unsupported features (e.g. scalar functions). + let mut analyzer = SemanticAnalyzer::new(config); + let semantic = analyzer.analyze(&self.ast)?; + if !semantic.errors.is_empty() { + return Err(GraphError::PlanError { + message: format!("Semantic analysis failed:\n{}", semantic.errors.join("\n")), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } if datasets.is_empty() { return Err(GraphError::PlanError { @@ -958,7 +975,14 @@ impl CypherQuery { .return_clause .items .iter() - .map(|item| to_df_value_expr_simple(&item.expression)) + .map(|item| { + let expr = to_df_value_expr_simple(&item.expression); + if let Some(alias) = &item.alias { + expr.alias(alias) + } else { + expr + } + }) .collect(); if !proj_exprs.is_empty() { df = df.select(proj_exprs).map_err(|e| GraphError::PlanError { @@ -2249,4 +2273,41 @@ mod tests { "expected unsupported feature error, got {err:?}" ); } + + #[tokio::test] + async fn test_execute_fails_on_semantic_error() { + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Schema}; + use std::collections::HashMap; + use std::sync::Arc; + + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + let batch = RecordBatch::new_empty(schema); + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), batch); + + let cfg = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + // Query referencing undefined variable 'x' + let query = CypherQuery::new("MATCH (n:Person) RETURN x.name") + .unwrap() + .with_config(cfg); + + let result = query.execute_simple(datasets).await; + + assert!(result.is_err()); + match result { + Err(GraphError::PlanError { message, .. }) => { + assert!(message.contains("Semantic analysis failed")); + assert!(message.contains("Undefined variable: 'x'")); + } + _ => panic!( + "Expected PlanError with semantic failure message, got {:?}", + result + ), + } + } } diff --git a/crates/lance-graph/src/semantic.rs b/crates/lance-graph/src/semantic.rs index 926fc0c9..edea213f 100644 --- a/crates/lance-graph/src/semantic.rs +++ b/crates/lance-graph/src/semantic.rs @@ -53,7 +53,6 @@ pub enum ScopeType { /// Semantic analysis result with validated and enriched AST #[derive(Debug, Clone)] pub struct SemanticResult { - pub query: CypherQuery, pub variables: HashMap, pub errors: Vec, pub warnings: Vec, @@ -106,6 +105,23 @@ impl SemanticAnalyzer { } } + // Phase 4: Variable discovery in post-WITH READING clauses (query chaining) + self.current_scope = ScopeType::Match; + for clause in &query.post_with_reading_clauses { + match clause { + ReadingClause::Match(match_clause) => { + if let Err(e) = self.analyze_match_clause(match_clause) { + errors.push(format!("Post-WITH MATCH clause error: {}", e)); + } + } + ReadingClause::Unwind(unwind_clause) => { + if let Err(e) = self.analyze_unwind_clause(unwind_clause) { + errors.push(format!("Post-WITH UNWIND clause error: {}", e)); + } + } + } + } + // Phase 4: Validate post-WITH WHERE clause if present if let Some(post_where) = &query.post_with_where_clause { self.current_scope = ScopeType::PostWithWhere; @@ -135,7 +151,6 @@ impl SemanticAnalyzer { self.validate_types(&mut errors); Ok(SemanticResult { - query: query.clone(), variables: self.variables.clone(), errors, warnings, @@ -343,14 +358,17 @@ impl SemanticAnalyzer { } } ValueExpression::Function { name, args } => { + let function_name = name.to_lowercase(); + // Validate function-specific arity and signature rules - match name.to_lowercase().as_str() { + match function_name.as_str() { + // Aggregations "count" | "sum" | "avg" | "min" | "max" | "collect" => { if args.len() != 1 { return Err(GraphError::PlanError { message: format!( "{} requires exactly 1 argument, got {}", - name.to_uppercase(), + function_name.to_uppercase(), args.len() ), location: snafu::Location::new(file!(), line!(), column!()), @@ -359,25 +377,51 @@ impl SemanticAnalyzer { // Additional validation for SUM, AVG, MIN, MAX: they require properties, not bare variables // Only COUNT and COLLECT allow bare variables (COUNT(*), COUNT(p), COLLECT(p)) - if matches!(name.to_lowercase().as_str(), "sum" | "avg" | "min" | "max") { + if matches!(function_name.as_str(), "sum" | "avg" | "min" | "max") { if let Some(ValueExpression::Variable(v)) = args.first() { return Err(GraphError::PlanError { message: format!( "{}({}) is invalid - {} requires a property like {}({}.property). You cannot {} a node/entity.", - name.to_uppercase(), v, name.to_uppercase(), name.to_uppercase(), v, name.to_lowercase() + function_name.to_uppercase(), v, function_name.to_uppercase(), function_name.to_uppercase(), v, function_name ), location: snafu::Location::new(file!(), line!(), column!()), }); } } } + // Scalar string functions + "tolower" | "lower" | "toupper" | "upper" => { + if args.len() != 1 { + return Err(GraphError::PlanError { + message: format!( + "{} requires exactly 1 argument, got {}", + function_name.to_uppercase(), + args.len() + ), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } + } + // Unknown/unimplemented scalar function _ => { - // Other functions - no validation yet + return Err(GraphError::UnsupportedFeature { + feature: format!( + "Cypher function '{}' is not implemented", + function_name + ), + location: snafu::Location::new(file!(), line!(), column!()), + }); } } - // Validate arguments recursively + // Validate arguments recursively. + // Special-case COUNT(*) where '*' isn't a real variable. for arg in args { + if function_name == "count" + && matches!(arg, ValueExpression::Variable(v) if v == "*") + { + continue; + } self.analyze_value_expression(arg)?; } } @@ -453,6 +497,21 @@ impl SemanticAnalyzer { Ok(()) } + fn register_projection_alias(&mut self, alias: &str) { + if self.variables.contains_key(alias) { + return; + } + + let var_info = VariableInfo { + name: alias.to_string(), + variable_type: VariableType::Property, + labels: vec![], + properties: HashSet::new(), + defined_in: self.current_scope.clone(), + }; + self.variables.insert(alias.to_string(), var_info); + } + /// Validate property reference fn validate_property_reference(&self, prop_ref: &PropertyRef) -> Result<()> { if !self.variables.contains_key(&prop_ref.variable) { @@ -468,6 +527,9 @@ impl SemanticAnalyzer { fn analyze_return_clause(&mut self, return_clause: &ReturnClause) -> Result<()> { for item in &return_clause.items { self.analyze_value_expression(&item.expression)?; + if let Some(alias) = &item.alias { + self.register_projection_alias(alias); + } } Ok(()) } @@ -477,6 +539,9 @@ impl SemanticAnalyzer { // Validate WITH item expressions (similar to RETURN) for item in &with_clause.items { self.analyze_value_expression(&item.expression)?; + if let Some(alias) = &item.alias { + self.register_projection_alias(alias); + } } // Validate ORDER BY within WITH if present if let Some(order_by) = &with_clause.order_by { @@ -1131,6 +1196,38 @@ mod tests { ); } + #[test] + fn test_count_star_passes_validation() { + // COUNT(*) should be allowed (special-cased in semantic analysis) + let expr = ValueExpression::Function { + name: "count".to_string(), + args: vec![ValueExpression::Variable("*".to_string())], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + result.errors.is_empty(), + "Expected COUNT(*) to pass semantic validation, got: {:?}", + result.errors + ); + } + + #[test] + fn test_unimplemented_scalar_function_fails_validation() { + let expr = ValueExpression::Function { + name: "replace".to_string(), + args: vec![ValueExpression::Property(PropertyRef::new("n", "name"))], + }; + let result = analyze_return_with_match("n", "Person", expr).unwrap(); + assert!( + result + .errors + .iter() + .any(|e| e.to_lowercase().contains("not implemented")), + "Expected semantic validation to reject unimplemented function, got: {:?}", + result.errors + ); + } + #[test] fn test_sum_with_variable_fails_validation() { let expr = ValueExpression::Function { diff --git a/crates/lance-graph/src/simple_executor/expr.rs b/crates/lance-graph/src/simple_executor/expr.rs index ec1fb694..bcdd2915 100644 --- a/crates/lance-graph/src/simple_executor/expr.rs +++ b/crates/lance-graph/src/simple_executor/expr.rs @@ -163,15 +163,97 @@ pub(crate) fn to_df_value_expr_simple( expr: &crate::ast::ValueExpression, ) -> datafusion::logical_expr::Expr { use crate::ast::ValueExpression as VE; - use datafusion::logical_expr::{col, lit}; + use datafusion::functions::string::{lower, upper}; + use datafusion::logical_expr::{col, lit, BinaryExpr, Expr, Operator}; match expr { VE::Property(prop) => col(&prop.property), VE::Variable(v) => col(v), VE::Literal(v) => to_df_literal(v), - VE::Function { .. } | VE::Arithmetic { .. } => lit(0), + VE::Function { name, args } => match name.to_lowercase().as_str() { + "tolower" | "lower" => { + if args.len() == 1 { + lower().call(vec![to_df_value_expr_simple(&args[0])]) + } else { + Expr::Literal(datafusion::scalar::ScalarValue::Null, None) + } + } + "toupper" | "upper" => { + if args.len() == 1 { + upper().call(vec![to_df_value_expr_simple(&args[0])]) + } else { + Expr::Literal(datafusion::scalar::ScalarValue::Null, None) + } + } + _ => Expr::Literal(datafusion::scalar::ScalarValue::Null, None), + }, + VE::Arithmetic { + left, + operator, + right, + } => { + use crate::ast::ArithmeticOperator as AO; + let l = to_df_value_expr_simple(left); + let r = to_df_value_expr_simple(right); + let op = match operator { + AO::Add => Operator::Plus, + AO::Subtract => Operator::Minus, + AO::Multiply => Operator::Multiply, + AO::Divide => Operator::Divide, + AO::Modulo => Operator::Modulo, + }; + Expr::BinaryExpr(BinaryExpr { + left: Box::new(l), + op, + right: Box::new(r), + }) + } VE::VectorDistance { .. } => lit(0.0f32), VE::VectorSimilarity { .. } => lit(1.0f32), VE::Parameter(_) => lit(0), VE::VectorLiteral(_) => lit(0.0f32), } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::{ArithmeticOperator, PropertyRef, ValueExpression}; + use datafusion::logical_expr::Expr; + use datafusion::scalar::ScalarValue; + + #[test] + fn test_simple_expr_unknown_function_returns_null() { + let expr = ValueExpression::Function { + name: "replace".to_string(), + args: vec![ValueExpression::Property(PropertyRef::new("p", "name"))], + }; + let df_expr = to_df_value_expr_simple(&expr); + assert!(matches!(df_expr, Expr::Literal(ScalarValue::Null, _))); + } + + #[test] + fn test_simple_expr_lower_wrong_arity_returns_null() { + let expr = ValueExpression::Function { + name: "lower".to_string(), + args: vec![ + ValueExpression::Property(PropertyRef::new("p", "name")), + ValueExpression::Property(PropertyRef::new("p", "name")), + ], + }; + let df_expr = to_df_value_expr_simple(&expr); + assert!(matches!(df_expr, Expr::Literal(ScalarValue::Null, _))); + } + + #[test] + fn test_simple_expr_arithmetic_builds_binary_expr() { + let expr = ValueExpression::Arithmetic { + left: Box::new(ValueExpression::Variable("x".to_string())), + operator: ArithmeticOperator::Add, + right: Box::new(ValueExpression::Literal( + crate::ast::PropertyValue::Integer(1), + )), + }; + let df_expr = to_df_value_expr_simple(&expr); + assert!(matches!(df_expr, Expr::BinaryExpr(_))); + } +} diff --git a/crates/lance-graph/tests/test_datafusion_pipeline.rs b/crates/lance-graph/tests/test_datafusion_pipeline.rs index b9725752..ae439332 100644 --- a/crates/lance-graph/tests/test_datafusion_pipeline.rs +++ b/crates/lance-graph/tests/test_datafusion_pipeline.rs @@ -4507,6 +4507,42 @@ async fn test_with_post_match_chaining() { assert!(result.column_by_name("cnt").is_some()); } +// ============================================================================ +// Scalar Function / Semantic Validation Regression Tests +// ============================================================================ + +#[tokio::test] +async fn test_unimplemented_scalar_function_errors() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new( + "MATCH (p:Person) RETURN p.name AS name, replace(p.name, 'A', 'a') AS replaced", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let err = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .expect_err("replace() should error until implemented"); + + let message = err.to_string().to_lowercase(); + assert!(message.contains("replace"), "unexpected error: {err}"); + assert!( + message.contains("not implemented") || message.contains("unsupported"), + "unexpected error: {err}" + ); +} + +// NOTE: Simple executor tests live in `tests/test_simple_executor_pipeline.rs`. + // ============================================================================ // UNWIND Tests // ============================================================================ diff --git a/crates/lance-graph/tests/test_simple_executor_pipeline.rs b/crates/lance-graph/tests/test_simple_executor_pipeline.rs new file mode 100644 index 00000000..097b1166 --- /dev/null +++ b/crates/lance-graph/tests/test_simple_executor_pipeline.rs @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{collections::HashMap, sync::Arc}; + +use arrow_array::{Int64Array, RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use lance_graph::{CypherQuery, ExecutionStrategy, GraphConfig}; + +fn create_person_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + let ids = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])); + let names = Arc::new(StringArray::from(vec![ + "Alice", "Bob", "Charlie", "David", "Eve", + ])); + + RecordBatch::try_new(schema, vec![ids, names]).unwrap() +} + +#[tokio::test] +async fn test_tolower_works_in_simple_executor() { + let person_batch = create_person_batch(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new( + "MATCH (p:Person) RETURN p.name AS name, tolower(p.name) AS lowered ORDER BY name", + ) + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::Simple)) + .await + .unwrap(); + + let names = result + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let lowered = result + .column_by_name("lowered") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let got: Vec<(String, String)> = (0..result.num_rows()) + .map(|i| (names.value(i).to_string(), lowered.value(i).to_string())) + .collect(); + + assert_eq!( + got, + vec![ + ("Alice".to_string(), "alice".to_string()), + ("Bob".to_string(), "bob".to_string()), + ("Charlie".to_string(), "charlie".to_string()), + ("David".to_string(), "david".to_string()), + ("Eve".to_string(), "eve".to_string()), + ] + ); +} diff --git a/python/python/tests/test_scalar_functions.py b/python/python/tests/test_scalar_functions.py new file mode 100644 index 00000000..8d3a0899 --- /dev/null +++ b/python/python/tests/test_scalar_functions.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +import pyarrow as pa +import pytest +from lance_graph import CypherQuery, GraphConfig + + +def test_unimplemented_scalar_function_errors() -> None: + cfg = GraphConfig.builder().with_node_label("Person", "name").build() + datasets = {"Person": pa.table({"name": ["Alice", "BOB", "CaSeY"]})} + + query = CypherQuery( + "MATCH (p:Person) " + "RETURN p.name AS name, replace(p.name, 'A', 'a') AS replaced " + "ORDER BY name", + ).with_config(cfg) + + with pytest.raises(Exception) as excinfo: + query.execute(datasets) + + assert "replace" in str(excinfo.value).lower()