Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 153 additions & 22 deletions rust/lance-graph/src/datafusion_planner/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,16 +686,27 @@ impl DataFusionPlanner {
/// generates join keys based on the id fields of those shared variables.
///
/// Supports both node variables and relationship variables:
/// - Node variables: Join on node ID field (e.g., `b__id`)
/// - Relationship variables: Currently unsupported - returns empty keys
/// - **Node variables**: Join on node ID field (e.g., `b__id`)
/// - **Relationship variables**: Join on composite keys (src_id AND dst_id)
///
/// # Example
/// # Examples
///
/// **Node variable join:**
/// ```text
/// Left pattern: (a:Person)-[:KNOWS]->(b:Person) -> variables: [a, b]
/// Right pattern: (b:Person)-[:WORKS_AT]->(c:Company) -> variables: [b, c]
/// Left: (a:Person)-[:KNOWS]->(b:Person) -> variables: [a, b]
/// Right: (b:Person)-[:WORKS_AT]->(c:Company) -> variables: [b, c]
/// Shared: [b]
/// Result: (left_keys=["b__id"], right_keys=["b__id"])
/// ```
///
/// **Relationship variable join:**
/// ```text
/// Left: (a:Person)-[r:KNOWS]->(b:Person) -> variables: [a, b, r]
/// Right: (c:Person)-[r:KNOWS]->(d:Person) -> variables: [c, d, r]
/// Shared: [r]
/// Result: (left_keys=["r__src_id", "r__dst_id"],
/// right_keys=["r__src_id", "r__dst_id"])
/// ```
fn infer_join_keys(
&self,
ctx: &PlanningContext,
Expand Down Expand Up @@ -735,24 +746,39 @@ impl DataFusionPlanner {
left_keys.push(left_key);
right_keys.push(right_key);
}
} else {
// Not a node variable - check if it's a relationship variable
// Look up the relationship instance by its alias (the variable name)
if let Some(rel_instance) = ctx
.analysis
.relationship_instances
.iter()
.find(|r| r.alias == *var)
{
// Get the relationship mapping to find src/dst field names
if let Some(rel_map) = self
.config
.relationship_mappings
.get(&rel_instance.rel_type)
{
// Generate composite join keys for both src_id and dst_id
// This ensures we're matching the exact same relationship instance
// The columns are qualified as: {alias}__{original_field_name}
// Example: var="r", source_id_field="src_person_id"
// -> "r__src_person_id"
let left_src = format!("{}__{}", var, &rel_map.source_id_field);
let right_src = format!("{}__{}", var, &rel_map.source_id_field);
let left_dst = format!("{}__{}", var, &rel_map.target_id_field);
let right_dst = format!("{}__{}", var, &rel_map.target_id_field);

left_keys.push(left_src);
right_keys.push(right_src);
left_keys.push(left_dst);
right_keys.push(right_dst);
}
}
// If not found in either node or relationship variables, skip it
}
// If not a node variable, it might be a relationship variable
// TODO: Implement relationship variable join key generation
//
// For now, we skip relationship variables (they won't generate keys).
// This means patterns with only shared relationship variables will fall back
// to cross join (or error for outer joins).
//
// To implement this:
// 1. Look up the relationship instance in ctx.analysis.relationship_instances
// using the variable name as the key
// 2. Get the relationship mapping from self.config.relationship_mappings
// using the relationship type
// 3. Generate join keys based on a unique relationship ID column
// (may need to add an ID field to RelationshipMapping if not present)
// 4. Consider how to handle the fact that relationships are represented as
// joins in the physical plan - you may need to join on both src_id and dst_id
// to ensure the same relationship instance is matched
}

(left_keys, right_keys)
Expand Down Expand Up @@ -2246,4 +2272,109 @@ mod tests {
"Shared variables should include 'r'"
);
}

