Skip to content

Commit 657e76c

Browse files
committed
Implement graph construct for cypher endpoints
1 parent e96d0f3 commit 657e76c

File tree

11 files changed

+227
-49
lines changed

11 files changed

+227
-49
lines changed

graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from types import TracebackType
55
from typing import NamedTuple, Type
66

7+
from pandas import DataFrame
8+
79
from graphdatascience.procedure_surface.api.base_result import BaseResult
810
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
911
from graphdatascience.procedure_surface.api.catalog.graph_info import GraphInfo, GraphInfoWithDegrees
@@ -14,31 +16,76 @@
1416

1517

1618
class CatalogEndpoints(ABC):
19+
@abstractmethod
20+
def construct(
21+
self,
22+
graph_name: str,
23+
nodes: DataFrame | list[DataFrame],
24+
relationships: DataFrame | list[DataFrame] | None = None,
25+
concurrency: int = 4,
26+
undirected_relationship_types: list[str] | None = None,
27+
) -> GraphV2:
28+
"""Construct a graph from a list of node and relationship dataframes.
29+
30+
Parameters
31+
----------
32+
graph_name
33+
Name of the graph to construct
34+
nodes
35+
Node dataframes. A dataframe should follow the schema:
36+
37+
- `nodeId` to identify uniquely the node overall dataframes
38+
- `labels` to specify the labels of the node as a list of strings (optional)
39+
- other columns are treated as node properties
40+
relationships
41+
Relationship dataframes. A dataframe should follow the schema:
42+
43+
- `sourceNodeId` to identify the start node of the relationship
44+
- `targetNodeId` to identify the end node of the relationship
45+
- `relationshipType` to specify the type of the relationship (optional)
46+
- other columns are treated as relationship properties
47+
concurrency
48+
Number of concurrent threads to use.
49+
undirected_relationship_types
50+
List of relationship types to treat as undirected.
51+
52+
Returns
53+
-------
54+
GraphV2
55+
Constructed graph object.
56+
"""
57+
1758
@abstractmethod
1859
def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]:
1960
"""List graphs in the graph catalog.
2061
21-
Args:
22-
G (GraphV2 | str | None, optional): GraphV2 object or name to filter results.
23-
If None, list all graphs. Defaults to None.
62+
Parameters
63+
----------
64+
G
65+
GraphV2 object or name to filter results. If None, list all graphs.
2466
25-
Returns:
26-
list[GraphListResult]: List of graph metadata objects containing information like
27-
graph name, node count, relationship count, etc.
67+
Returns
68+
-------
69+
list[GraphInfoWithDegrees]
70+
List of graph metadata objects containing information like node count.
2871
"""
2972
pass
3073

3174
@abstractmethod
3275
def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | None:
3376
"""Drop a graph from the graph catalog.
3477
35-
Args:
36-
G (GraphV2 | str): GraphV2 object or name to drop.
37-
fail_if_missing (bool): Whether to fail if the graph is missing. Defaults to True.
78+
Parameters
79+
----------
80+
G
81+
Graph to drop by name of object.
82+
fail_if_missing
83+
Whether to fail if the graph is missing
3884
39-
Returns:
40-
GraphListResult: GraphV2 metadata object containing information like
41-
graph name, node count, relationship count, etc.
85+
Returns
86+
-------
87+
GraphListResult
88+
GraphV2 metadata object containing information like node count.
4289
"""
4390

