Skip to content

Commit f9576f3

Browse files
fix: remove enable_iam_auth from downstream kwargs and catch error (#273)
1 parent cdfcc72 commit f9576f3

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,18 @@ def connect(
108108
# Use the InstanceConnectionManager to establish an SSL Connection.
109109
#
110110
# Return a DBAPI connection
111-
111+
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
112112
if instance_connection_string in self._instances:
113113
icm = self._instances[instance_connection_string]
114+
if enable_iam_auth != icm._enable_iam_auth:
115+
raise ValueError(
116+
"connect() called with `enable_iam_auth={}`, but previously used "
117+
"enable_iam_auth={}`. If you require both for your use case, "
118+
"please use a new connector.Connector object.".format(
119+
enable_iam_auth, icm._enable_iam_auth
120+
)
121+
)
114122
else:
115-
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
116123
icm = InstanceConnectionManager(
117124
instance_connection_string,
118125
driver,

tests/unit/test_connector.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,25 @@
2727
from typing import Any
2828

2929

30+
class MockInstanceConnectionManager:
31+
_enable_iam_auth: bool
32+
33+
def __init__(
34+
self,
35+
enable_iam_auth: bool = False,
36+
) -> None:
37+
self._enable_iam_auth = enable_iam_auth
38+
39+
def connect(
40+
self,
41+
driver: str,
42+
ip_type: IPTypes,
43+
timeout: int,
44+
**kwargs: Any,
45+
) -> Any:
46+
return True
47+
48+
3049
def test_connect_timeout(
3150
fake_credentials: Credentials, async_loop: asyncio.AbstractEventLoop
3251
) -> None:
@@ -59,6 +78,31 @@ async def timeout_stub(*args: Any, **kwargs: Any) -> None:
5978
)
6079

6180

81+
def test_connect_enable_iam_auth_error() -> None:
82+
"""Test that calling connect() with different enable_iam_auth
83+
argument values throws error."""
84+
connect_string = "my-project:my-region:my-instance"
85+
default_connector = connector.Connector()
86+
with patch(
87+
"google.cloud.sql.connector.connector.InstanceConnectionManager"
88+
) as mock_icm:
89+
mock_icm.return_value = MockInstanceConnectionManager(enable_iam_auth=False)
90+
conn = default_connector.connect(
91+
connect_string,
92+
"pg8000",
93+
enable_iam_auth=False,
94+
)
95+
assert conn is True
96+
# try to connect using enable_iam_auth=True, should raise error
97+
pytest.raises(
98+
ValueError,
99+
default_connector.connect,
100+
connect_string,
101+
"pg8000",
102+
enable_iam_auth=True,
103+
)
104+
105+
62106
def test_default_Connector_Init() -> None:
63107
"""Test that default Connector __init__ sets properties properly."""
64108
default_connector = connector.Connector()

0 commit comments

Comments
 (0)