#[test]
fn test_relationship_variable_join_key_inference() {
// Test that the join key inference logic correctly handles relationship variables
//
// Note: This tests the key generation logic, not the full plan execution.
// In practice, joining on shared relationship variables across disconnected patterns
// doesn't make semantic sense in Cypher (a relationship can't have two sources).
//
// The implementation correctly:
// 1. Detects relationship variables in both patterns
// 2. Generates composite keys (src_id + dst_id) for relationship variables
// 3. Generates single keys for node variables
use crate::datafusion_planner::analysis;
use crate::logical_plan::LogicalOperator;

let cfg = crate::config::GraphConfig::builder()
.with_node_label("Person", "id")
.with_relationship("KNOWS", "src_person_id", "dst_person_id")
.build()
.unwrap();
let planner = DataFusionPlanner::with_catalog(cfg, make_catalog());

// Left: (a:Person)-[r1:KNOWS]->(b:Person)
let scan_a = LogicalOperator::ScanByLabel {
variable: "a".to_string(),
label: "Person".to_string(),
properties: Default::default(),
};
let expand_left = LogicalOperator::Expand {
input: Box::new(scan_a),
source_variable: "a".to_string(),
target_variable: "b".to_string(),
target_label: "Person".to_string(),
relationship_types: vec!["KNOWS".to_string()],
direction: crate::ast::RelationshipDirection::Outgoing,
relationship_variable: Some("r1".to_string()),
properties: Default::default(),
target_properties: Default::default(),
};

// Right: (b:Person)-[r2:KNOWS]->(c:Person) - shares node 'b'
let scan_b = LogicalOperator::ScanByLabel {
variable: "b".to_string(),
label: "Person".to_string(),
properties: Default::default(),
};
let expand_right = LogicalOperator::Expand {
input: Box::new(scan_b),
source_variable: "b".to_string(),
target_variable: "c".to_string(),
target_label: "Person".to_string(),
relationship_types: vec!["KNOWS".to_string()],
direction: crate::ast::RelationshipDirection::Outgoing,
relationship_variable: Some("r2".to_string()),
properties: Default::default(),
target_properties: Default::default(),
};

// Analyze both patterns to build the context
let left_analysis = analysis::analyze(&expand_left).unwrap();
let left_ctx = analysis::PlanningContext::new(&left_analysis);

// Test the key inference logic directly
let (left_keys, right_keys) =
planner.infer_join_keys(&left_ctx, &expand_left, &expand_right);

// Should generate join keys for shared node variable 'b'
assert!(
!left_keys.is_empty(),
"Should generate join keys for shared node 'b'"
);
assert_eq!(
left_keys.len(),
right_keys.len(),
"Left and right keys should match"
);

// Should contain b__id (the shared node)
assert!(
left_keys.iter().any(|k| k.contains("b__id")),
"Should have join key for shared node 'b': {:?}",
left_keys
);

// Verify that relationship variables r1 and r2 are collected
let left_vars = planner.extract_variables(&expand_left);
let right_vars = planner.extract_variables(&expand_right);

assert!(left_vars.contains(&"r1".to_string()), "Left should have r1");
assert!(
right_vars.contains(&"r2".to_string()),
"Right should have r2"
);

// r1 and r2 are different, so they shouldn't be in shared variables
let shared: Vec<String> = left_vars
.iter()
.filter(|v| right_vars.contains(v))
.cloned()
.collect();
assert!(!shared.contains(&"r1".to_string()), "r1 is not shared");
assert!(!shared.contains(&"r2".to_string()), "r2 is not shared");
assert!(shared.contains(&"b".to_string()), "b is shared");
}
}
157 changes: 157 additions & 0 deletions rust/lance-graph/tests/test_datafusion_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2912,3 +2912,160 @@ async fn test_datafusion_disconnected_with_distinct() {
.collect();
assert_eq!(name_set, expected);
}

#[tokio::test]
async fn test_datafusion_shared_node_variable_join() {
// This should join on shared variable 'b' using b.id
let config = create_graph_config();
let person_batch = create_person_dataset();
let knows_batch = create_knows_dataset();

let query = CypherQuery::new(
"MATCH (a:Person)-[:KNOWS]->(b:Person), (b)-[:KNOWS]->(c:Person) \
RETURN a.name, b.name, c.name ORDER BY a.name, c.name",
)
.unwrap()
.with_config(config);

let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), person_batch);
datasets.insert("KNOWS".to_string(), knows_batch);

