Skip to content

Commit 0d74f5d

Browse files
refactor: update IPTypes (#1031)
1 parent e087704 commit 0d74f5d

File tree

4 files changed

+72
-34
lines changed

4 files changed

+72
-34
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
2-
ignore = E203, E266, E501, W503, ANN101, ANN401
2+
ignore = E203, E266, E501, W503, ANN101, ANN102, ANN401
33
exclude =
44
# Exclude generated code.
55
**/proto/**

google/cloud/sql/connector/connector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@
4848
class Connector:
4949
"""A class to configure and create connections to Cloud SQL instances.
5050
51-
:type ip_type: IPTypes
51+
:type ip_type: str | IPTypes
5252
:param ip_type
53-
The IP type (public or private) used to connect. IP types
54-
can be either IPTypes.PUBLIC or IPTypes.PRIVATE.
53+
The IP type used to connect. IP types can be either IPTypes.PUBLIC
54+
("PUBLIC"), IPTypes.PRIVATE ("PRIVATE"), or IPTypes.PSC ("PSC").
5555
5656
:type enable_iam_auth: bool
5757
:param enable_iam_auth
@@ -135,7 +135,7 @@ def __init__(
135135
self._user_agent = user_agent
136136
# if ip_type is str, convert to IPTypes enum
137137
if isinstance(ip_type, str):
138-
ip_type = IPTypes._get_ip_type_from_str(ip_type)
138+
ip_type = IPTypes._from_str(ip_type)
139139
self._ip_type = ip_type
140140

141141
def connect(
@@ -257,7 +257,7 @@ async def connect_async(
257257
ip_type = kwargs.pop("ip_type", self._ip_type)
258258
# if ip_type is str, convert to IPTypes enum
259259
if isinstance(ip_type, str):
260-
ip_type = IPTypes._get_ip_type_from_str(ip_type)
260+
ip_type = IPTypes._from_str(ip_type)
261261
kwargs["timeout"] = kwargs.get("timeout", self._timeout)
262262

263263
# Host and ssl options come from the certificates and metadata, so we don't

google/cloud/sql/connector/instance.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,19 @@ class IPTypes(Enum):
6464
PRIVATE: str = "PRIVATE"
6565
PSC: str = "PSC"
6666

67-
@staticmethod
68-
def _get_ip_type_from_str(ip_type_str: str) -> IPTypes:
69-
"""Utility method to convert IP type from a str into IPTypes."""
70-
if ip_type_str.lower() == "public":
71-
ip_type = IPTypes.PUBLIC
72-
elif ip_type_str.lower() == "private":
73-
ip_type = IPTypes.PRIVATE
74-
elif ip_type_str.lower() == "psc":
75-
ip_type = IPTypes.PSC
76-
else:
77-
raise ValueError(
78-
f"Incorrect value for ip_type, got '{ip_type_str}'. Want one of: 'public', 'private' or 'psc'."
79-
)
80-
return ip_type
67+
@classmethod
68+
def _missing_(cls, value: object) -> None:
69+
raise ValueError(
70+
f"Incorrect value for ip_type, got '{value}'. Want one of: "
71+
f"{', '.join([repr(m.value) for m in cls])}, 'PUBLIC'."
72+
)
73+
74+
@classmethod
75+
def _from_str(cls, ip_type_str: str) -> IPTypes:
76+
"""Convert IP type from a str into IPTypes."""
77+
if ip_type_str.upper() == "PUBLIC":
78+
ip_type_str = "PRIMARY"
79+
return cls(ip_type_str.upper())
8180

8281

8382
class ConnectionInfo:

tests/unit/test_connector.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
limitations under the License.
1515
"""
1616
import asyncio
17+
from typing import Union
1718

1819
from google.auth.credentials import Credentials
1920
from mock import patch
@@ -131,20 +132,56 @@ async def test_Connector_Init_async_context_manager(
131132

132133

133134
@pytest.mark.parametrize(
134-
"ip_type",
135-
["public", "private", "psc", "PUBLIC", "PRIVATE", "PSC"],
135+
"ip_type, expected",
136+
[
137+
(
138+
"private",
139+
IPTypes.PRIVATE,
140+
),
141+
(
142+
"PRIVATE",
143+
IPTypes.PRIVATE,
144+
),
145+
(
146+
IPTypes.PRIVATE,
147+
IPTypes.PRIVATE,
148+
),
149+
(
150+
"public",
151+
IPTypes.PUBLIC,
152+
),
153+
(
154+
"PUBLIC",
155+
IPTypes.PUBLIC,
156+
),
157+
(
158+
IPTypes.PUBLIC,
159+
IPTypes.PUBLIC,
160+
),
161+
(
162+
"psc",
163+
IPTypes.PSC,
164+
),
165+
(
166+
"PSC",
167+
IPTypes.PSC,
168+
),
169+
(
170+
IPTypes.PSC,
171+
IPTypes.PSC,
172+
),
173+
],
136174
)
137-
def test_Connector_Init_ip_type_str(
138-
ip_type: str, fake_credentials: Credentials
175+
def test_Connector_init_ip_type(
176+
ip_type: Union[str, IPTypes], expected: IPTypes, fake_credentials: Credentials
139177
) -> None:
140-
"""Test that Connector properly sets ip_type when given str."""
141-
with Connector(ip_type=ip_type, credentials=fake_credentials) as connector:
142-
if ip_type.lower() == "public":
143-
assert connector._ip_type == IPTypes.PUBLIC
144-
if ip_type.lower() == "private":
145-
assert connector._ip_type == IPTypes.PRIVATE
146-
if ip_type.lower() == "psc":
147-
assert connector._ip_type == IPTypes.PSC
178+
"""
179+
Test to check whether the __init__ method of Connector
180+
properly sets ip_type.
181+
"""
182+
connector = Connector(credentials=fake_credentials, ip_type=ip_type)
183+
assert connector._ip_type == expected
184+
connector.close()
148185

149186

150187
def test_Connector_Init_bad_ip_type(fake_credentials: Credentials) -> None:
@@ -154,7 +191,8 @@ def test_Connector_Init_bad_ip_type(fake_credentials: Credentials) -> None:
154191
Connector(ip_type=bad_ip_type, credentials=fake_credentials)
155192
assert (
156193
exc_info.value.args[0]
157-
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'public', 'private' or 'psc'."
194+
== f"Incorrect value for ip_type, got '{bad_ip_type.upper()}'. "
195+
"Want one of: 'PRIMARY', 'PRIVATE', 'PSC', 'PUBLIC'."
158196
)
159197

160198

@@ -176,7 +214,8 @@ def test_Connector_connect_bad_ip_type(
176214
)
177215
assert (
178216
exc_info.value.args[0]
179-
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'public', 'private' or 'psc'."
217+
== f"Incorrect value for ip_type, got '{bad_ip_type.upper()}'. "
218+
"Want one of: 'PRIMARY', 'PRIVATE', 'PSC', 'PUBLIC'."
180219
)
181220

182221

0 commit comments

Comments
 (0)