diff --git a/examples/ladybug_demo.py b/examples/ladybug_demo.py new file mode 100644 index 0000000..1e48279 --- /dev/null +++ b/examples/ladybug_demo.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import ast +from typing import List + +import ladybug as lb +import pyarrow as pa + +from fquery.arrow import arrow +from fquery.ladybug import graph, graph_edge, ladybug +from fquery.view_model import edge, node + + +@graph +class ReviewGraph: + @ladybug + @arrow + @node + class User: + name: str + age: int + + @edge + async def follows(self) -> List["User"]: + yield self.follows + + @edge + async def reviews(self) -> List["Review"]: + yield self.reviews + + @ladybug + @arrow + @node + class Review: + business: str + rating: int + + +User = ReviewGraph.User +Review = ReviewGraph.Review +UserQuery = User.query() +ReviewQuery = Review.query() + + +def save_demo() -> None: + db = lb.Database(":memory:") + conn = lb.Connection(db) + + ReviewGraph.create_schema(conn) + + u1 = User(id=1, name="Ada", age=37) + u2 = User(id=2, name="Grace", age=42) + u3 = User(id=3, name="Linus", age=55) + r1 = Review(id=10, business="Cafe", rating=5) + graph = ReviewGraph( + nodes=[u1, u2, u3, r1], + edges=[ + graph_edge("reviews", u1, r1), + graph_edge("follows", u1, u2), + graph_edge("follows", u2, u3), + ], + ) + graph.save(conn) + + reviews_cypher = ( + UserQuery({"name": "Ada"}) + .edge("reviews") + .project(["business", "rating"]) + .where(ast.Expr("review.rating >= 4")) + .to_cypher() + ) + assert reviews_cypher == ( + "MATCH (u:User {name: 'Ada'})-[:REVIEWS]->(n1:Review)\n" + "WHERE n1.rating >= 4\n" + "RETURN n1.business, n1.rating" + ) + rows = conn.execute(reviews_cypher).get_all() + assert rows == [["Cafe", 5]] + print(reviews_cypher) + print(rows) + + follows_cypher = ( + UserQuery({"name": "Ada"}) + .edge("follows") + .edge("follows") + .project(["name"]) + .to_cypher() + ) + assert follows_cypher == ( + "MATCH (a:User {name: 'Ada'})-[e:FOLLOWS*2..2]-(b:User)\nRETURN b.name" + ) + follows_rows = conn.execute(follows_cypher).get_all() + assert ["Linus"] in follows_rows + print(follows_cypher) + print(follows_rows) + + +def arrow_memory_demo() -> None: + db = lb.Database(":memory:") + conn = lb.Connection(db) + + users = User.to_arrow( + [ + User(id=1, name="Ada", age=37), + User(id=2, name="Grace", age=42), + User(id=3, name="Linus", age=55), + ] + ) + reviews = Review.to_arrow( + [ + Review(id=10, business="Cafe", rating=5), + Review(id=20, business="Deli", rating=4), + ] + ) + review_edges = { + "from": [1, 2], + "to": [10, 20], + } + + User.create_arrow_table(conn, users) + Review.create_arrow_table(conn, reviews) + User.create_arrow_rel_table( + conn, "follows", pa.table({"from": [1, 2], "to": [2, 3]}), User + ) + User.create_arrow_rel_table(conn, "reviews", pa.table(review_edges), Review) + + cypher = ( + UserQuery({"name": "Ada"}) + .edge("reviews") + .project(["business", "rating"]) + .to_cypher() + ) + result = User.query_as_arrow(conn, cypher, chunk_size=1024) + table = result.get_as_arrow() + assert table.to_pylist() == [ + {"n1.business": "Cafe", "n1.rating": 5}, + ] + print(table) + + follows_cypher = ( + UserQuery({"name": "Ada"}) + .edge("follows") + .edge("follows") + .project(["name"]) + .to_cypher() + ) + assert follows_cypher == ( + "MATCH (a:User {name: 'Ada'})-[e:FOLLOWS*2..2]-(b:User)\nRETURN b.name" + ) + follows_result = User.query_as_arrow(conn, follows_cypher, chunk_size=1024) + follows_table = follows_result.get_as_arrow() + assert {"b.name": "Linus"} in follows_table.to_pylist() + print(follows_cypher) + print(follows_table) + + +if __name__ == "__main__": + save_demo() + arrow_memory_demo() diff --git a/fquery/cypher_builder.py b/fquery/cypher_builder.py index 88ce74c..d2e58b2 100644 --- a/fquery/cypher_builder.py +++ b/fquery/cypher_builder.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import ast import operator +from typing import Any from .naming import query_type_name from .visitor import Visitor @@ -36,7 +37,28 @@ def __init__(self, id1s): @staticmethod def table_from_query(query): - return query_type_name(query.__class__).capitalize() + return query_type_name(query.__class__) + + @staticmethod + def _cypher_literal(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, (int, float)): + return str(value) + if isinstance(value, str): + escaped = value.replace("\\", "\\\\").replace("'", "\\'") + return f"'{escaped}'" + raise TypeError(f"Unsupported Cypher property value {value!r}") + + def _node_pattern(self, var: str, label: str, props=None) -> str: + if not props: + return f"({var}:{label})" + prop_list = ", ".join( + f"{name}: {self._cypher_literal(value)}" for name, value in props.items() + ) + return f"({var}:{label} {{{prop_list}}})" def _get_next_node_var(self): self.node_counter += 1 @@ -45,7 +67,13 @@ def _get_next_node_var(self): async def visit_leaf(self, query): if not self.root_label: self.root_label = self.table_from_query(query) - self.match_parts = [f"({self.current_node}:{self.root_label})"] + self.match_parts = [ + self._node_pattern( + self.current_node, + self.root_label, + getattr(query, "_match_props", None), + ) + ] if query in self.visited: # Prevent infinite recursion @@ -92,7 +120,13 @@ async def visit_edge(self, query): root_query = root_query.child if hasattr(root_query, "__class__"): self.root_label = self.table_from_query(root_query) - self.match_parts = [f"({self.current_node}:{self.root_label})"] + self.match_parts = [ + self._node_pattern( + self.current_node, + self.root_label, + getattr(root_query, "_match_props", None), + ) + ] edge_name = query.edge_name @@ -106,6 +140,9 @@ async def visit_edge(self, query): ) if has_same_edge_child: + root_query = query + while hasattr(root_query, "child") and root_query.child: + root_query = root_query.child # Count the total number of consecutive edges of the same type hops = 1 # current edge current_query = query.child @@ -119,7 +156,11 @@ async def visit_edge(self, query): # This is the start of a multi-hop pattern (e.g., friend-of-friend, etc.) self.match_parts = [ - f"(a:{self.root_label})", + self._node_pattern( + "a", + self.root_label, + getattr(root_query, "_match_props", None), + ), f"[e:{edge_name.upper()}*{hops}..{hops}]", f"(b:{self.root_label})", ] diff --git a/fquery/ladybug.py b/fquery/ladybug.py new file mode 100644 index 0000000..8b5e78c --- /dev/null +++ b/fquery/ladybug.py @@ -0,0 +1,394 @@ +import re +import types +from dataclasses import dataclass, field, fields, is_dataclass +from datetime import date, datetime, time +from typing import ( + Any, + ClassVar, + Dict, + List, + Type, + Union, + get_args, + get_origin, + get_type_hints, +) + +from .arrow import table as arrow_table +from .view_model import get_edges, get_return_type + +LADYBUG_TYPEMAP = { + bool: "BOOL", + int: "INT64", + float: "DOUBLE", + str: "STRING", + bytes: "BLOB", + datetime: "TIMESTAMP", + date: "DATE", + time: "TIME", +} + +PARAM_RE = re.compile(r"\$([A-Za-z_][A-Za-z0-9_]*)") + + +@dataclass(frozen=True) +class GraphEdge: + name: str + src: Any + dst: Any + + +UNION_TYPES = (Union,) +if hasattr(types, "UnionType"): + UNION_TYPES = UNION_TYPES + (types.UnionType,) + + +def ladybug(cls): + """ + Decorator that adds Ladybug persistence helpers to an fquery node. + """ + if not is_dataclass(cls) or "__dataclass_fields__" not in cls.__dict__: + cls = dataclass(kw_only=True)(cls) + return model(cls) + + +def graph(cls): + """ + Decorator that registers a group of Ladybug-backed fquery nodes. + """ + if not is_dataclass(cls): + annotations = dict(getattr(cls, "__annotations__", {})) + annotations.setdefault("nodes", List[Any]) + annotations.setdefault("edges", List[GraphEdge]) + cls.__annotations__ = annotations + cls.nodes = field(default_factory=list) + cls.edges = field(default_factory=list) + cls = dataclass(kw_only=True)(cls) + + def create_schema(graph_cls, conn) -> None: + models = _graph_models(graph_cls) + for model_cls in models: + model_cls.create_schema(conn, include_edges=False) + for model_cls in models: + model_cls.create_edge_schema(conn) + + def save(self, conn) -> None: + nodes, edges = _collect_graph(self) + for node in nodes: + node.save(conn, include_edges=False) + for src, edge_name, dst in edges: + _execute( + conn, + _edge_create(type(src), edge_name, type(dst)), + {"src_id": src.id, "dst_id": dst.id}, + ) + + cls.create_schema = classmethod(create_schema) + cls.save = save + cls.write = save + return cls + + +def graph_edge(name: str, src: Any, dst: Any) -> GraphEdge: + return GraphEdge(name, src, dst) + + +def _graph_models(cls: Type) -> List[Type]: + return [ + value + for value in cls.__dict__.values() + if isinstance(value, type) and hasattr(value, "__ladybug_node_ddl__") + ] + + +def _is_ladybug_node(value) -> bool: + return hasattr(type(value), "__ladybug_node_ddl__") + + +def _is_graph_edge(value) -> bool: + return isinstance(value, GraphEdge) + + +def _iter_values(value): + if value is None: + return + if isinstance(value, dict): + for item in value.values(): + yield item + return + if isinstance(value, (list, tuple, set)): + for item in value: + yield item + return + yield value + + +def _is_graph_container(value) -> bool: + return isinstance(value, (dict, list, tuple, set)) + + +def _graph_roots(graph): + if is_dataclass(graph): + for graph_field in fields(graph): + yield getattr(graph, graph_field.name) + return + for name, value in vars(graph).items(): + if not name.startswith("_"): + yield value + + +def _collect_graph(graph): + nodes = [] + edges = [] + seen_nodes = set() + seen_edges = set() + + def add_edge(edge_name, src, dst): + edge_key = (type(src), src.id, edge_name, type(dst), dst.id) + if edge_key in seen_edges: + return + seen_edges.add(edge_key) + edges.append((src, edge_name, dst)) + + def visit(value): + if _is_graph_edge(value): + visit(value.src) + visit(value.dst) + add_edge(value.name, value.src, value.dst) + return + if _is_ladybug_node(value): + visit_node(value) + return + for item in _iter_values(value): + if _is_graph_edge(item): + visit(item) + elif _is_ladybug_node(item): + visit_node(item) + elif _is_graph_container(item): + visit(item) + + def visit_node(item): + node_key = (type(item), item.id) + if node_key in seen_nodes: + return + seen_nodes.add(node_key) + nodes.append(item) + + for edge_name in get_edges(type(item)): + if edge_name not in item: + continue + for target in _iter_values(item[edge_name]): + visit(target) + add_edge(edge_name, item, target) + + for root in _graph_roots(graph): + visit(root) + return nodes, edges + + +def _table_name(cls: Type) -> str: + return getattr(cls, "__ladybug_table__", cls.__name__) + + +def _rel_table_name(edge_name: str) -> str: + return edge_name.upper() + + +def _unwrap_optional(annotation): + origin = get_origin(annotation) + args = get_args(annotation) + if origin in UNION_TYPES and type(None) in args: + non_none = [arg for arg in args if arg is not type(None)] + if len(non_none) == 1: + return non_none[0] + return annotation + + +def _ladybug_type(annotation) -> str: + annotation = _unwrap_optional(annotation) + origin = get_origin(annotation) + args = get_args(annotation) + + if origin in (list, List): + if len(args) != 1: + raise TypeError(f"List fields must specify one item type: {annotation!r}") + return f"{_ladybug_type(args[0])}[]" + + if origin in (dict, Dict): + return "MAP" + + try: + return LADYBUG_TYPEMAP[annotation] + except KeyError: + raise TypeError(f"Unsupported Ladybug field type {annotation!r}") from None + + +def _node_fields(cls: Type): + return [model_field for model_field, _ in _node_field_types(cls)] + + +def _node_field_types(cls: Type): + type_hints = get_type_hints(cls) + pairs = [] + for model_field in fields(cls): + annotation = type_hints.get(model_field.name, model_field.type) + if get_origin(annotation) is ClassVar or model_field.name.startswith("_"): + continue + pairs.append((model_field, annotation)) + return pairs + + +def _node_ddl(cls: Type) -> str: + columns = [] + for model_field, annotation in _node_field_types(cls): + col = f"{model_field.name} {_ladybug_type(annotation)}" + if model_field.name == "id": + col += " PRIMARY KEY" + columns.append(col) + if not any(model_field.name == "id" for model_field in _node_fields(cls)): + columns.insert(0, "id INT64 PRIMARY KEY") + return f"CREATE NODE TABLE {_table_name(cls)}({', '.join(columns)})" + + +def _edge_ddl(cls: Type, edge_name: str, dst_table_name: str) -> str: + return ( + f"CREATE REL TABLE {_rel_table_name(edge_name)}" + f"(FROM {_table_name(cls)} TO {dst_table_name})" + ) + + +def _node_params(obj) -> Dict[str, Any]: + return { + model_field.name: getattr(obj, model_field.name) + for model_field in _node_fields(type(obj)) + } + + +def _node_create(cls: Type) -> str: + props = ", ".join( + f"{model_field.name}: ${model_field.name}" for model_field in _node_fields(cls) + ) + return f"CREATE (n:{_table_name(cls)} {{{props}}})" + + +def _edge_create(src_cls: Type, edge_name: str, dst_cls: Type) -> str: + return ( + f"MATCH (src:{_table_name(src_cls)} {{id: $src_id}}), " + f"(dst:{_table_name(dst_cls)} {{id: $dst_id}}) " + f"CREATE (src)-[:{_rel_table_name(edge_name)}]->(dst)" + ) + + +def _cypher_literal(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, (int, float)): + return str(value) + if isinstance(value, str): + escaped = value.replace("\\", "\\\\").replace("'", "\\'") + return f"'{escaped}'" + if isinstance(value, (list, tuple)): + return "[" + ", ".join(_cypher_literal(item) for item in value) + "]" + raise TypeError(f"Unsupported Ladybug parameter value {value!r}") + + +def _inline_parameters(query: str, parameters: Dict[str, Any]) -> str: + def replace(match): + name = match.group(1) + if name not in parameters: + return match.group(0) + return _cypher_literal(parameters[name]) + + return PARAM_RE.sub(replace, query) + + +def _execute(conn, query: str, parameters: Dict[str, Any] = None): + if parameters is None: + return conn.execute(query) + try: + return conn.execute(query, parameters) + except ModuleNotFoundError as exc: + if exc.name != "numpy": + raise + return conn.execute(_inline_parameters(query, parameters)) + + +def _as_arrow_table(cls: Type, rows_or_table): + if hasattr(rows_or_table, "schema") and hasattr(rows_or_table, "to_pylist"): + return rows_or_table + if hasattr(cls, "to_arrow"): + return cls.to_arrow(rows_or_table) + return arrow_table(cls, rows_or_table) + + +def model(cls: Type) -> Type: + def create_schema(model_cls, conn, *, include_edges: bool = True) -> None: + conn.execute(_node_ddl(model_cls)) + if not include_edges: + return + model_cls.create_edge_schema(conn) + + def create_edge_schema(model_cls, conn) -> None: + for edge_name, edge_func in get_edges(model_cls).items(): + conn.execute( + _edge_ddl(model_cls, edge_name, get_return_type(edge_func._old)) + ) + + def create_arrow_table(model_cls, conn, rows_or_table, table_name: str = None): + table_name = table_name or _table_name(model_cls) + return conn.create_arrow_table( + table_name, _as_arrow_table(model_cls, rows_or_table) + ) + + def create_arrow_rel_table( + model_cls, + conn, + edge_name: str, + rows_or_table, + dst_cls: Type, + *, + layout="FLAT", + indptr_dataframe=None, + ): + return conn.create_arrow_rel_table( + _rel_table_name(edge_name), + rows_or_table, + _table_name(model_cls), + _table_name(dst_cls), + layout, + indptr_dataframe, + ) + + def query_as_arrow(model_cls, conn, query: str, chunk_size: int = 1024): + return conn.query_as_arrow(query, chunk_size) + + def save(self, conn, *, include_edges: bool = True) -> None: + _execute(conn, _node_create(type(self)), _node_params(self)) + if not include_edges: + return + for edge_name in get_edges(type(self)): + if edge_name not in self: + continue + targets = self[edge_name] + if targets is None: + continue + targets = targets if isinstance(targets, list) else [targets] + for target in targets: + _execute( + conn, + _edge_create(type(self), edge_name, type(target)), + {"src_id": self.id, "dst_id": target.id}, + ) + + cls.__ladybug_table__ = _table_name(cls) + cls.__ladybug_node_ddl__ = _node_ddl(cls) + cls.create_schema = classmethod(create_schema) + cls.create_edge_schema = classmethod(create_edge_schema) + cls.create_arrow_table = classmethod(create_arrow_table) + cls.create_arrow_rel_table = classmethod(create_arrow_rel_table) + cls.query_as_arrow = classmethod(query_as_arrow) + cls.save = save + cls.write = save + return cls diff --git a/fquery/query.py b/fquery/query.py index 1b96850..0d87b0b 100644 --- a/fquery/query.py +++ b/fquery/query.py @@ -72,9 +72,12 @@ def __init__( # from a child query. assert (bool(items) ^ bool(ids)) or child self._items = items + self._match_props = ids if isinstance(ids, dict) else {} if self._items: # pyre-fixme[16]: `ViewModel` has no attribute `id`. self.ids = [item.id for item in self._items] + elif isinstance(ids, dict): + self.ids = [] else: self.ids = ids or [] diff --git a/pyproject.toml b/pyproject.toml index dc328e8..7f6c9fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ django = [ ] malloy = [] cypher = [] +ladybug = [ + "ladybug >= 0.17.0; python_version >= '3.10'", +] test = [ "pytest", ] @@ -86,6 +89,9 @@ django = [ ] malloy = [] cypher = [] +ladybug = [ + "ladybug; python_version >= '3.10'", +] test = [ "pytest", ] @@ -98,6 +104,7 @@ dev = [ { include-group = "polars" }, { include-group = "pyarrow" }, { include-group = "django" }, + { include-group = "ladybug" }, ] [tool.isort] diff --git a/setup.py b/setup.py index 665e9c4..dd8b57f 100755 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ "django": ["django"], "malloy": [], "cypher": [], + "ladybug": ["ladybug; python_version >= '3.10'"], "df": ["polars >= 0.12.0"], }, ) diff --git a/tests/test_ladybug.py b/tests/test_ladybug.py new file mode 100644 index 0000000..45da95f --- /dev/null +++ b/tests/test_ladybug.py @@ -0,0 +1,234 @@ +from typing import List + +from fquery.ladybug import graph, graph_edge, ladybug +from fquery.view_model import edge, node + + +class FakeConnection: + def __init__(self): + self.executed = [] + self.arrow_tables = [] + self.arrow_rel_tables = [] + self.arrow_queries = [] + + def execute(self, query, parameters=None): + self.executed.append((query, parameters)) + return [] + + def create_arrow_table(self, table_name, dataframe): + self.arrow_tables.append((table_name, dataframe)) + return [] + + def create_arrow_rel_table( + self, + table_name, + dataframe, + src_table_name, + dst_table_name, + layout, + indptr_dataframe, + ): + self.arrow_rel_tables.append( + ( + table_name, + dataframe, + src_table_name, + dst_table_name, + layout, + indptr_dataframe, + ) + ) + return [] + + def query_as_arrow(self, query, chunk_size): + self.arrow_queries.append((query, chunk_size)) + return "arrow-result" + + +class FakeArrowTable: + schema = None + + def to_pylist(self): + return [] + + +@graph +class ReviewGraph: + @ladybug + @node + class User: + name: str + age: int + + @edge + async def follows(self) -> List["User"]: + yield self.follows + + @edge + async def reviews(self) -> List["Review"]: + yield self.reviews + + @ladybug + @node + class Review: + business: str + rating: int + + +User = ReviewGraph.User +Review = ReviewGraph.Review + + +def test_ladybug_schema_and_to_cypher(): + conn = FakeConnection() + + User.create_schema(conn) + + assert conn.executed == [ + ( + "CREATE NODE TABLE User(id INT64 PRIMARY KEY, name STRING, age INT64)", + None, + ), + ("CREATE REL TABLE FOLLOWS(FROM User TO User)", None), + ("CREATE REL TABLE REVIEWS(FROM User TO Review)", None), + ] + Review.query() + assert ( + User.query()({"name": "Ada"}) + .edge("reviews") + .project(["business", "rating"]) + .to_cypher() + == "MATCH (u:User {name: 'Ada'})-[:REVIEWS]->(n1:Review)\n" + "RETURN n1.business, n1.rating" + ) + assert ( + User.query()({"name": "Ada"}) + .edge("follows") + .edge("follows") + .project(["name"]) + .to_cypher() + == "MATCH (a:User {name: 'Ada'})-[e:FOLLOWS*2..2]-(b:User)\n" + "RETURN b.name" + ) + + +def test_graph_schema_registration(): + conn = FakeConnection() + + ReviewGraph.create_schema(conn) + + assert conn.executed == [ + ( + "CREATE NODE TABLE User(id INT64 PRIMARY KEY, name STRING, age INT64)", + None, + ), + ( + "CREATE NODE TABLE Review(id INT64 PRIMARY KEY, business STRING, rating INT64)", + None, + ), + ("CREATE REL TABLE FOLLOWS(FROM User TO User)", None), + ("CREATE REL TABLE REVIEWS(FROM User TO Review)", None), + ] + + +def test_ladybug_save_writes_node_and_edges(): + conn = FakeConnection() + review = Review(id=10, business="Cafe", rating=5) + user = User(id=1, name="Ada", age=37) + followed = User(id=2, name="Grace", age=42) + user.follows = [followed] + user.reviews = [review] + + user.save(conn) + + assert conn.executed == [ + ( + "CREATE (n:User {id: $id, name: $name, age: $age})", + {"id": 1, "name": "Ada", "age": 37}, + ), + ( + "MATCH (src:User {id: $src_id}), " + "(dst:User {id: $dst_id}) " + "CREATE (src)-[:FOLLOWS]->(dst)", + {"src_id": 1, "dst_id": 2}, + ), + ( + "MATCH (src:User {id: $src_id}), " + "(dst:Review {id: $dst_id}) " + "CREATE (src)-[:REVIEWS]->(dst)", + {"src_id": 1, "dst_id": 10}, + ), + ] + + +def test_graph_save_writes_nodes_then_edges(): + conn = FakeConnection() + u1 = User(id=1, name="Ada", age=37) + u2 = User(id=2, name="Grace", age=42) + u3 = User(id=3, name="Linus", age=55) + r1 = Review(id=10, business="Cafe", rating=5) + graph = ReviewGraph( + nodes=[u1, u2, u3, r1], + edges=[ + graph_edge("follows", u1, u2), + graph_edge("follows", u2, u3), + graph_edge("reviews", u1, r1), + ], + ) + + graph.save(conn) + + assert conn.executed == [ + ( + "CREATE (n:User {id: $id, name: $name, age: $age})", + {"id": 1, "name": "Ada", "age": 37}, + ), + ( + "CREATE (n:User {id: $id, name: $name, age: $age})", + {"id": 2, "name": "Grace", "age": 42}, + ), + ( + "CREATE (n:User {id: $id, name: $name, age: $age})", + {"id": 3, "name": "Linus", "age": 55}, + ), + ( + "CREATE (n:Review {id: $id, business: $business, rating: $rating})", + {"id": 10, "business": "Cafe", "rating": 5}, + ), + ( + "MATCH (src:User {id: $src_id}), " + "(dst:User {id: $dst_id}) " + "CREATE (src)-[:FOLLOWS]->(dst)", + {"src_id": 1, "dst_id": 2}, + ), + ( + "MATCH (src:User {id: $src_id}), " + "(dst:User {id: $dst_id}) " + "CREATE (src)-[:FOLLOWS]->(dst)", + {"src_id": 2, "dst_id": 3}, + ), + ( + "MATCH (src:User {id: $src_id}), " + "(dst:Review {id: $dst_id}) " + "CREATE (src)-[:REVIEWS]->(dst)", + {"src_id": 1, "dst_id": 10}, + ), + ] + + +def test_ladybug_arrow_helpers(): + conn = FakeConnection() + table = FakeArrowTable() + rel_table = object() + cypher = "MATCH (u:User) RETURN u.name" + + User.create_arrow_table(conn, table) + User.create_arrow_rel_table(conn, "reviews", rel_table, Review) + result = User.query_as_arrow(conn, cypher, chunk_size=64) + + assert conn.arrow_tables == [("User", table)] + assert conn.arrow_rel_tables == [ + ("REVIEWS", rel_table, "User", "Review", "FLAT", None) + ] + assert result == "arrow-result" + assert conn.arrow_queries == [(cypher, 64)] diff --git a/tests/test_ladybug_demo.py b/tests/test_ladybug_demo.py new file mode 100644 index 0000000..c96aace --- /dev/null +++ b/tests/test_ladybug_demo.py @@ -0,0 +1,16 @@ +import pytest + +pytest.importorskip("ladybug") +pytest.importorskip("pyarrow") + + +def test_ladybug_demo_save_path(): + from examples.ladybug_demo import save_demo + + save_demo() + + +def test_ladybug_demo_arrow_memory_path(): + from examples.ladybug_demo import arrow_memory_demo + + arrow_memory_demo()