4491
@abstractmethod
@@ -68,9 +115,10 @@ def filter(
68115
job_id
69116
Identifier for the computation.
70117
71-
Returns:
72-
GraphWithFilterResult: tuple of the filtered graph object and the information like
73-
graph name, node count, relationship count, etc.
118+
Returns
119+
-------
120+
GraphWithFilterResult:
121+
tuple of the filtered graph object and the information like graph name, node count, relationship count, etc.
74122
"""
75123
pass
76124

graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import Any, NamedTuple, Type
66
from uuid import uuid4
77

8+
from pandas import DataFrame
9+
810
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
911
from graphdatascience.arrow_client.v2.job_client import JobClient
1012
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
@@ -52,15 +54,6 @@ def __init__(
5254
protocol_version = ProtocolVersionResolver(query_runner).resolve()
5355
self._project_protocol = ProjectProtocol.select(protocol_version)
5456

55-
def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]:
56-
graph_name: str | None = None
57-
if isinstance(G, GraphV2):
58-
graph_name = G.name()
59-
elif isinstance(G, str):
60-
graph_name = G
61-
62-
return self._graph_backend.list(graph_name)
63-
6457
def project(
6558
self,
6659
graph_name: str,
@@ -137,6 +130,16 @@ def project(
137130

138131
return GraphWithProjectResult(get_graph(graph_name, self._arrow_client), job_result)
139132

133+
def construct(
134+
self,
135+
graph_name: str,
136+
nodes: DataFrame | list[DataFrame],
137+
relationships: DataFrame | list[DataFrame] | None = None,
138+
concurrency: int = 4,
139+
undirected_relationship_types: list[str] | None = None,
140+
) -> GraphV2:
141+
raise NotImplementedError("Graph construction is not yet supported via V2 endpoints.")
142+
140143
def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | None:
141144
graph_name = G.name() if isinstance(G, GraphV2) else G
142145

@@ -212,6 +215,15 @@ def generate(
212215
GraphGenerationStats(**JobClient.get_summary(self._arrow_client, job_id)),
213216
)
214217

218+
def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]:
219+
graph_name: str | None = None
220+
if isinstance(G, GraphV2):
221+
graph_name = G.name()
222+
elif isinstance(G, str):
223+
graph_name = G
224+
225+
return self._graph_backend.list(graph_name)
226+
215227
@property
216228
def sample(self) -> GraphSamplingEndpoints:
217229
return GraphSamplingArrowEndpoints(self._arrow_client, show_progress=self._show_progress)

graphdatascience/procedure_surface/cypher/catalog/node_properties_cypher_endpoints.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
NodePropertySpec,
1111
)
1212
from graphdatascience.procedure_surface.api.default_values import ALL_LABELS
13+
from graphdatascience.procedure_surface.cypher.catalog.utils import require_database
1314
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
1415
from graphdatascience.procedure_surface.utils.result_utils import join_db_node_properties, transpose_property_columns
1516
from graphdatascience.query_runner.query_runner import QueryRunner
@@ -35,9 +36,7 @@ def stream(
3536
db_node_properties: list[str] | None = None,
3637
) -> DataFrame:
3738
if self._gds_arrow_client is not None:
38-
database = self._query_runner.database()
39-
if database is None:
40-
raise ValueError("The database is not set")
39+
database = require_database(self._query_runner)
4140

4241
result = self._gds_arrow_client.get_node_properties(
4342
G.name(), database, node_properties, node_labels, list_node_labels or False, concurrency

graphdatascience/procedure_surface/cypher/catalog/relationship_cypher_endpoints.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
RelationshipsWriteResult,
1515
)
1616
from graphdatascience.procedure_surface.api.default_values import ALL_TYPES
17+
from graphdatascience.procedure_surface.cypher.catalog.utils import require_database
1718
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
1819

1920

@@ -36,9 +37,7 @@ def stream(
3637
effective_rel_types = relationship_types if relationship_types is not None else ["*"]
3738

3839
if self._gds_arrow_client is not None:
39-
database = self._query_runner.database()
40-
if database is None:
41-
raise ValueError("The database is not set")
40+
database = require_database(self._query_runner)
4241

4342
if relationship_properties:
4443
return self._gds_arrow_client.get_relationship_properties(
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from graphdatascience.query_runner.query_runner import QueryRunner
2+
3+
4+
def require_database(query_runner: QueryRunner) -> str:
5+
database = query_runner.database()
6+
if database is None:
7+
raise ValueError(
8+
"For this call you must have explicitly specified a valid Neo4j database to target, "
9+
"using `gds.set_database`."
10+
)
11+
12+
return database

graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from types import TracebackType
55
from typing import Any, NamedTuple, Type
66

7+
from pandas import DataFrame
8+
9+
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
710
from graphdatascience.procedure_surface.api.catalog.catalog_endpoints import (
811
CatalogEndpoints,
912
GraphFilterResult,
@@ -15,7 +18,11 @@
1518
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
1619
from graphdatascience.procedure_surface.api.catalog.graph_info import GraphInfo, GraphInfoWithDegrees
1720
from graphdatascience.procedure_surface.api.catalog.graph_sampling_endpoints import GraphSamplingEndpoints
18-
from graphdatascience.procedure_surface.cypher.catalog.graph_backend_cypher import get_graph
21+
from graphdatascience.procedure_surface.cypher.catalog.graph_backend_cypher import CypherGraphBackend, get_graph
22+
from graphdatascience.procedure_surface.cypher.catalog.utils import require_database
23+
from graphdatascience.query_runner.arrow_graph_constructor import ArrowGraphConstructor
24+
from graphdatascience.query_runner.cypher_graph_constructor import CypherGraphConstructor
25+
from graphdatascience.query_runner.graph_constructor import GraphConstructor
1926

2027
from ...call_parameters import CallParameters
2128
from ...query_runner.query_runner import QueryRunner
@@ -28,8 +35,46 @@
2835

2936

3037
class CatalogCypherEndpoints(CatalogEndpoints):
31-
def __init__(self, query_runner: QueryRunner):
38+
def __init__(self, query_runner: QueryRunner, arrow_client: GdsArrowClient | None = None):
3239
self._query_runner = query_runner
40+
self._arrow_client = arrow_client
41+
42+
def construct(
43+
self,
44+
graph_name: str,
45+
nodes: DataFrame | list[DataFrame],
46+
relationships: DataFrame | list[DataFrame] | None = None,
47+
concurrency: int | None = None,
48+
undirected_relationship_types: list[str] | None = None,
49+
) -> GraphV2:
50+
if isinstance(nodes, DataFrame):
51+
nodes = [nodes]
52+
if relationships is None:
53+
relationships = []
54+
elif isinstance(relationships, DataFrame):
55+
relationships = [relationships]
56+
57+
graph_constructor: GraphConstructor
58+
if self._arrow_client is not None:
59+
database = require_database(self._query_runner)
60+
61+
graph_constructor = ArrowGraphConstructor(
62+
database=database,
63+
graph_name=graph_name,
64+
flight_client=self._arrow_client,
65+
concurrency=concurrency,
66+
undirected_relationship_types=undirected_relationship_types,
67+
)
68+
else:
69+
graph_constructor = CypherGraphConstructor(
70+
query_runner=self._query_runner,
71+
graph_name=graph_name,
72+
concurrency=concurrency,
73+
undirected_relationship_types=undirected_relationship_types,
74+
)
75+
76+
graph_constructor.run(node_dfs=nodes, relationship_dfs=relationships)
77+
return GraphV2(name=graph_name, backend=CypherGraphBackend(graph_name, self._query_runner))
3378

3479
def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]:
3580
graph_name = G if isinstance(G, str) else G.name() if G is not None else None

graphdatascience/query_runner/arrow_graph_constructor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def __init__(
2222
database: str,
2323
graph_name: str,
2424
flight_client: GdsArrowClient,
25-
concurrency: int,
26-
undirected_relationship_types: list[str] | None,
25+
concurrency: int | None = None,
26+
undirected_relationship_types: list[str] | None = None,
2727
chunk_size: int = 10_000,
2828
):
2929
self._database = database

graphdatascience/query_runner/cypher_graph_constructor.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,13 @@ def __init__(
5858
self,
5959
query_runner: QueryRunner,
6060
graph_name: str,
61-
concurrency: int,
62-
undirected_relationship_types: list[str] | None,
63-
server_version: ServerVersion,
61+
concurrency: int | None = None,
62+
undirected_relationship_types: list[str] | None = None,
6463
):
6564
self._query_runner = query_runner
6665
self._concurrency = concurrency
6766
self._graph_name = graph_name
68-
self._server_version = server_version
67+
self._server_version = query_runner.server_version()
6968
self._undirected_relationship_types = undirected_relationship_types
7069

7170
def run(self, node_dfs: list[DataFrame], relationship_dfs: list[DataFrame]) -> None:
@@ -81,9 +80,9 @@ def run(self, node_dfs: list[DataFrame], relationship_dfs: list[DataFrame]) -> N
8180
self.CypherProjectionRunner(
8281
self._query_runner,
8382
self._graph_name,
83+
self._server_version,
8484
self._concurrency,
8585
self._undirected_relationship_types,
86-
self._server_version,
8786
).run(node_dfs, relationship_dfs)
8887
else:
8988
assert not self._undirected_relationship_types, "This should have been raised earlier."
@@ -130,9 +129,9 @@ def __init__(
130129
self,
131130
query_runner: QueryRunner,
132131
graph_name: str,
133-
concurrency: int,
134-
undirected_relationship_types: list[str] | None,
135132
server_version: ServerVersion,
133+
concurrency: int | None = None,
134+
undirected_relationship_types: list[str] | None = None,
136135
):
137136
self._query_runner = query_runner
138137
self._concurrency = concurrency
@@ -359,9 +358,9 @@ def rels_config_part(self, rel_cols: list[EntityColumnSchema], rel_properties_ke
359358
return rels_config_fields
360359

361360
class LegacyCypherProjectionRunner:
362-
def __init__(self, query_runner: QueryRunner, graph_name: str, concurrency: int):
361+
def __init__(self, query_runner: QueryRunner, graph_name: str, concurrency: int | None = None):
363362
self._query_runner = query_runner
364-
self._concurrency = concurrency
363+
self._concurrency = concurrency if concurrency is not None else 4
365364
self._graph_name = graph_name
366365

367366
def run(self, node_df: DataFrame, relationship_df: DataFrame) -> None:

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,7 @@ def __del__(self) -> None:
375375
def create_graph_constructor(
376376
self, graph_name: str, concurrency: int, undirected_relationship_types: list[str] | None
377377
) -> GraphConstructor:
378-
return CypherGraphConstructor(
379-
self, graph_name, concurrency, undirected_relationship_types, self.server_version()
380-
)
378+
return CypherGraphConstructor(self, graph_name, concurrency, undirected_relationship_types)
381379

382380
def set_show_progress(self, show_progress: bool) -> None:
383381
self._show_progress = show_progress

0 commit comments

Comments
 (0)