Skip to content

Commit e087704

Browse files
feat: support ip_type as str (#1029)
1 parent 5b14e33 commit e087704

File tree

7 files changed

+85
-13
lines changed

7 files changed

+85
-13
lines changed

README.md

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ from google.cloud.sql.connector import Connector, IPTypes
188188

189189
# Note: all parameters below are optional
190190
connector = Connector(
191-
ip_type=IPTypes.PUBLIC,
191+
ip_type="public", # can also be "private" or "psc"
192192
enable_iam_auth=False,
193193
timeout=30,
194194
credentials=custom_creds # google.auth.credentials.Credentials
@@ -261,18 +261,16 @@ using both public and private IP addresses, as well as
261261
with, set the `ip_type` keyword argument when initializing a `Connector()` or when
262262
calling `connector.connect()`.
263263

264-
Possible values for `ip_type` are `IPTypes.PUBLIC` (default value),
265-
`IPTypes.PRIVATE`, and `IPTypes.PSC`.
264+
Possible values for `ip_type` are `"public"` (default value),
265+
`"private"`, and `"psc"`.
266266

267267
Example:
268268

269269
```python
270-
from google.cloud.sql.connector import IPTypes
271-
272270
conn = connector.connect(
273271
"project:region:instance",
274272
"pymysql",
275-
ip_type=IPTypes.PRIVATE # use private IP
273+
ip_type="private" # use private IP
276274
... insert other kwargs ...
277275
)
278276
```
@@ -333,7 +331,7 @@ conn = connector.connect(
333331
db="my-db-name",
334332
active_directory_auth=True,
335333
server_name="private.[instance].[location].[project].cloudsql.[domain]",
336-
ip_type=IPTypes.PRIVATE
334+
ip_type="private"
337335
)
338336
```
339337

@@ -358,7 +356,7 @@ your web application through the following:
358356
```python
359357
from flask import Flask
360358
from flask_sqlalchemy import SQLAlchemy
361-
from google.cloud.sql.connector import Connector, IPTypes
359+
from google.cloud.sql.connector import Connector
362360

363361

364362
# initialize Python Connector object
@@ -372,7 +370,7 @@ def getconn():
372370
user="my-user",
373371
password="my-password",
374372
db="my-database",
375-
ip_type= IPTypes.PUBLIC # IPTypes.PRIVATE for private IP
373+
ip_type="public" # "private" for private IP
376374
)
377375
return conn
378376

@@ -407,7 +405,7 @@ from sqlalchemy import create_engine
407405
from sqlalchemy.engine import Engine
408406
from sqlalchemy.ext.declarative import declarative_base
409407
from sqlalchemy.orm import sessionmaker
410-
from google.cloud.sql.connector import Connector, IPTypes
408+
from google.cloud.sql.connector import Connector
411409

412410
# helper function to return SQLAlchemy connection pool
413411
def init_connection_pool(connector: Connector) -> Engine:
@@ -419,7 +417,7 @@ def init_connection_pool(connector: Connector) -> Engine:
419417
user="my-user",
420418
password="my-password",
421419
db="my-database",
422-
ip_type= IPTypes.PUBLIC # IPTypes.PRIVATE for private IP
420+
ip_type="public" # "private" for private IP
423421
)
424422
return conn
425423

google/cloud/sql/connector/connector.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class Connector:
8787

