Skip to content
4 changes: 4 additions & 0 deletions crates/lance-graph/src/datafusion_planner/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr {
_ => {
// Unsupported function - return NULL which coerces to any type
// This prevents type coercion errors in both string and numeric contexts
//
// TODO(#107): Now that semantic analysis rejects unknown functions, consider
// upgrading this to a hard internal error (e.g. `unreachable!()` or returning
// a planner/execution error) to catch validator regressions early.
Expr::Literal(datafusion::scalar::ScalarValue::Null, None)
}
}
Expand Down
67 changes: 64 additions & 3 deletions crates/lance-graph/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,13 @@ impl CypherQuery {

// Phase 1: Semantic Analysis
let mut analyzer = SemanticAnalyzer::new(config.clone());
analyzer.analyze(&self.ast)?;
let semantic = analyzer.analyze(&self.ast)?;
if !semantic.errors.is_empty() {
return Err(GraphError::PlanError {
message: format!("Semantic analysis failed:\n{}", semantic.errors.join("\n")),
location: snafu::Location::new(file!(), line!(), column!()),
});
}

// Phase 2: Graph Logical Plan
let mut logical_planner = LogicalPlanner::new();
Expand Down Expand Up @@ -878,13 +884,24 @@ impl CypherQuery {
&self,
datasets: HashMap<String, arrow::record_batch::RecordBatch>,
) -> Result<arrow::record_batch::RecordBatch> {
use crate::semantic::SemanticAnalyzer;
use arrow::compute::concat_batches;
use datafusion::datasource::MemTable;
use datafusion::prelude::*;
use std::sync::Arc;

// Require a config for now, even if we don't fully exploit it yet
let _config = self.require_config()?;
let config = self.require_config()?.clone();

// Ensure we don't silently ignore unsupported features (e.g. scalar functions).
let mut analyzer = SemanticAnalyzer::new(config);
let semantic = analyzer.analyze(&self.ast)?;
if !semantic.errors.is_empty() {
return Err(GraphError::PlanError {
message: format!("Semantic analysis failed:\n{}", semantic.errors.join("\n")),
location: snafu::Location::new(file!(), line!(), column!()),
});
}

if datasets.is_empty() {
return Err(GraphError::PlanError {
Expand Down Expand Up @@ -958,7 +975,14 @@ impl CypherQuery {
.return_clause
.items
.iter()
.map(|item| to_df_value_expr_simple(&item.expression))
.map(|item| {
let expr = to_df_value_expr_simple(&item.expression);
if let Some(alias) = &item.alias {
expr.alias(alias)
} else {
expr
}
})
.collect();
if !proj_exprs.is_empty() {
df = df.select(proj_exprs).map_err(|e| GraphError::PlanError {
Expand Down Expand Up @@ -2249,4 +2273,41 @@ mod tests {
"expected unsupported feature error, got {err:?}"
);
}

#[tokio::test]
async fn test_execute_fails_on_semantic_error() {
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema};
use std::collections::HashMap;
use std::sync::Arc;

let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let batch = RecordBatch::new_empty(schema);
let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), batch);

let cfg = GraphConfig::builder()
.with_node_label("Person", "id")
.build()
.unwrap();

// Query referencing undefined variable 'x'
let query = CypherQuery::new("MATCH (n:Person) RETURN x.name")
.unwrap()
.with_config(cfg);

let result = query.execute_simple(datasets).await;

assert!(result.is_err());
match result {
Err(GraphError::PlanError { message, .. }) => {
assert!(message.contains("Semantic analysis failed"));
assert!(message.contains("Undefined variable: 'x'"));
}
_ => panic!(
"Expected PlanError with semantic failure message, got {:?}",
result
),
}
}
}
113 changes: 105 additions & 8 deletions crates/lance-graph/src/semantic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ pub enum ScopeType {
/// Semantic analysis result with validated and enriched AST
#[derive(Debug, Clone)]
pub struct SemanticResult {
pub query: CypherQuery,
pub variables: HashMap<String, VariableInfo>,
pub errors: Vec<String>,
pub warnings: Vec<String>,
Expand Down Expand Up @@ -106,6 +105,23 @@ impl SemanticAnalyzer {
}
}

// Phase 4: Variable discovery in post-WITH READING clauses (query chaining)
self.current_scope = ScopeType::Match;
for clause in &query.post_with_reading_clauses {
match clause {
ReadingClause::Match(match_clause) => {
if let Err(e) = self.analyze_match_clause(match_clause) {
errors.push(format!("Post-WITH MATCH clause error: {}", e));
}
}
ReadingClause::Unwind(unwind_clause) => {
if let Err(e) = self.analyze_unwind_clause(unwind_clause) {
errors.push(format!("Post-WITH UNWIND clause error: {}", e));
}
}
}
}

// Phase 4: Validate post-WITH WHERE clause if present
if let Some(post_where) = &query.post_with_where_clause {
self.current_scope = ScopeType::PostWithWhere;
Expand Down Expand Up @@ -135,7 +151,6 @@ impl SemanticAnalyzer {
self.validate_types(&mut errors);

Ok(SemanticResult {
query: query.clone(),
variables: self.variables.clone(),
errors,
warnings,
Expand Down Expand Up @@ -343,14 +358,17 @@ impl SemanticAnalyzer {
}
}
ValueExpression::Function { name, args } => {
let function_name = name.to_lowercase();

// Validate function-specific arity and signature rules
match name.to_lowercase().as_str() {
match function_name.as_str() {
// Aggregations
"count" | "sum" | "avg" | "min" | "max" | "collect" => {
if args.len() != 1 {
return Err(GraphError::PlanError {
message: format!(
"{} requires exactly 1 argument, got {}",
name.to_uppercase(),
function_name.to_uppercase(),
args.len()
),
location: snafu::Location::new(file!(), line!(), column!()),
Expand All @@ -359,25 +377,51 @@ impl SemanticAnalyzer {

// Additional validation for SUM, AVG, MIN, MAX: they require properties, not bare variables
// Only COUNT and COLLECT allow bare variables (COUNT(*), COUNT(p), COLLECT(p))
if matches!(name.to_lowercase().as_str(), "sum" | "avg" | "min" | "max") {
if matches!(function_name.as_str(), "sum" | "avg" | "min" | "max") {
if let Some(ValueExpression::Variable(v)) = args.first() {
return Err(GraphError::PlanError {
message: format!(
"{}({}) is invalid - {} requires a property like {}({}.property). You cannot {} a node/entity.",
name.to_uppercase(), v, name.to_uppercase(), name.to_uppercase(), v, name.to_lowercase()
function_name.to_uppercase(), v, function_name.to_uppercase(), function_name.to_uppercase(), v, function_name
),
location: snafu::Location::new(file!(), line!(), column!()),
});
}
}
}
// Scalar string functions
"tolower" | "lower" | "toupper" | "upper" => {
if args.len() != 1 {
return Err(GraphError::PlanError {
message: format!(
"{} requires exactly 1 argument, got {}",
function_name.to_uppercase(),
args.len()
),
location: snafu::Location::new(file!(), line!(), column!()),
});
}
}
// Unknown/unimplemented scalar function
_ => {
// Other functions - no validation yet
return Err(GraphError::UnsupportedFeature {
feature: format!(
"Cypher function '{}' is not implemented",
function_name
),
location: snafu::Location::new(file!(), line!(), column!()),
});
}
}

// Validate arguments recursively
// Validate arguments recursively.
// Special-case COUNT(*) where '*' isn't a real variable.
for arg in args {
if function_name == "count"
&& matches!(arg, ValueExpression::Variable(v) if v == "*")
{
continue;
}
self.analyze_value_expression(arg)?;
}
}
Expand Down Expand Up @@ -453,6 +497,21 @@ impl SemanticAnalyzer {
Ok(())
}

fn register_projection_alias(&mut self, alias: &str) {
if self.variables.contains_key(alias) {
return;
}

let var_info = VariableInfo {
name: alias.to_string(),
variable_type: VariableType::Property,
labels: vec![],
properties: HashSet::new(),
defined_in: self.current_scope.clone(),
};
self.variables.insert(alias.to_string(), var_info);
}

/// Validate property reference
fn validate_property_reference(&self, prop_ref: &PropertyRef) -> Result<()> {
if !self.variables.contains_key(&prop_ref.variable) {
Expand All @@ -468,6 +527,9 @@ impl SemanticAnalyzer {
fn analyze_return_clause(&mut self, return_clause: &ReturnClause) -> Result<()> {
for item in &return_clause.items {
self.analyze_value_expression(&item.expression)?;
if let Some(alias) = &item.alias {
self.register_projection_alias(alias);
}
}
Ok(())
}
Expand All @@ -477,6 +539,9 @@ impl SemanticAnalyzer {
// Validate WITH item expressions (similar to RETURN)
for item in &with_clause.items {
self.analyze_value_expression(&item.expression)?;
if let Some(alias) = &item.alias {
self.register_projection_alias(alias);
}
}
// Validate ORDER BY within WITH if present
if let Some(order_by) = &with_clause.order_by {
Expand Down Expand Up @@ -1131,6 +1196,38 @@ mod tests {
);
}

#[test]
fn test_count_star_passes_validation() {
// COUNT(*) should be allowed (special-cased in semantic analysis)
let expr = ValueExpression::Function {
name: "count".to_string(),
args: vec![ValueExpression::Variable("*".to_string())],
};
let result = analyze_return_with_match("n", "Person", expr).unwrap();
assert!(
result.errors.is_empty(),
"Expected COUNT(*) to pass semantic validation, got: {:?}",
result.errors
);
}

#[test]
fn test_unimplemented_scalar_function_fails_validation() {
let expr = ValueExpression::Function {
name: "replace".to_string(),
args: vec![ValueExpression::Property(PropertyRef::new("n", "name"))],
};
let result = analyze_return_with_match("n", "Person", expr).unwrap();
assert!(
result
.errors
.iter()
.any(|e| e.to_lowercase().contains("not implemented")),
"Expected semantic validation to reject unimplemented function, got: {:?}",
result.errors
);
}

#[test]
fn test_sum_with_variable_fails_validation() {
let expr = ValueExpression::Function {
Expand Down
Loading
Loading