Skip to content

Commit 82da410

Browse files
authored
feat: add support for a custom user agent (#986)
1 parent 48a9ccb commit 82da410

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
loop: Optional[asyncio.AbstractEventLoop] = None,
9292
quota_project: Optional[str] = None,
9393
sqladmin_api_endpoint: str = "https://sqladmin.googleapis.com",
94+
user_agent: Optional[str] = None,
9495
) -> None:
9596
# if event loop is given, use for background tasks
9697
if loop:
@@ -115,6 +116,7 @@ def __init__(
115116
self._quota_project = quota_project
116117
self._sqladmin_api_endpoint = sqladmin_api_endpoint
117118
self._credentials = credentials
119+
self._user_agent = user_agent
118120

119121
def connect(
120122
self, instance_connection_string: str, driver: str, **kwargs: Any
@@ -211,6 +213,7 @@ async def connect_async(
211213
enable_iam_auth,
212214
self._quota_project,
213215
self._sqladmin_api_endpoint,
216+
user_agent=self._user_agent,
214217
)
215218
self._instances[instance_connection_string] = instance
216219

google/cloud/sql/connector/instance.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ def get_preferred_ip(self, ip_type: IPTypes) -> str:
134134
)
135135

136136

137+
def _format_user_agent(version: str, driver: str, custom: Optional[str]) -> str:
138+
agent = f"{APPLICATION_NAME}/{version}+{driver}"
139+
if custom:
140+
agent = f"{agent} {custom}"
141+
return agent
142+
143+
137144
class Instance:
138145
"""A class to manage the details of the connection to a Cloud SQL
139146
instance, including refreshing the credentials.
@@ -224,6 +231,7 @@ def __init__(
224231
enable_iam_auth: bool = False,
225232
quota_project: Optional[str] = None,
226233
sqladmin_api_endpoint: str = "https://sqladmin.googleapis.com",
234+
user_agent: Optional[str] = None,
227235
) -> None:
228236
# validate and parse instance connection name
229237
self._project, self._region, self._instance = _parse_instance_connection_name(
@@ -233,7 +241,11 @@ def __init__(
233241

234242
self._enable_iam_auth = enable_iam_auth
235243

236-
self._user_agent_string = f"{APPLICATION_NAME}/{version}+{driver_name}"
244+
self._user_agent_string = _format_user_agent(
245+
version,
246+
driver_name,
247+
user_agent,
248+
)
237249
self._quota_project = quota_project
238250
self._sqladmin_api_endpoint = sqladmin_api_endpoint
239251
self._loop = loop

tests/unit/test_instance.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from google.cloud.sql.connector.instance import IPTypes
3535
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
3636
from google.cloud.sql.connector.utils import generate_keys
37+
from google.cloud.sql.connector.version import __version__ as version
3738

3839

3940
@pytest.fixture
@@ -84,7 +85,13 @@ async def test_Instance_init(
8485
)
8586
with patch("google.cloud.sql.connector.utils.default") as mock_auth:
8687
mock_auth.return_value = fake_credentials, None
87-
instance = Instance(connect_string, "pymysql", keys, event_loop)
88+
instance = Instance(
89+
connect_string,
90+
"pymysql",
91+
keys,
92+
event_loop,
93+
user_agent="custom/v1.0.0",
94+
)
8895
project_result = instance._project
8996
region_result = instance._region
9097
instance_result = instance._instance
@@ -93,6 +100,10 @@ async def test_Instance_init(
93100
and region_result == "test-region"
94101
and instance_result == "test-instance"
95102
)
103+
assert (
104+
instance._user_agent_string
105+
== f"cloud-sql-python-connector/{version}+pymysql custom/v1.0.0"
106+
)
96107
# cleanup instance
97108
await instance.close()
98109

0 commit comments

Comments
 (0)