From e0e8e94031bdd711bfe3361c0be3ce65dce1f1f7 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Mon, 25 May 2026 18:21:32 -0700 Subject: [PATCH 1/2] Add Ladybug graph demo --- examples/ladybug_demo.py | 147 ++++++++++++++++++++ fquery/cypher_builder.py | 2 +- fquery/ladybug.py | 272 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 7 + setup.py | 1 + tests/test_ladybug.py | 180 ++++++++++++++++++++++++ tests/test_ladybug_demo.py | 16 +++ 7 files changed, 624 insertions(+), 1 deletion(-) create mode 100644 examples/ladybug_demo.py create mode 100644 fquery/ladybug.py create mode 100644 tests/test_ladybug.py create mode 100644 tests/test_ladybug_demo.py diff --git a/examples/ladybug_demo.py b/examples/ladybug_demo.py new file mode 100644 index 0000000..51fac84 --- /dev/null +++ b/examples/ladybug_demo.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import ast +from typing import List + +import ladybug as lb +import pyarrow as pa + +from fquery.arrow import arrow +from fquery.ladybug import ladybug, ladybug_graph +from fquery.view_model import edge, node + + +@ladybug +@arrow +@node +class User: + name: str + age: int + + @edge + async def follow(self) -> List["User"]: + yield self.follow + + @edge + async def reviews(self) -> List["Review"]: + yield self.reviews + + +@ladybug +@arrow +@node +class Review: + business: str + rating: int + + +UserQuery = User.query() +ReviewQuery = Review.query() + + +@ladybug_graph +class ReviewGraph: + User = User + Review = Review + + +def save_demo() -> None: + db = lb.Database(":memory:") + conn = lb.Connection(db) + + ReviewGraph.create_schema(conn) + + review = Review(id=10, business="Cafe", rating=5) + user = User(id=1, name="Ada", age=37) + follow1 = User(id=2, name="Grace", age=42) + follow2 = User(id=3, name="Linus", age=55) + user.reviews = [review] + user.follow = [follow1] + follow1.follow = [follow2] + + review.save(conn) + follow2.save(conn, include_edges=False) + follow1.save(conn) + user.write(conn) + + reviews_cypher = ( + UserQuery([1]) + .edge("reviews") + .project(["business", "rating"]) + .where(ast.Expr("review.rating >= 4")) + .to_cypher() + ) + assert reviews_cypher == ( + "MATCH (u:User)-[:REVIEWS]->(n1:Review)\n" + "WHERE n1.rating >= 4\n" + "RETURN n1.business, n1.rating" + ) + rows = conn.execute(reviews_cypher).get_all() + assert rows == [["Cafe", 5]] + print(reviews_cypher) + print(rows) + + follow_cypher = ( + UserQuery([1]).edge("follow").edge("follow").project(["name"]).to_cypher() + ) + assert follow_cypher == "MATCH (a:User)-[e:FOLLOW*2..2]-(b:User)\nRETURN b.name" + follow_rows = conn.execute(follow_cypher).get_all() + assert ["Linus"] in follow_rows + print(follow_cypher) + print(follow_rows) + + +def arrow_memory_demo() -> None: + db = lb.Database(":memory:") + conn = lb.Connection(db) + + users = User.to_arrow( + [ + User(id=1, name="Ada", age=37), + User(id=2, name="Grace", age=42), + User(id=3, name="Linus", age=55), + ] + ) + reviews = Review.to_arrow( + [ + Review(id=10, business="Cafe", rating=5), + Review(id=20, business="Deli", rating=4), + ] + ) + review_edges = { + "from": [1, 2], + "to": [10, 20], + } + + User.create_arrow_table(conn, users) + Review.create_arrow_table(conn, reviews) + User.create_arrow_rel_table( + conn, "follow", pa.table({"from": [1, 2], "to": [2, 3]}), User + ) + User.create_arrow_rel_table(conn, "reviews", pa.table(review_edges), Review) + + cypher = ( + UserQuery([1, 2]).edge("reviews").project(["business", "rating"]).to_cypher() + ) + result = User.query_as_arrow(conn, cypher, chunk_size=1024) + table = result.get_as_arrow() + assert table.to_pylist() == [ + {"n1.business": "Cafe", "n1.rating": 5}, + {"n1.business": "Deli", "n1.rating": 4}, + ] + print(table) + + follow_cypher = ( + UserQuery([1]).edge("follow").edge("follow").project(["name"]).to_cypher() + ) + assert follow_cypher == "MATCH (a:User)-[e:FOLLOW*2..2]-(b:User)\nRETURN b.name" + follow_result = User.query_as_arrow(conn, follow_cypher, chunk_size=1024) + follow_table = follow_result.get_as_arrow() + assert {"b.name": "Linus"} in follow_table.to_pylist() + print(follow_cypher) + print(follow_table) + + +if __name__ == "__main__": + save_demo() + arrow_memory_demo() diff --git a/fquery/cypher_builder.py b/fquery/cypher_builder.py index 88ce74c..2f49e19 100644 --- a/fquery/cypher_builder.py +++ b/fquery/cypher_builder.py @@ -36,7 +36,7 @@ def __init__(self, id1s): @staticmethod def table_from_query(query): - return query_type_name(query.__class__).capitalize() + return query_type_name(query.__class__) def _get_next_node_var(self): self.node_counter += 1 diff --git a/fquery/ladybug.py b/fquery/ladybug.py new file mode 100644 index 0000000..98fdc05 --- /dev/null +++ b/fquery/ladybug.py @@ -0,0 +1,272 @@ +import re +import types +from dataclasses import dataclass, fields, is_dataclass +from datetime import date, datetime, time +from typing import ( + Any, + ClassVar, + Dict, + List, + Type, + Union, + get_args, + get_origin, + get_type_hints, +) + +from .arrow import table as arrow_table +from .view_model import get_edges, get_return_type + +LADYBUG_TYPEMAP = { + bool: "BOOL", + int: "INT64", + float: "DOUBLE", + str: "STRING", + bytes: "BLOB", + datetime: "TIMESTAMP", + date: "DATE", + time: "TIME", +} + +PARAM_RE = re.compile(r"\$([A-Za-z_][A-Za-z0-9_]*)") + + +UNION_TYPES = (Union,) +if hasattr(types, "UnionType"): + UNION_TYPES = UNION_TYPES + (types.UnionType,) + + +def ladybug(cls): + """ + Decorator that adds Ladybug persistence helpers to an fquery node. + """ + if not is_dataclass(cls) or "__dataclass_fields__" not in cls.__dict__: + cls = dataclass(kw_only=True)(cls) + return model(cls) + + +def ladybug_graph(cls): + """ + Decorator that registers a group of Ladybug-backed fquery nodes. + """ + + def create_schema(graph_cls, conn) -> None: + models = _graph_models(graph_cls) + for model_cls in models: + model_cls.create_schema(conn, include_edges=False) + for model_cls in models: + model_cls.create_edge_schema(conn) + + cls.create_schema = classmethod(create_schema) + return cls + + +def _graph_models(cls: Type) -> List[Type]: + return [ + value + for value in cls.__dict__.values() + if isinstance(value, type) and hasattr(value, "__ladybug_node_ddl__") + ] + + +def _table_name(cls: Type) -> str: + return getattr(cls, "__ladybug_table__", cls.__name__) + + +def _rel_table_name(edge_name: str) -> str: + return edge_name.upper() + + +def _unwrap_optional(annotation): + origin = get_origin(annotation) + args = get_args(annotation) + if origin in UNION_TYPES and type(None) in args: + non_none = [arg for arg in args if arg is not type(None)] + if len(non_none) == 1: + return non_none[0] + return annotation + + +def _ladybug_type(annotation) -> str: + annotation = _unwrap_optional(annotation) + origin = get_origin(annotation) + args = get_args(annotation) + + if origin in (list, List): + if len(args) != 1: + raise TypeError(f"List fields must specify one item type: {annotation!r}") + return f"{_ladybug_type(args[0])}[]" + + if origin in (dict, Dict): + return "MAP" + + try: + return LADYBUG_TYPEMAP[annotation] + except KeyError: + raise TypeError(f"Unsupported Ladybug field type {annotation!r}") from None + + +def _node_fields(cls: Type): + return [field for field, _ in _node_field_types(cls)] + + +def _node_field_types(cls: Type): + type_hints = get_type_hints(cls) + pairs = [] + for field in fields(cls): + annotation = type_hints.get(field.name, field.type) + if get_origin(annotation) is ClassVar or field.name.startswith("_"): + continue + pairs.append((field, annotation)) + return pairs + + +def _node_ddl(cls: Type) -> str: + columns = [] + for field, annotation in _node_field_types(cls): + col = f"{field.name} {_ladybug_type(annotation)}" + if field.name == "id": + col += " PRIMARY KEY" + columns.append(col) + if not any(field.name == "id" for field in _node_fields(cls)): + columns.insert(0, "id INT64 PRIMARY KEY") + return f"CREATE NODE TABLE {_table_name(cls)}({', '.join(columns)})" + + +def _edge_ddl(cls: Type, edge_name: str, dst_table_name: str) -> str: + return ( + f"CREATE REL TABLE {_rel_table_name(edge_name)}" + f"(FROM {_table_name(cls)} TO {dst_table_name})" + ) + + +def _node_params(obj) -> Dict[str, Any]: + return {field.name: getattr(obj, field.name) for field in _node_fields(type(obj))} + + +def _node_create(cls: Type) -> str: + props = ", ".join(f"{field.name}: ${field.name}" for field in _node_fields(cls)) + return f"CREATE (n:{_table_name(cls)} {{{props}}})" + + +def _edge_create(src_cls: Type, edge_name: str, dst_cls: Type) -> str: + return ( + f"MATCH (src:{_table_name(src_cls)} {{id: $src_id}}), " + f"(dst:{_table_name(dst_cls)} {{id: $dst_id}}) " + f"CREATE (src)-[:{_rel_table_name(edge_name)}]->(dst)" + ) + + +def _cypher_literal(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, (int, float)): + return str(value) + if isinstance(value, str): + escaped = value.replace("\\", "\\\\").replace("'", "\\'") + return f"'{escaped}'" + if isinstance(value, (list, tuple)): + return "[" + ", ".join(_cypher_literal(item) for item in value) + "]" + raise TypeError(f"Unsupported Ladybug parameter value {value!r}") + + +def _inline_parameters(query: str, parameters: Dict[str, Any]) -> str: + def replace(match): + name = match.group(1) + if name not in parameters: + return match.group(0) + return _cypher_literal(parameters[name]) + + return PARAM_RE.sub(replace, query) + + +def _execute(conn, query: str, parameters: Dict[str, Any] = None): + if parameters is None: + return conn.execute(query) + try: + return conn.execute(query, parameters) + except ModuleNotFoundError as exc: + if exc.name != "numpy": + raise + return conn.execute(_inline_parameters(query, parameters)) + + +def _as_arrow_table(cls: Type, rows_or_table): + if hasattr(rows_or_table, "schema") and hasattr(rows_or_table, "to_pylist"): + return rows_or_table + if hasattr(cls, "to_arrow"): + return cls.to_arrow(rows_or_table) + return arrow_table(cls, rows_or_table) + + +def model(cls: Type) -> Type: + def create_schema(model_cls, conn, *, include_edges: bool = True) -> None: + conn.execute(_node_ddl(model_cls)) + if not include_edges: + return + model_cls.create_edge_schema(conn) + + def create_edge_schema(model_cls, conn) -> None: + for edge_name, edge_func in get_edges(model_cls).items(): + conn.execute( + _edge_ddl(model_cls, edge_name, get_return_type(edge_func._old)) + ) + + def create_arrow_table(model_cls, conn, rows_or_table, table_name: str = None): + table_name = table_name or _table_name(model_cls) + return conn.create_arrow_table( + table_name, _as_arrow_table(model_cls, rows_or_table) + ) + + def create_arrow_rel_table( + model_cls, + conn, + edge_name: str, + rows_or_table, + dst_cls: Type, + *, + layout="FLAT", + indptr_dataframe=None, + ): + return conn.create_arrow_rel_table( + _rel_table_name(edge_name), + rows_or_table, + _table_name(model_cls), + _table_name(dst_cls), + layout, + indptr_dataframe, + ) + + def query_as_arrow(model_cls, conn, query: str, chunk_size: int = 1024): + return conn.query_as_arrow(query, chunk_size) + + def save(self, conn, *, include_edges: bool = True) -> None: + _execute(conn, _node_create(type(self)), _node_params(self)) + if not include_edges: + return + for edge_name in get_edges(type(self)): + if edge_name not in self: + continue + targets = self[edge_name] + if targets is None: + continue + targets = targets if isinstance(targets, list) else [targets] + for target in targets: + _execute( + conn, + _edge_create(type(self), edge_name, type(target)), + {"src_id": self.id, "dst_id": target.id}, + ) + + cls.__ladybug_table__ = _table_name(cls) + cls.__ladybug_node_ddl__ = _node_ddl(cls) + cls.create_schema = classmethod(create_schema) + cls.create_edge_schema = classmethod(create_edge_schema) + cls.create_arrow_table = classmethod(create_arrow_table) + cls.create_arrow_rel_table = classmethod(create_arrow_rel_table) + cls.query_as_arrow = classmethod(query_as_arrow) + cls.save = save + cls.write = save + return cls diff --git a/pyproject.toml b/pyproject.toml index dc328e8..a4d685d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ django = [ ] malloy = [] cypher = [] +ladybug = [ + "ladybug; python_version >= '3.10'", +] test = [ "pytest", ] @@ -86,6 +89,9 @@ django = [ ] malloy = [] cypher = [] +ladybug = [ + "ladybug; python_version >= '3.10'", +] test = [ "pytest", ] @@ -98,6 +104,7 @@ dev = [ { include-group = "polars" }, { include-group = "pyarrow" }, { include-group = "django" }, + { include-group = "ladybug" }, ] [tool.isort] diff --git a/setup.py b/setup.py index 665e9c4..dd8b57f 100755 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ "django": ["django"], "malloy": [], "cypher": [], + "ladybug": ["ladybug; python_version >= '3.10'"], "df": ["polars >= 0.12.0"], }, ) diff --git a/tests/test_ladybug.py b/tests/test_ladybug.py new file mode 100644 index 0000000..0b34b79 --- /dev/null +++ b/tests/test_ladybug.py @@ -0,0 +1,180 @@ +from typing import List + +from fquery.ladybug import ladybug, ladybug_graph +from fquery.view_model import edge, node + + +class FakeConnection: + def __init__(self): + self.executed = [] + self.arrow_tables = [] + self.arrow_rel_tables = [] + self.arrow_queries = [] + + def execute(self, query, parameters=None): + self.executed.append((query, parameters)) + return [] + + def create_arrow_table(self, table_name, dataframe): + self.arrow_tables.append((table_name, dataframe)) + return [] + + def create_arrow_rel_table( + self, + table_name, + dataframe, + src_table_name, + dst_table_name, + layout, + indptr_dataframe, + ): + self.arrow_rel_tables.append( + ( + table_name, + dataframe, + src_table_name, + dst_table_name, + layout, + indptr_dataframe, + ) + ) + return [] + + def query_as_arrow(self, query, chunk_size): + self.arrow_queries.append((query, chunk_size)) + return "arrow-result" + + +class FakeArrowTable: + schema = None + + def to_pylist(self): + return [] + + +@ladybug +@node +class LadybugUser: + name: str + age: int + + @edge + async def follow(self) -> List["LadybugUser"]: + yield self.follow + + @edge + async def reviews(self) -> List["LadybugReview"]: + yield self.reviews + + +@ladybug +@node +class LadybugReview: + business: str + rating: int + + +@ladybug_graph +class LadybugReviewGraph: + User = LadybugUser + Review = LadybugReview + + +def test_ladybug_schema_and_to_cypher(): + conn = FakeConnection() + + LadybugUser.create_schema(conn) + + assert conn.executed == [ + ( + "CREATE NODE TABLE LadybugUser(id INT64 PRIMARY KEY, name STRING, age INT64)", + None, + ), + ("CREATE REL TABLE FOLLOW(FROM LadybugUser TO LadybugUser)", None), + ("CREATE REL TABLE REVIEWS(FROM LadybugUser TO LadybugReview)", None), + ] + LadybugReview.query() + assert ( + LadybugUser.query()([1]) + .edge("reviews") + .project(["business", "rating"]) + .to_cypher() + == "MATCH (u:LadybugUser)-[:REVIEWS]->(n1:LadybugReview)\n" + "RETURN n1.business, n1.rating" + ) + assert ( + LadybugUser.query()([1]) + .edge("follow") + .edge("follow") + .project(["name"]) + .to_cypher() + == "MATCH (a:LadybugUser)-[e:FOLLOW*2..2]-(b:LadybugUser)\n" + "RETURN b.name" + ) + + +def test_ladybug_graph_schema_registration(): + conn = FakeConnection() + + LadybugReviewGraph.create_schema(conn) + + assert conn.executed == [ + ( + "CREATE NODE TABLE LadybugUser(id INT64 PRIMARY KEY, name STRING, age INT64)", + None, + ), + ( + "CREATE NODE TABLE LadybugReview(id INT64 PRIMARY KEY, business STRING, rating INT64)", + None, + ), + ("CREATE REL TABLE FOLLOW(FROM LadybugUser TO LadybugUser)", None), + ("CREATE REL TABLE REVIEWS(FROM LadybugUser TO LadybugReview)", None), + ] + + +def test_ladybug_save_writes_node_and_edges(): + conn = FakeConnection() + review = LadybugReview(id=10, business="Cafe", rating=5) + user = LadybugUser(id=1, name="Ada", age=37) + followed = LadybugUser(id=2, name="Grace", age=42) + user.follow = [followed] + user.reviews = [review] + + user.save(conn) + + assert conn.executed == [ + ( + "CREATE (n:LadybugUser {id: $id, name: $name, age: $age})", + {"id": 1, "name": "Ada", "age": 37}, + ), + ( + "MATCH (src:LadybugUser {id: $src_id}), " + "(dst:LadybugUser {id: $dst_id}) " + "CREATE (src)-[:FOLLOW]->(dst)", + {"src_id": 1, "dst_id": 2}, + ), + ( + "MATCH (src:LadybugUser {id: $src_id}), " + "(dst:LadybugReview {id: $dst_id}) " + "CREATE (src)-[:REVIEWS]->(dst)", + {"src_id": 1, "dst_id": 10}, + ), + ] + + +def test_ladybug_arrow_helpers(): + conn = FakeConnection() + table = FakeArrowTable() + rel_table = object() + cypher = "MATCH (u:LadybugUser) RETURN u.name" + + LadybugUser.create_arrow_table(conn, table) + LadybugUser.create_arrow_rel_table(conn, "reviews", rel_table, LadybugReview) + result = LadybugUser.query_as_arrow(conn, cypher, chunk_size=64) + + assert conn.arrow_tables == [("LadybugUser", table)] + assert conn.arrow_rel_tables == [ + ("REVIEWS", rel_table, "LadybugUser", "LadybugReview", "FLAT", None) + ] + assert result == "arrow-result" + assert conn.arrow_queries == [(cypher, 64)] diff --git a/tests/test_ladybug_demo.py b/tests/test_ladybug_demo.py new file mode 100644 index 0000000..c96aace --- /dev/null +++ b/tests/test_ladybug_demo.py @@ -0,0 +1,16 @@ +import pytest + +pytest.importorskip("ladybug") +pytest.importorskip("pyarrow") + + +def test_ladybug_demo_save_path(): + from examples.ladybug_demo import save_demo + + save_demo() + + +def test_ladybug_demo_arrow_memory_path(): + from examples.ladybug_demo import arrow_memory_demo + + arrow_memory_demo() From 3c46bdd949721f567b90bc3b93832c50a5f94b68 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Mon, 25 May 2026 18:31:26 -0700 Subject: [PATCH 2/2] Make Ladybug graph writes declarative --- examples/ladybug_demo.py | 138 +++++++++++++++++--------------- fquery/cypher_builder.py | 47 ++++++++++- fquery/ladybug.py | 148 +++++++++++++++++++++++++++++++---- fquery/query.py | 3 + pyproject.toml | 2 +- tests/test_ladybug.py | 164 ++++++++++++++++++++++++++------------- 6 files changed, 367 insertions(+), 135 deletions(-) diff --git a/examples/ladybug_demo.py b/examples/ladybug_demo.py index 51fac84..1e48279 100644 --- a/examples/ladybug_demo.py +++ b/examples/ladybug_demo.py @@ -7,72 +7,70 @@ import pyarrow as pa from fquery.arrow import arrow -from fquery.ladybug import ladybug, ladybug_graph +from fquery.ladybug import graph, graph_edge, ladybug from fquery.view_model import edge, node -@ladybug -@arrow -@node -class User: - name: str - age: int - - @edge - async def follow(self) -> List["User"]: - yield self.follow - - @edge - async def reviews(self) -> List["Review"]: - yield self.reviews - - -@ladybug -@arrow -@node -class Review: - business: str - rating: int - - +@graph +class ReviewGraph: + @ladybug + @arrow + @node + class User: + name: str + age: int + + @edge + async def follows(self) -> List["User"]: + yield self.follows + + @edge + async def reviews(self) -> List["Review"]: + yield self.reviews + + @ladybug + @arrow + @node + class Review: + business: str + rating: int + + +User = ReviewGraph.User +Review = ReviewGraph.Review UserQuery = User.query() ReviewQuery = Review.query() -@ladybug_graph -class ReviewGraph: - User = User - Review = Review - - def save_demo() -> None: db = lb.Database(":memory:") conn = lb.Connection(db) ReviewGraph.create_schema(conn) - review = Review(id=10, business="Cafe", rating=5) - user = User(id=1, name="Ada", age=37) - follow1 = User(id=2, name="Grace", age=42) - follow2 = User(id=3, name="Linus", age=55) - user.reviews = [review] - user.follow = [follow1] - follow1.follow = [follow2] - - review.save(conn) - follow2.save(conn, include_edges=False) - follow1.save(conn) - user.write(conn) + u1 = User(id=1, name="Ada", age=37) + u2 = User(id=2, name="Grace", age=42) + u3 = User(id=3, name="Linus", age=55) + r1 = Review(id=10, business="Cafe", rating=5) + graph = ReviewGraph( + nodes=[u1, u2, u3, r1], + edges=[ + graph_edge("reviews", u1, r1), + graph_edge("follows", u1, u2), + graph_edge("follows", u2, u3), + ], + ) + graph.save(conn) reviews_cypher = ( - UserQuery([1]) + UserQuery({"name": "Ada"}) .edge("reviews") .project(["business", "rating"]) .where(ast.Expr("review.rating >= 4")) .to_cypher() ) assert reviews_cypher == ( - "MATCH (u:User)-[:REVIEWS]->(n1:Review)\n" + "MATCH (u:User {name: 'Ada'})-[:REVIEWS]->(n1:Review)\n" "WHERE n1.rating >= 4\n" "RETURN n1.business, n1.rating" ) @@ -81,14 +79,20 @@ def save_demo() -> None: print(reviews_cypher) print(rows) - follow_cypher = ( - UserQuery([1]).edge("follow").edge("follow").project(["name"]).to_cypher() + follows_cypher = ( + UserQuery({"name": "Ada"}) + .edge("follows") + .edge("follows") + .project(["name"]) + .to_cypher() ) - assert follow_cypher == "MATCH (a:User)-[e:FOLLOW*2..2]-(b:User)\nRETURN b.name" - follow_rows = conn.execute(follow_cypher).get_all() - assert ["Linus"] in follow_rows - print(follow_cypher) - print(follow_rows) + assert follows_cypher == ( + "MATCH (a:User {name: 'Ada'})-[e:FOLLOWS*2..2]-(b:User)\nRETURN b.name" + ) + follows_rows = conn.execute(follows_cypher).get_all() + assert ["Linus"] in follows_rows + print(follows_cypher) + print(follows_rows) def arrow_memory_demo() -> None: @@ -116,30 +120,38 @@ def arrow_memory_demo() -> None: User.create_arrow_table(conn, users) Review.create_arrow_table(conn, reviews) User.create_arrow_rel_table( - conn, "follow", pa.table({"from": [1, 2], "to": [2, 3]}), User + conn, "follows", pa.table({"from": [1, 2], "to": [2, 3]}), User ) User.create_arrow_rel_table(conn, "reviews", pa.table(review_edges), Review) cypher = ( - UserQuery([1, 2]).edge("reviews").project(["business", "rating"]).to_cypher() + UserQuery({"name": "Ada"}) + .edge("reviews") + .project(["business", "rating"]) + .to_cypher() ) result = User.query_as_arrow(conn, cypher, chunk_size=1024) table = result.get_as_arrow() assert table.to_pylist() == [ {"n1.business": "Cafe", "n1.rating": 5}, - {"n1.business": "Deli", "n1.rating": 4}, ] print(table) - follow_cypher = ( - UserQuery([1]).edge("follow").edge("follow").project(["name"]).to_cypher() + follows_cypher = ( + UserQuery({"name": "Ada"}) + .edge("follows") + .edge("follows") + .project(["name"]) + .to_cypher() + ) + assert follows_cypher == ( + "MATCH (a:User {name: 'Ada'})-[e:FOLLOWS*2..2]-(b:User)\nRETURN b.name" ) - assert follow_cypher == "MATCH (a:User)-[e:FOLLOW*2..2]-(b:User)\nRETURN b.name" - follow_result = User.query_as_arrow(conn, follow_cypher, chunk_size=1024) - follow_table = follow_result.get_as_arrow() - assert {"b.name": "Linus"} in follow_table.to_pylist() - print(follow_cypher) - print(follow_table) + follows_result = User.query_as_arrow(conn, follows_cypher, chunk_size=1024) + follows_table = follows_result.get_as_arrow() + assert {"b.name": "Linus"} in follows_table.to_pylist() + print(follows_cypher) + print(follows_table) if __name__ == "__main__": diff --git a/fquery/cypher_builder.py b/fquery/cypher_builder.py index 2f49e19..d2e58b2 100644 --- a/fquery/cypher_builder.py +++ b/fquery/cypher_builder.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import ast import operator +from typing import Any from .naming import query_type_name from .visitor import Visitor @@ -38,6 +39,27 @@ def __init__(self, id1s): def table_from_query(query): return query_type_name(query.__class__) + @staticmethod + def _cypher_literal(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, (int, float)): + return str(value) + if isinstance(value, str): + escaped = value.replace("\\", "\\\\").replace("'", "\\'") + return f"'{escaped}'" + raise TypeError(f"Unsupported Cypher property value {value!r}") + + def _node_pattern(self, var: str, label: str, props=None) -> str: + if not props: + return f"({var}:{label})" + prop_list = ", ".join( + f"{name}: {self._cypher_literal(value)}" for name, value in props.items() + ) + return f"({var}:{label} {{{prop_list}}})" + def _get_next_node_var(self): self.node_counter += 1 return f"n{self.node_counter}" @@ -45,7 +67,13 @@ def _get_next_node_var(self): async def visit_leaf(self, query): if not self.root_label: self.root_label = self.table_from_query(query) - self.match_parts = [f"({self.current_node}:{self.root_label})"] + self.match_parts = [ + self._node_pattern( + self.current_node, + self.root_label, + getattr(query, "_match_props", None), + ) + ] if query in self.visited: # Prevent infinite recursion @@ -92,7 +120,13 @@ async def visit_edge(self, query): root_query = root_query.child if hasattr(root_query, "__class__"): self.root_label = self.table_from_query(root_query) - self.match_parts = [f"({self.current_node}:{self.root_label})"] + self.match_parts = [ + self._node_pattern( + self.current_node, + self.root_label, + getattr(root_query, "_match_props", None), + ) + ] edge_name = query.edge_name @@ -106,6 +140,9 @@ async def visit_edge(self, query): ) if has_same_edge_child: + root_query = query + while hasattr(root_query, "child") and root_query.child: + root_query = root_query.child # Count the total number of consecutive edges of the same type hops = 1 # current edge current_query = query.child @@ -119,7 +156,11 @@ async def visit_edge(self, query): # This is the start of a multi-hop pattern (e.g., friend-of-friend, etc.) self.match_parts = [ - f"(a:{self.root_label})", + self._node_pattern( + "a", + self.root_label, + getattr(root_query, "_match_props", None), + ), f"[e:{edge_name.upper()}*{hops}..{hops}]", f"(b:{self.root_label})", ] diff --git a/fquery/ladybug.py b/fquery/ladybug.py index 98fdc05..8b5e78c 100644 --- a/fquery/ladybug.py +++ b/fquery/ladybug.py @@ -1,6 +1,6 @@ import re import types -from dataclasses import dataclass, fields, is_dataclass +from dataclasses import dataclass, field, fields, is_dataclass from datetime import date, datetime, time from typing import ( Any, @@ -31,6 +31,13 @@ PARAM_RE = re.compile(r"\$([A-Za-z_][A-Za-z0-9_]*)") +@dataclass(frozen=True) +class GraphEdge: + name: str + src: Any + dst: Any + + UNION_TYPES = (Union,) if hasattr(types, "UnionType"): UNION_TYPES = UNION_TYPES + (types.UnionType,) @@ -45,10 +52,18 @@ def ladybug(cls): return model(cls) -def ladybug_graph(cls): +def graph(cls): """ Decorator that registers a group of Ladybug-backed fquery nodes. """ + if not is_dataclass(cls): + annotations = dict(getattr(cls, "__annotations__", {})) + annotations.setdefault("nodes", List[Any]) + annotations.setdefault("edges", List[GraphEdge]) + cls.__annotations__ = annotations + cls.nodes = field(default_factory=list) + cls.edges = field(default_factory=list) + cls = dataclass(kw_only=True)(cls) def create_schema(graph_cls, conn) -> None: models = _graph_models(graph_cls) @@ -57,10 +72,27 @@ def create_schema(graph_cls, conn) -> None: for model_cls in models: model_cls.create_edge_schema(conn) + def save(self, conn) -> None: + nodes, edges = _collect_graph(self) + for node in nodes: + node.save(conn, include_edges=False) + for src, edge_name, dst in edges: + _execute( + conn, + _edge_create(type(src), edge_name, type(dst)), + {"src_id": src.id, "dst_id": dst.id}, + ) + cls.create_schema = classmethod(create_schema) + cls.save = save + cls.write = save return cls +def graph_edge(name: str, src: Any, dst: Any) -> GraphEdge: + return GraphEdge(name, src, dst) + + def _graph_models(cls: Type) -> List[Type]: return [ value @@ -69,6 +101,91 @@ def _graph_models(cls: Type) -> List[Type]: ] +def _is_ladybug_node(value) -> bool: + return hasattr(type(value), "__ladybug_node_ddl__") + + +def _is_graph_edge(value) -> bool: + return isinstance(value, GraphEdge) + + +def _iter_values(value): + if value is None: + return + if isinstance(value, dict): + for item in value.values(): + yield item + return + if isinstance(value, (list, tuple, set)): + for item in value: + yield item + return + yield value + + +def _is_graph_container(value) -> bool: + return isinstance(value, (dict, list, tuple, set)) + + +def _graph_roots(graph): + if is_dataclass(graph): + for graph_field in fields(graph): + yield getattr(graph, graph_field.name) + return + for name, value in vars(graph).items(): + if not name.startswith("_"): + yield value + + +def _collect_graph(graph): + nodes = [] + edges = [] + seen_nodes = set() + seen_edges = set() + + def add_edge(edge_name, src, dst): + edge_key = (type(src), src.id, edge_name, type(dst), dst.id) + if edge_key in seen_edges: + return + seen_edges.add(edge_key) + edges.append((src, edge_name, dst)) + + def visit(value): + if _is_graph_edge(value): + visit(value.src) + visit(value.dst) + add_edge(value.name, value.src, value.dst) + return + if _is_ladybug_node(value): + visit_node(value) + return + for item in _iter_values(value): + if _is_graph_edge(item): + visit(item) + elif _is_ladybug_node(item): + visit_node(item) + elif _is_graph_container(item): + visit(item) + + def visit_node(item): + node_key = (type(item), item.id) + if node_key in seen_nodes: + return + seen_nodes.add(node_key) + nodes.append(item) + + for edge_name in get_edges(type(item)): + if edge_name not in item: + continue + for target in _iter_values(item[edge_name]): + visit(target) + add_edge(edge_name, item, target) + + for root in _graph_roots(graph): + visit(root) + return nodes, edges + + def _table_name(cls: Type) -> str: return getattr(cls, "__ladybug_table__", cls.__name__) @@ -107,28 +224,28 @@ def _ladybug_type(annotation) -> str: def _node_fields(cls: Type): - return [field for field, _ in _node_field_types(cls)] + return [model_field for model_field, _ in _node_field_types(cls)] def _node_field_types(cls: Type): type_hints = get_type_hints(cls) pairs = [] - for field in fields(cls): - annotation = type_hints.get(field.name, field.type) - if get_origin(annotation) is ClassVar or field.name.startswith("_"): + for model_field in fields(cls): + annotation = type_hints.get(model_field.name, model_field.type) + if get_origin(annotation) is ClassVar or model_field.name.startswith("_"): continue - pairs.append((field, annotation)) + pairs.append((model_field, annotation)) return pairs def _node_ddl(cls: Type) -> str: columns = [] - for field, annotation in _node_field_types(cls): - col = f"{field.name} {_ladybug_type(annotation)}" - if field.name == "id": + for model_field, annotation in _node_field_types(cls): + col = f"{model_field.name} {_ladybug_type(annotation)}" + if model_field.name == "id": col += " PRIMARY KEY" columns.append(col) - if not any(field.name == "id" for field in _node_fields(cls)): + if not any(model_field.name == "id" for model_field in _node_fields(cls)): columns.insert(0, "id INT64 PRIMARY KEY") return f"CREATE NODE TABLE {_table_name(cls)}({', '.join(columns)})" @@ -141,11 +258,16 @@ def _edge_ddl(cls: Type, edge_name: str, dst_table_name: str) -> str: def _node_params(obj) -> Dict[str, Any]: - return {field.name: getattr(obj, field.name) for field in _node_fields(type(obj))} + return { + model_field.name: getattr(obj, model_field.name) + for model_field in _node_fields(type(obj)) + } def _node_create(cls: Type) -> str: - props = ", ".join(f"{field.name}: ${field.name}" for field in _node_fields(cls)) + props = ", ".join( + f"{model_field.name}: ${model_field.name}" for model_field in _node_fields(cls) + ) return f"CREATE (n:{_table_name(cls)} {{{props}}})" diff --git a/fquery/query.py b/fquery/query.py index 1b96850..0d87b0b 100644 --- a/fquery/query.py +++ b/fquery/query.py @@ -72,9 +72,12 @@ def __init__( # from a child query. assert (bool(items) ^ bool(ids)) or child self._items = items + self._match_props = ids if isinstance(ids, dict) else {} if self._items: # pyre-fixme[16]: `ViewModel` has no attribute `id`. self.ids = [item.id for item in self._items] + elif isinstance(ids, dict): + self.ids = [] else: self.ids = ids or [] diff --git a/pyproject.toml b/pyproject.toml index a4d685d..7f6c9fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ django = [ malloy = [] cypher = [] ladybug = [ - "ladybug; python_version >= '3.10'", + "ladybug >= 0.17.0; python_version >= '3.10'", ] test = [ "pytest", diff --git a/tests/test_ladybug.py b/tests/test_ladybug.py index 0b34b79..45da95f 100644 --- a/tests/test_ladybug.py +++ b/tests/test_ladybug.py @@ -1,6 +1,6 @@ from typing import List -from fquery.ladybug import ladybug, ladybug_graph +from fquery.ladybug import graph, graph_edge, ladybug from fquery.view_model import edge, node @@ -52,110 +52,164 @@ def to_pylist(self): return [] -@ladybug -@node -class LadybugUser: - name: str - age: int +@graph +class ReviewGraph: + @ladybug + @node + class User: + name: str + age: int - @edge - async def follow(self) -> List["LadybugUser"]: - yield self.follow + @edge + async def follows(self) -> List["User"]: + yield self.follows - @edge - async def reviews(self) -> List["LadybugReview"]: - yield self.reviews + @edge + async def reviews(self) -> List["Review"]: + yield self.reviews + @ladybug + @node + class Review: + business: str + rating: int -@ladybug -@node -class LadybugReview: - business: str - rating: int - -@ladybug_graph -class LadybugReviewGraph: - User = LadybugUser - Review = LadybugReview +User = ReviewGraph.User +Review = ReviewGraph.Review def test_ladybug_schema_and_to_cypher(): conn = FakeConnection() - LadybugUser.create_schema(conn) + User.create_schema(conn) assert conn.executed == [ ( - "CREATE NODE TABLE LadybugUser(id INT64 PRIMARY KEY, name STRING, age INT64)", + "CREATE NODE TABLE User(id INT64 PRIMARY KEY, name STRING, age INT64)", None, ), - ("CREATE REL TABLE FOLLOW(FROM LadybugUser TO LadybugUser)", None), - ("CREATE REL TABLE REVIEWS(FROM LadybugUser TO LadybugReview)", None), + ("CREATE REL TABLE FOLLOWS(FROM User TO User)", None), + ("CREATE REL TABLE REVIEWS(FROM User TO Review)", None), ] - LadybugReview.query() + Review.query() assert ( - LadybugUser.query()([1]) + User.query()({"name": "Ada"}) .edge("reviews") .project(["business", "rating"]) .to_cypher() - == "MATCH (u:LadybugUser)-[:REVIEWS]->(n1:LadybugReview)\n" + == "MATCH (u:User {name: 'Ada'})-[:REVIEWS]->(n1:Review)\n" "RETURN n1.business, n1.rating" ) assert ( - LadybugUser.query()([1]) - .edge("follow") - .edge("follow") + User.query()({"name": "Ada"}) + .edge("follows") + .edge("follows") .project(["name"]) .to_cypher() - == "MATCH (a:LadybugUser)-[e:FOLLOW*2..2]-(b:LadybugUser)\n" + == "MATCH (a:User {name: 'Ada'})-[e:FOLLOWS*2..2]-(b:User)\n" "RETURN b.name" ) -def test_ladybug_graph_schema_registration(): +def test_graph_schema_registration(): conn = FakeConnection() - LadybugReviewGraph.create_schema(conn) + ReviewGraph.create_schema(conn) assert conn.executed == [ ( - "CREATE NODE TABLE LadybugUser(id INT64 PRIMARY KEY, name STRING, age INT64)", + "CREATE NODE TABLE User(id INT64 PRIMARY KEY, name STRING, age INT64)", None, ), ( - "CREATE NODE TABLE LadybugReview(id INT64 PRIMARY KEY, business STRING, rating INT64)", + "CREATE NODE TABLE Review(id INT64 PRIMARY KEY, business STRING, rating INT64)", None, ), - ("CREATE REL TABLE FOLLOW(FROM LadybugUser TO LadybugUser)", None), - ("CREATE REL TABLE REVIEWS(FROM LadybugUser TO LadybugReview)", None), + ("CREATE REL TABLE FOLLOWS(FROM User TO User)", None), + ("CREATE REL TABLE REVIEWS(FROM User TO Review)", None), ] def test_ladybug_save_writes_node_and_edges(): conn = FakeConnection() - review = LadybugReview(id=10, business="Cafe", rating=5) - user = LadybugUser(id=1, name="Ada", age=37) - followed = LadybugUser(id=2, name="Grace", age=42) - user.follow = [followed] + review = Review(id=10, business="Cafe", rating=5) + user = User(id=1, name="Ada", age=37) + followed = User(id=2, name="Grace", age=42) + user.follows = [followed] user.reviews = [review] user.save(conn) assert conn.executed == [ ( - "CREATE (n:LadybugUser {id: $id, name: $name, age: $age})", + "CREATE (n:User {id: $id, name: $name, age: $age})", {"id": 1, "name": "Ada", "age": 37}, ), ( - "MATCH (src:LadybugUser {id: $src_id}), " - "(dst:LadybugUser {id: $dst_id}) " - "CREATE (src)-[:FOLLOW]->(dst)", + "MATCH (src:User {id: $src_id}), " + "(dst:User {id: $dst_id}) " + "CREATE (src)-[:FOLLOWS]->(dst)", {"src_id": 1, "dst_id": 2}, ), ( - "MATCH (src:LadybugUser {id: $src_id}), " - "(dst:LadybugReview {id: $dst_id}) " + "MATCH (src:User {id: $src_id}), " + "(dst:Review {id: $dst_id}) " + "CREATE (src)-[:REVIEWS]->(dst)", + {"src_id": 1, "dst_id": 10}, + ), + ] + + +def test_graph_save_writes_nodes_then_edges(): + conn = FakeConnection() + u1 = User(id=1, name="Ada", age=37) + u2 = User(id=2, name="Grace", age=42) + u3 = User(id=3, name="Linus", age=55) + r1 = Review(id=10, business="Cafe", rating=5) + graph = ReviewGraph( + nodes=[u1, u2, u3, r1], + edges=[ + graph_edge("follows", u1, u2), + graph_edge("follows", u2, u3), + graph_edge("reviews", u1, r1), + ], + ) + + graph.save(conn) + + assert conn.executed == [ + ( + "CREATE (n:User {id: $id, name: $name, age: $age})", + {"id": 1, "name": "Ada", "age": 37}, + ), + ( + "CREATE (n:User {id: $id, name: $name, age: $age})", + {"id": 2, "name": "Grace", "age": 42}, + ), + ( + "CREATE (n:User {id: $id, name: $name, age: $age})", + {"id": 3, "name": "Linus", "age": 55}, + ), + ( + "CREATE (n:Review {id: $id, business: $business, rating: $rating})", + {"id": 10, "business": "Cafe", "rating": 5}, + ), + ( + "MATCH (src:User {id: $src_id}), " + "(dst:User {id: $dst_id}) " + "CREATE (src)-[:FOLLOWS]->(dst)", + {"src_id": 1, "dst_id": 2}, + ), + ( + "MATCH (src:User {id: $src_id}), " + "(dst:User {id: $dst_id}) " + "CREATE (src)-[:FOLLOWS]->(dst)", + {"src_id": 2, "dst_id": 3}, + ), + ( + "MATCH (src:User {id: $src_id}), " + "(dst:Review {id: $dst_id}) " "CREATE (src)-[:REVIEWS]->(dst)", {"src_id": 1, "dst_id": 10}, ), @@ -166,15 +220,15 @@ def test_ladybug_arrow_helpers(): conn = FakeConnection() table = FakeArrowTable() rel_table = object() - cypher = "MATCH (u:LadybugUser) RETURN u.name" + cypher = "MATCH (u:User) RETURN u.name" - LadybugUser.create_arrow_table(conn, table) - LadybugUser.create_arrow_rel_table(conn, "reviews", rel_table, LadybugReview) - result = LadybugUser.query_as_arrow(conn, cypher, chunk_size=64) + User.create_arrow_table(conn, table) + User.create_arrow_rel_table(conn, "reviews", rel_table, Review) + result = User.query_as_arrow(conn, cypher, chunk_size=64) - assert conn.arrow_tables == [("LadybugUser", table)] + assert conn.arrow_tables == [("User", table)] assert conn.arrow_rel_tables == [ - ("REVIEWS", rel_table, "LadybugUser", "LadybugReview", "FLAT", None) + ("REVIEWS", rel_table, "User", "Review", "FLAT", None) ] assert result == "arrow-result" assert conn.arrow_queries == [(cypher, 64)]