Skip to content

Commit 1d91915

Browse files
committed
Use more generic semantic_version for non gds version
1 parent 91566dd commit 1d91915

File tree

3 files changed

+53
-9
lines changed

3 files changed

+53
-9
lines changed

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from pyarrow.types import is_dictionary
1919
from tenacity import retry, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential
2020

21-
from graphdatascience.server_version.server_version import ServerVersion
22-
21+
from ..semantic_version.semantic_version import SemanticVersion
2322
from ..version import __version__
2423
from .arrow_endpoint_version import ArrowEndpointVersion
2524
from .arrow_info import ArrowInfo
@@ -676,7 +675,7 @@ def _do_get(
676675
message=r"Passing a BlockManager to DataFrame is deprecated",
677676
)
678677

679-
if ServerVersion.from_string(pandas.__version__) >= ServerVersion(2, 0, 0):
678+
if SemanticVersion.from_string(pandas.__version__) >= SemanticVersion(2, 0, 0):
680679
return arrow_table.to_pandas(types_mapper=pandas.ArrowDtype) # type: ignore
681680
else:
682681
arrow_table = self._sanitize_arrow_table(arrow_table)

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ..error.endpoint_suggester import generate_suggestive_error_message
1414
from ..error.gds_not_installed import GdsNotFound
1515
from ..error.unable_to_connect import UnableToConnectError
16+
from ..semantic_version.semantic_version import SemanticVersion
1617
from ..server_version.server_version import ServerVersion
1718
from ..version import __version__
1819
from .cypher_graph_constructor import CypherGraphConstructor
@@ -24,7 +25,7 @@
2425
class Neo4jQueryRunner(QueryRunner):
2526
_AURA_DS_PROTOCOL = "neo4j+s"
2627
_LOG_POLLING_INTERVAL = 0.5
27-
_NEO4J_DRIVER_VERSION = ServerVersion.from_string(neo4j.__version__)
28+
_NEO4J_DRIVER_VERSION = SemanticVersion.from_string(neo4j.__version__)
2829

2930
@staticmethod
3031
def create_for_db(
@@ -60,7 +61,7 @@ def create_for_db(
6061
else:
6162
raise ValueError(f"Invalid endpoint type: {type(endpoint)}")
6263

63-
if Neo4jQueryRunner._NEO4J_DRIVER_VERSION >= ServerVersion(5, 21, 0):
64+
if Neo4jQueryRunner._NEO4J_DRIVER_VERSION >= SemanticVersion(5, 21, 0):
6465
notifications_logger = logging.getLogger("neo4j.notifications")
6566
# the client does not expose YIELD fields so we just skip these warnings for now
6667
notifications_logger.addFilter(
@@ -93,7 +94,7 @@ def create_for_session(
9394
instance_description="GDS Session",
9495
)
9596

96-
if Neo4jQueryRunner._NEO4J_DRIVER_VERSION >= ServerVersion(5, 21, 0):
97+
if Neo4jQueryRunner._NEO4J_DRIVER_VERSION >= SemanticVersion(5, 21, 0):
9798
notifications_logger = logging.getLogger("neo4j.notifications")
9899
# the client does not expose YIELD fields so we just skip these warnings for now
99100
notifications_logger.addFilter(
@@ -175,13 +176,13 @@ def run_cypher(
175176

176177
df = result.to_df()
177178

178-
if self._NEO4J_DRIVER_VERSION < ServerVersion(5, 0, 0):
179+
if self._NEO4J_DRIVER_VERSION < SemanticVersion(5, 0, 0):
179180
self._last_bookmarks = [session.last_bookmark()]
180181
else:
181182
self._last_bookmarks = session.last_bookmarks()
182183

183184
if (
184-
Neo4jQueryRunner._NEO4J_DRIVER_VERSION >= ServerVersion(5, 21, 0)
185+
Neo4jQueryRunner._NEO4J_DRIVER_VERSION >= SemanticVersion(5, 21, 0)
185186
and result._warn_notification_severity == "WARNING"
186187
):
187188
# the client does not expose YIELD fields so we just skip these warnings for now
@@ -342,7 +343,7 @@ def _verify_connectivity(
342343
retrys = 0
343344
while retrys < retry_config.max_retries:
344345
try:
345-
if self._NEO4J_DRIVER_VERSION < ServerVersion(5, 0, 0):
346+
if self._NEO4J_DRIVER_VERSION < SemanticVersion(5, 0, 0):
346347
warnings.filterwarnings(
347348
"ignore",
348349
category=neo4j.ExperimentalWarning,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from __future__ import annotations
2+
3+
import re
4+
5+
6+
class InvalidServerVersionError(Exception):
7+
pass
8+
9+
10+
class SemanticVersion:
11+
"""
12+
A representation of a semantic version, such as for python packages.
13+
"""
14+
15+
def __init__(self, major: int, minor: int, patch: int):
16+
self.major = major
17+
self.minor = minor
18+
self.patch = patch
19+
20+
@classmethod
21+
def from_string(cls, version: str) -> SemanticVersion:
22+
server_version_match = re.search(r"^(\d+)\.(\d+)\.(\d+)", version)
23+
if not server_version_match:
24+
raise InvalidServerVersionError(f"{version} is not a valid semantic version")
25+
26+
return cls(*map(int, server_version_match.groups()))
27+
28+
def __lt__(self, other: SemanticVersion) -> bool:
29+
if self.major != other.major:
30+
return self.major < other.major
31+
32+
if self.minor != other.minor:
33+
return self.minor < other.minor
34+
35+
if self.patch != other.patch:
36+
return self.patch < other.patch
37+
38+
return False
39+
40+
def __ge__(self, other: SemanticVersion) -> bool:
41+
return not self.__lt__(other)
42+
43+
def __str__(self) -> str:
44+
return f"{self.major}.{self.minor}.{self.patch}"

0 commit comments

Comments
 (0)