Skip to content

Commit d4d9b15

Browse files
chore: refactor credentials to be initialized in Connector (#995)
1 parent 811f661 commit d4d9b15

File tree

8 files changed

+146
-195
lines changed

8 files changed

+146
-195
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
import socket
2222
from threading import Thread
2323
from types import TracebackType
24-
from typing import Any, Dict, Optional, Type, TYPE_CHECKING
24+
from typing import Any, Dict, Optional, Type
25+
26+
import google.auth
27+
from google.auth.credentials import Credentials
28+
from google.auth.credentials import with_scopes_if_required
2529

2630
import google.cloud.sql.connector.asyncpg as asyncpg
2731
from google.cloud.sql.connector.exceptions import ConnectorLoopError
@@ -34,9 +38,6 @@
3438
from google.cloud.sql.connector.utils import format_database_user
3539
from google.cloud.sql.connector.utils import generate_keys
3640

37-
if TYPE_CHECKING:
38-
from google.auth.credentials import Credentials
39-
4041
logger = logging.getLogger(name=__name__)
4142

4243
ASYNC_DRIVERS = ["asyncpg"]
@@ -109,13 +110,26 @@ def __init__(
109110
)
110111
self._instances: Dict[str, Instance] = {}
111112

113+
# initialize credentials
114+
scopes = ["https://www.googleapis.com/auth/sqlservice.admin"]
115+
if credentials:
116+
# verfiy custom credentials are proper type
117+
# and atleast base class of google.auth.credentials
118+
if not isinstance(credentials, Credentials):
119+
raise TypeError(
120+
"credentials must be of type google.auth.credentials.Credentials,"
121+
f" got {type(credentials)}"
122+
)
123+
self._credentials = with_scopes_if_required(credentials, scopes=scopes)
124+
# otherwise use application default credentials
125+
else:
126+
self._credentials, _ = google.auth.default(scopes=scopes)
112127
# set default params for connections
113128
self._timeout = timeout
114129
self._enable_iam_auth = enable_iam_auth
115130
self._ip_type = ip_type
116131
self._quota_project = quota_project
117132
self._sqladmin_api_endpoint = sqladmin_api_endpoint
118-
self._credentials = credentials
119133
self._user_agent = user_agent
120134

121135
def connect(

google/cloud/sql/connector/exceptions.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,6 @@ class PlatformNotSupportedError(Exception):
4848
pass
4949

5050

51-
class CredentialsTypeError(Exception):
52-
"""
53-
Raised when credentials parameter is not proper type.
54-
"""
55-
56-
pass
57-
58-
5951
class AutoIAMAuthNotSupported(Exception):
6052
"""
6153
Exception to be raised when Automatic IAM Authentication is not

google/cloud/sql/connector/instance.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,12 @@
2828

2929
from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
3030
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
31-
from google.cloud.sql.connector.exceptions import CredentialsTypeError
3231
from google.cloud.sql.connector.exceptions import TLSVersionError
3332
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
3433
from google.cloud.sql.connector.refresh_utils import _get_ephemeral
3534
from google.cloud.sql.connector.refresh_utils import _get_metadata
3635
from google.cloud.sql.connector.refresh_utils import _is_valid
3736
from google.cloud.sql.connector.refresh_utils import _seconds_until_refresh
38-
from google.cloud.sql.connector.utils import _auth_init
3937
from google.cloud.sql.connector.utils import write_to_file
4038
from google.cloud.sql.connector.version import __version__ as version
4139

@@ -157,7 +155,6 @@ class Instance:
157155
:type credentials: google.auth.credentials.Credentials
158156
:param credentials
159157
Credentials object used to authenticate connections to Cloud SQL server.
160-
If not specified, Application Default Credentials are used.
161158
162159
:param enable_iam_auth
163160
Enables automatic IAM database authentication for Postgres or MySQL
@@ -206,7 +203,7 @@ def _client_session(self) -> aiohttp.ClientSession:
206203
self.__client_session = aiohttp.ClientSession(headers=headers)
207204
return self.__client_session
208205

209-
_credentials: Optional[Credentials] = None
206+
_credentials: Credentials
210207
_keys: asyncio.Future
211208

212209
_instance_connection_string: str
@@ -227,7 +224,7 @@ def __init__(
227224
driver_name: str,
228225
keys: asyncio.Future,
229226
loop: asyncio.AbstractEventLoop,
230-
credentials: Optional[Credentials] = None,
227+
credentials: Credentials,
231228
enable_iam_auth: bool = False,
232229
quota_project: Optional[str] = None,
233230
sqladmin_api_endpoint: str = "https://sqladmin.googleapis.com",
@@ -250,13 +247,7 @@ def __init__(
250247
self._sqladmin_api_endpoint = sqladmin_api_endpoint
251248
self._loop = loop
252249
self._keys = keys
253-
# validate credentials type
254-
if not isinstance(credentials, Credentials) and credentials is not None:
255-
raise CredentialsTypeError(
256-
"Arg credentials must be type 'google.auth.credentials.Credentials' "
257-
"or None (to use Application Default Credentials)"
258-
)
259-
self._credentials = _auth_init(credentials)
250+
self._credentials = credentials
260251
self._refresh_rate_limiter = AsyncRateLimiter(
261252
max_capacity=2, rate=1 / 30, loop=self._loop
262253
)

google/cloud/sql/connector/utils.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16-
from typing import Optional, Tuple
16+
from typing import Tuple
1717

1818
from cryptography.hazmat.backends import default_backend
1919
from cryptography.hazmat.primitives import serialization
2020
from cryptography.hazmat.primitives.asymmetric import rsa
21-
from google.auth import default
22-
from google.auth.credentials import Credentials
23-
from google.auth.credentials import with_scopes_if_required
2421

2522

2623
async def generate_keys() -> Tuple[bytes, str]:
@@ -104,23 +101,3 @@ def format_database_user(database_version: str, user: str) -> str:
104101
return user.split("@")[0]
105102

106103
return user
107-
108-
109-
def _auth_init(credentials: Optional[Credentials]) -> Credentials:
110-
"""Creates google.auth credentials object with scopes required to make
111-
calls to the Cloud SQL Admin APIs.
112-
113-
:type credentials: google.auth.credentials.Credentials
114-
:param credentials
115-
Credentials object used to authenticate connections to Cloud SQL server.
116-
If not specified, Application Default Credentials are used.
117-
"""
118-
scopes = ["https://www.googleapis.com/auth/sqlservice.admin"]
119-
# if Credentials object is passed in, use for authentication
120-
if isinstance(credentials, Credentials):
121-
credentials = with_scopes_if_required(credentials, scopes=scopes)
122-
# otherwise use application default credentials
123-
else:
124-
credentials, _ = default(scopes=scopes)
125-
126-
return credentials

tests/conftest.py

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from google.auth.credentials import Credentials
2424
from google.auth.credentials import with_scopes_if_required
2525
from google.oauth2 import service_account
26-
from mock import patch
2726
import pytest # noqa F401 Needed to run the tests
2827
from unit.mocks import FakeCSQLInstance # type: ignore
2928

@@ -146,78 +145,79 @@ async def instance(
146145
keys = asyncio.create_task(generate_keys())
147146
_, client_key = await keys
148147

149-
with patch("google.cloud.sql.connector.utils.default") as mock_auth:
150-
mock_auth.return_value = fake_credentials, None
151-
# mock Cloud SQL Admin API calls
152-
with aioresponses() as mocked:
153-
mocked.get(
154-
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{mock_instance.project}/instances/{mock_instance.name}/connectSettings",
155-
status=200,
156-
body=mock_instance.connect_settings(),
157-
repeat=True,
158-
)
159-
mocked.post(
160-
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{mock_instance.project}/instances/{mock_instance.name}:generateEphemeralCert",
161-
status=200,
162-
body=mock_instance.generate_ephemeral(client_key),
163-
repeat=True,
164-
)
165-
166-
instance = Instance(
167-
f"{mock_instance.project}:{mock_instance.region}:{mock_instance.name}",
168-
"pg8000",
169-
keys,
170-
loop,
171-
)
172-
173-
yield instance
174-
await instance.close()
148+
# mock Cloud SQL Admin API calls
149+
with aioresponses() as mocked:
150+
mocked.get(
151+
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{mock_instance.project}/instances/{mock_instance.name}/connectSettings",
152+
status=200,
153+
body=mock_instance.connect_settings(),
154+
repeat=True,
155+
)
156+
mocked.post(
157+
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{mock_instance.project}/instances/{mock_instance.name}:generateEphemeralCert",
158+
status=200,
159+
body=mock_instance.generate_ephemeral(client_key),
160+
repeat=True,
161+
)
162+
163+
instance = Instance(
164+
f"{mock_instance.project}:{mock_instance.region}:{mock_instance.name}",
165+
"pg8000",
166+
keys,
167+
loop,
168+
fake_credentials,
169+
)
170+
171+
yield instance
172+
await instance.close()
175173

176174

177175
@pytest.fixture
178176
async def connector(fake_credentials: Credentials) -> AsyncGenerator[Connector, None]:
179177
instance_connection_name = "my-project:my-region:my-instance"
180178
project, region, instance_name = instance_connection_name.split(":")
181179
# initialize connector
182-
connector = Connector()
183-
with patch("google.cloud.sql.connector.utils.default") as mock_auth:
184-
mock_auth.return_value = fake_credentials, None
185-
# mock Cloud SQL Admin API calls
186-
mock_instance = FakeCSQLInstance(project, region, instance_name)
187-
188-
async def wait_for_keys(future: asyncio.Future) -> Tuple[bytes, str]:
189-
"""
190-
Helper method to await keys of Connector in tests prior to
191-
initializing an Instance object.
192-
"""
193-
return await future
194-
195-
# converting asyncio.Future into concurrent.Future
196-
# await keys in background thread so that .result() is set
197-
# required because keys are needed for mocks, but are not awaited
198-
# in the code until Instance() is initialized
199-
_, client_key = asyncio.run_coroutine_threadsafe(
200-
wait_for_keys(connector._keys), connector._loop
201-
).result()
202-
with aioresponses() as mocked:
203-
mocked.get(
204-
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}/connectSettings",
205-
status=200,
206-
body=mock_instance.connect_settings(),
207-
repeat=True,
208-
)
209-
mocked.post(
210-
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}:generateEphemeralCert",
211-
status=200,
212-
body=mock_instance.generate_ephemeral(client_key),
213-
repeat=True,
214-
)
215-
# initialize Instance using mocked API calls
216-
instance = Instance(
217-
instance_connection_name, "pg8000", connector._keys, connector._loop
218-
)
219-
220-
connector._instances[instance_connection_name] = instance
221-
222-
yield connector
223-
connector.close()
180+
connector = Connector(credentials=fake_credentials)
181+
# mock Cloud SQL Admin API calls
182+
mock_instance = FakeCSQLInstance(project, region, instance_name)
183+
184+
async def wait_for_keys(future: asyncio.Future) -> Tuple[bytes, str]:
185+
"""
186+
Helper method to await keys of Connector in tests prior to
187+
initializing an Instance object.
188+
"""
189+
return await future
190+
191+
# converting asyncio.Future into concurrent.Future
192+
# await keys in background thread so that .result() is set
193+
# required because keys are needed for mocks, but are not awaited
194+
# in the code until Instance() is initialized
195+
_, client_key = asyncio.run_coroutine_threadsafe(
196+
wait_for_keys(connector._keys), connector._loop
197+
).result()
198+
with aioresponses() as mocked:
199+
mocked.get(
200+
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}/connectSettings",
201+
status=200,
202+
body=mock_instance.connect_settings(),
203+
repeat=True,
204+
)
205+
mocked.post(
206+
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}:generateEphemeralCert",
207+
status=200,
208+
body=mock_instance.generate_ephemeral(client_key),
209+
repeat=True,
210+
)
211+
# initialize Instance using mocked API calls
212+
instance = Instance(
213+
instance_connection_name,
214+
"pg8000",
215+
connector._keys,
216+
connector._loop,
217+
fake_credentials,
218+
)
219+
220+
connector._instances[instance_connection_name] = instance
221+
222+
yield connector
223+
connector.close()

0 commit comments

Comments
 (0)