Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions examples/ladybug_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from __future__ import annotations

import ast
from typing import List

import ladybug as lb
import pyarrow as pa

from fquery.arrow import arrow
from fquery.ladybug import graph, graph_edge, ladybug
from fquery.view_model import edge, node


@graph
class ReviewGraph:
@ladybug
@arrow
@node
class User:
name: str
age: int

@edge
async def follows(self) -> List["User"]:
yield self.follows

@edge
async def reviews(self) -> List["Review"]:
yield self.reviews

@ladybug
@arrow
@node
class Review:
business: str
rating: int


User = ReviewGraph.User
Review = ReviewGraph.Review
UserQuery = User.query()
ReviewQuery = Review.query()


def save_demo() -> None:
db = lb.Database(":memory:")
conn = lb.Connection(db)

ReviewGraph.create_schema(conn)

u1 = User(id=1, name="Ada", age=37)
u2 = User(id=2, name="Grace", age=42)
u3 = User(id=3, name="Linus", age=55)
r1 = Review(id=10, business="Cafe", rating=5)
graph = ReviewGraph(
nodes=[u1, u2, u3, r1],
edges=[
graph_edge("reviews", u1, r1),
graph_edge("follows", u1, u2),
graph_edge("follows", u2, u3),
],
)
graph.save(conn)

reviews_cypher = (
UserQuery({"name": "Ada"})
.edge("reviews")
.project(["business", "rating"])
.where(ast.Expr("review.rating >= 4"))
.to_cypher()
)
assert reviews_cypher == (
"MATCH (u:User {name: 'Ada'})-[:REVIEWS]->(n1:Review)\n"
"WHERE n1.rating >= 4\n"
"RETURN n1.business, n1.rating"
)
rows = conn.execute(reviews_cypher).get_all()
assert rows == [["Cafe", 5]]
print(reviews_cypher)
print(rows)

follows_cypher = (
UserQuery({"name": "Ada"})
.edge("follows")
.edge("follows")
.project(["name"])
.to_cypher()
)
assert follows_cypher == (
"MATCH (a:User {name: 'Ada'})-[e:FOLLOWS*2..2]-(b:User)\nRETURN b.name"
)
follows_rows = conn.execute(follows_cypher).get_all()
assert ["Linus"] in follows_rows
print(follows_cypher)
print(follows_rows)


def arrow_memory_demo() -> None:
db = lb.Database(":memory:")
conn = lb.Connection(db)

users = User.to_arrow(
[
User(id=1, name="Ada", age=37),
User(id=2, name="Grace", age=42),
User(id=3, name="Linus", age=55),
]
)
reviews = Review.to_arrow(
[
Review(id=10, business="Cafe", rating=5),
Review(id=20, business="Deli", rating=4),
]
)
review_edges = {
"from": [1, 2],
"to": [10, 20],
}

User.create_arrow_table(conn, users)
Review.create_arrow_table(conn, reviews)
User.create_arrow_rel_table(
conn, "follows", pa.table({"from": [1, 2], "to": [2, 3]}), User
)
User.create_arrow_rel_table(conn, "reviews", pa.table(review_edges), Review)

cypher = (
UserQuery({"name": "Ada"})
.edge("reviews")
.project(["business", "rating"])
.to_cypher()
)
result = User.query_as_arrow(conn, cypher, chunk_size=1024)
table = result.get_as_arrow()
assert table.to_pylist() == [
{"n1.business": "Cafe", "n1.rating": 5},
]
print(table)

follows_cypher = (
UserQuery({"name": "Ada"})
.edge("follows")
.edge("follows")
.project(["name"])
.to_cypher()
)
assert follows_cypher == (
"MATCH (a:User {name: 'Ada'})-[e:FOLLOWS*2..2]-(b:User)\nRETURN b.name"
)
follows_result = User.query_as_arrow(conn, follows_cypher, chunk_size=1024)
follows_table = follows_result.get_as_arrow()
assert {"b.name": "Linus"} in follows_table.to_pylist()
print(follows_cypher)
print(follows_table)


if __name__ == "__main__":
save_demo()
arrow_memory_demo()
49 changes: 45 additions & 4 deletions fquery/cypher_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import ast
import operator
from typing import Any

from .naming import query_type_name
from .visitor import Visitor
Expand Down Expand Up @@ -36,7 +37,28 @@ def __init__(self, id1s):

@staticmethod
def table_from_query(query):
return query_type_name(query.__class__).capitalize()
return query_type_name(query.__class__)

@staticmethod
def _cypher_literal(value: Any) -> str:
if value is None:
return "null"
if isinstance(value, bool):
return "true" if value else "false"
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, str):
escaped = value.replace("\\", "\\\\").replace("'", "\\'")
return f"'{escaped}'"
raise TypeError(f"Unsupported Cypher property value {value!r}")

def _node_pattern(self, var: str, label: str, props=None) -> str:
if not props:
return f"({var}:{label})"
prop_list = ", ".join(
f"{name}: {self._cypher_literal(value)}" for name, value in props.items()
)
return f"({var}:{label} {{{prop_list}}})"

def _get_next_node_var(self):
self.node_counter += 1
Expand All @@ -45,7 +67,13 @@ def _get_next_node_var(self):
async def visit_leaf(self, query):
if not self.root_label:
self.root_label = self.table_from_query(query)
self.match_parts = [f"({self.current_node}:{self.root_label})"]
self.match_parts = [
self._node_pattern(
self.current_node,
self.root_label,
getattr(query, "_match_props", None),
)
]

if query in self.visited:
# Prevent infinite recursion
Expand Down Expand Up @@ -92,7 +120,13 @@ async def visit_edge(self, query):
root_query = root_query.child
if hasattr(root_query, "__class__"):
self.root_label = self.table_from_query(root_query)
self.match_parts = [f"({self.current_node}:{self.root_label})"]
self.match_parts = [
self._node_pattern(
self.current_node,
self.root_label,
getattr(root_query, "_match_props", None),
)
]

edge_name = query.edge_name

Expand All @@ -106,6 +140,9 @@ async def visit_edge(self, query):
)

if has_same_edge_child:
root_query = query
while hasattr(root_query, "child") and root_query.child:
root_query = root_query.child
# Count the total number of consecutive edges of the same type
hops = 1 # current edge
current_query = query.child
Expand All @@ -119,7 +156,11 @@ async def visit_edge(self, query):

# This is the start of a multi-hop pattern (e.g., friend-of-friend, etc.)
self.match_parts = [
f"(a:{self.root_label})",
self._node_pattern(
"a",
self.root_label,
getattr(root_query, "_match_props", None),
),
f"[e:{edge_name.upper()}*{hops}..{hops}]",
f"(b:{self.root_label})",
]
Expand Down
Loading
Loading