let result = query.execute_datafusion(datasets).await.unwrap();

// This is a two-hop path query that should use join key inference on 'b'
// Alice(1) -> Bob(2) -> Charlie(3)
// Alice(1) -> Bob(2) -> David(4)
assert!(
result.num_rows() >= 2,
"Should have at least 2 two-hop paths"
);

let a_names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let b_names = result
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let c_names = result
.column(2)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();

// Verify at least one path: Alice -> Bob -> Charlie
let mut found_path = false;
for i in 0..result.num_rows() {
if a_names.value(i) == "Alice" && b_names.value(i) == "Bob" && c_names.value(i) == "Charlie"
{
found_path = true;
break;
}
}
assert!(found_path, "Should find path: Alice -> Bob -> Charlie");
}

#[tokio::test]
async fn test_datafusion_shared_variable_with_filter() {
let config = create_graph_config();
let person_batch = create_person_dataset();
let knows_batch = create_knows_dataset();

let query = CypherQuery::new(
"MATCH (a:Person)-[:KNOWS]->(b:Person), (b)-[:KNOWS]->(c:Person) \
WHERE a.age > 20 AND c.age < 40 \
RETURN a.name, b.name, c.name",
)
.unwrap()
.with_config(config);

let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), person_batch);
datasets.insert("KNOWS".to_string(), knows_batch);

let result = query.execute_datafusion(datasets).await.unwrap();

// Should successfully execute with join key inference + filters
assert!(result.num_rows() > 0, "Should have results with filters");

// Verify all results satisfy the age constraints
// All 'a' nodes should have age > 20 (excludes no one in our dataset)
// All 'c' nodes should have age < 40 (excludes David who is 40)
for i in 0..result.num_rows() {
let c_name = result
.column(2)
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.value(i);
assert_ne!(c_name, "David", "David (age 40) should be filtered out");
}
}

#[tokio::test]
async fn test_datafusion_multiple_shared_variables() {
let config = create_graph_config();
let person_batch = create_person_dataset();
let knows_batch = create_knows_dataset();

let query = CypherQuery::new(
"MATCH (a:Person)-[:KNOWS]->(b:Person), (b)-[:KNOWS]->(c:Person), (c)-[:KNOWS]->(d:Person) \
RETURN a.name, b.name, c.name, d.name",
)
.unwrap()
.with_config(config);

let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), person_batch);
datasets.insert("KNOWS".to_string(), knows_batch);

let result = query.execute_datafusion(datasets).await.unwrap();

// This is a three-hop path query using join key inference on 'b' and 'c'
// Should successfully execute (may have 0 or more results depending on data)
assert_eq!(result.num_columns(), 4);
}

#[tokio::test]
async fn test_datafusion_shared_variable_distinct() {
let config = create_graph_config();
let person_batch = create_person_dataset();
let knows_batch = create_knows_dataset();

let query = CypherQuery::new(
"MATCH (a:Person)-[:KNOWS]->(b:Person), (b)-[:KNOWS]->(c:Person) \
RETURN DISTINCT b.name ORDER BY b.name",
)
.unwrap()
.with_config(config);

let mut datasets = HashMap::new();
datasets.insert("Person".to_string(), person_batch);
datasets.insert("KNOWS".to_string(), knows_batch);

let result = query.execute_datafusion(datasets).await.unwrap();

// Should return distinct intermediate nodes that have both incoming and outgoing KNOWS edges
assert!(result.num_rows() > 0, "Should have intermediate nodes");
assert_eq!(result.num_columns(), 1);

let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();

// Verify no duplicates
let name_set: std::collections::HashSet<String> = (0..result.num_rows())
.map(|i| names.value(i).to_string())
.collect();
assert_eq!(
name_set.len(),
result.num_rows(),
"DISTINCT should eliminate duplicates"
);
}
Loading
Loading