8888
def __init__(
8989
self,
90-
ip_type: IPTypes = IPTypes.PUBLIC,
90+
ip_type: str | IPTypes = IPTypes.PUBLIC,
9191
enable_iam_auth: bool = False,
9292
timeout: int = 30,
9393
credentials: Optional[Credentials] = None,
@@ -130,10 +130,13 @@ def __init__(
130130
# set default params for connections
131131
self._timeout = timeout
132132
self._enable_iam_auth = enable_iam_auth
133-
self._ip_type = ip_type
134133
self._quota_project = quota_project
135134
self._sqladmin_api_endpoint = sqladmin_api_endpoint
136135
self._user_agent = user_agent
136+
# if ip_type is str, convert to IPTypes enum
137+
if isinstance(ip_type, str):
138+
ip_type = IPTypes._get_ip_type_from_str(ip_type)
139+
self._ip_type = ip_type
137140

138141
def connect(
139142
self, instance_connection_string: str, driver: str, **kwargs: Any
@@ -252,6 +255,9 @@ async def connect_async(
252255
raise KeyError(f"Driver '{driver}' is not supported.")
253256

254257
ip_type = kwargs.pop("ip_type", self._ip_type)
258+
# if ip_type is str, convert to IPTypes enum
259+
if isinstance(ip_type, str):
260+
ip_type = IPTypes._get_ip_type_from_str(ip_type)
255261
kwargs["timeout"] = kwargs.get("timeout", self._timeout)
256262

257263
# Host and ssl options come from the certificates and metadata, so we don't

google/cloud/sql/connector/instance.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,21 @@ class IPTypes(Enum):
6464
PRIVATE: str = "PRIVATE"
6565
PSC: str = "PSC"
6666

67+
@staticmethod
68+
def _get_ip_type_from_str(ip_type_str: str) -> IPTypes:
69+
"""Utility method to convert IP type from a str into IPTypes."""
70+
if ip_type_str.lower() == "public":
71+
ip_type = IPTypes.PUBLIC
72+
elif ip_type_str.lower() == "private":
73+
ip_type = IPTypes.PRIVATE
74+
elif ip_type_str.lower() == "psc":
75+
ip_type = IPTypes.PSC
76+
else:
77+
raise ValueError(
78+
f"Incorrect value for ip_type, got '{ip_type_str}'. Want one of: 'public', 'private' or 'psc'."
79+
)
80+
return ip_type
81+
6782

6883
class ConnectionInfo:
6984
ip_addrs: Dict[str, Any]

tests/system/test_pg8000_connection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def getconn() -> pg8000.dbapi.Connection:
4242
user=os.environ["POSTGRES_USER"],
4343
password=os.environ["POSTGRES_PASS"],
4444
db=os.environ["POSTGRES_DB"],
45+
ip_type="public", # can also be "private" or "psc"
4546
)
4647
return conn
4748

tests/system/test_pymysql_connection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def getconn() -> pymysql.connections.Connection:
4242
user=os.environ["MYSQL_USER"],
4343
password=os.environ["MYSQL_PASS"],
4444
db=os.environ["MYSQL_DB"],
45+
ip_type="public", # can also be "private" or "psc"
4546
)
4647
return conn
4748

tests/system/test_pytds_connection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def getconn() -> pytds.Connection:
4343
user=os.environ["SQLSERVER_USER"],
4444
password=os.environ["SQLSERVER_PASS"],
4545
db=os.environ["SQLSERVER_DB"],
46+
ip_type="public", # can also be "private" or "psc"
4647
)
4748
return conn
4849

tests/unit/test_connector.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,56 @@ async def test_Connector_Init_async_context_manager(
130130
assert connector._loop == loop
131131

132132

133+
@pytest.mark.parametrize(
134+
"ip_type",
135+
["public", "private", "psc", "PUBLIC", "PRIVATE", "PSC"],
136+
)
137+
def test_Connector_Init_ip_type_str(
138+
ip_type: str, fake_credentials: Credentials
139+
) -> None:
140+
"""Test that Connector properly sets ip_type when given str."""
141+
with Connector(ip_type=ip_type, credentials=fake_credentials) as connector:
142+
if ip_type.lower() == "public":
143+
assert connector._ip_type == IPTypes.PUBLIC
144+
if ip_type.lower() == "private":
145+
assert connector._ip_type == IPTypes.PRIVATE
146+
if ip_type.lower() == "psc":
147+
assert connector._ip_type == IPTypes.PSC
148+
149+
150+
def test_Connector_Init_bad_ip_type(fake_credentials: Credentials) -> None:
151+
"""Test that Connector errors due to bad ip_type str."""
152+
bad_ip_type = "bad-ip-type"
153+
with pytest.raises(ValueError) as exc_info:
154+
Connector(ip_type=bad_ip_type, credentials=fake_credentials)
155+
assert (
156+
exc_info.value.args[0]
157+
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'public', 'private' or 'psc'."
158+
)
159+
160+
161+
def test_Connector_connect_bad_ip_type(
162+
fake_credentials: Credentials, fake_client: CloudSQLClient
163+
) -> None:
164+
"""Test that Connector.connect errors due to bad ip_type str."""
165+
with Connector(credentials=fake_credentials) as connector:
166+
connector._client = fake_client
167+
bad_ip_type = "bad-ip-type"
168+
with pytest.raises(ValueError) as exc_info:
169+
connector.connect(
170+
"test-project:test-region:test-instance",
171+
"pg8000",
172+
user="my-user",
173+
password="my-pass",
174+
db="my-db",
175+
ip_type=bad_ip_type,
176+
)
177+
assert (
178+
exc_info.value.args[0]
179+
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'public', 'private' or 'psc'."
180+
)
181+
182+
133183
@pytest.mark.asyncio
134184
async def test_Connector_connect_async(
135185
fake_credentials: Credentials, fake_client: CloudSQLClient

0 commit comments

Comments
 (0)