Skip to content

Commit 9c8ed67

Browse files
authored
fix: only replace refresh result if successful or current is invalid (#135)
* fix: only replace refresh result if successful or current result is invalid * set initial value of current to a Future instead of a coroutine * add comment explaining _current replacement logic * linting fix * remove async from refresh_callback * set initial value of _current to Task
1 parent a54a8f3 commit 9c8ed67

File tree

1 file changed

+42
-13
lines changed

1 file changed

+42
-13
lines changed

google/cloud/sql/connector/instance_connection_manager.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@
3636
from typing import (
3737
Any,
3838
Awaitable,
39-
Coroutine,
4039
Dict,
4140
Optional,
4241
TYPE_CHECKING,
43-
Union,
4442
)
4543

4644
from functools import partial
@@ -220,8 +218,8 @@ def _client_session(self) -> aiohttp.ClientSession:
220218
_project: str
221219
_region: str
222220

223-
_current: Union[Coroutine, asyncio.Task]
224-
_next: Union[Coroutine, asyncio.Task]
221+
_current: asyncio.Task
222+
_next: asyncio.Task
225223

226224
def __init__(
227225
self,
@@ -252,11 +250,13 @@ def __init__(
252250
self._keys = asyncio.wrap_future(keys, loop=self._loop)
253251
self._auth_init()
254252

255-
logger.debug("Updating instance data")
253+
async def _set_instance_data() -> None:
254+
logger.debug("Updating instance data")
255+
self._current = self._loop.create_task(self._get_instance_data())
256+
self._next = self._loop.create_task(self._schedule_refresh())
256257

257-
self._current = self._perform_refresh()
258-
self._next = self._current
259-
asyncio.run_coroutine_threadsafe(self._current, self._loop)
258+
init_future = asyncio.run_coroutine_threadsafe(_set_instance_data(), self._loop)
259+
init_future.result()
260260

261261
def __del__(self) -> None:
262262
"""Deconstructor to make sure ClientSession is closed and tasks have
@@ -381,12 +381,39 @@ async def _perform_refresh(self) -> asyncio.Task:
381381

382382
logger.debug("Entered _perform_refresh")
383383

384-
self._current = self._loop.create_task(self._get_instance_data())
385-
# Ephemeral certificate expires in 1 hour, so we schedule a refresh to happen in 55 minutes.
384+
refresh_task = self._loop.create_task(self._get_instance_data())
386385

387-
self._next = self._loop.create_task(self._schedule_refresh())
386+
def _refresh_callback(task: asyncio.Task) -> None:
387+
try:
388+
task.result()
389+
except Exception as e:
390+
logger.warn(
391+
"An error occurred while performing refresh. Retrying immediately.",
392+
e,
393+
)
394+
instance_data = None
395+
try:
396+
instance_data = self._current.result()
397+
except Exception:
398+
# Current result is invalid, no-op
399+
logger.debug("Current instance data is invalid.")
400+
if (
401+
instance_data is None
402+
or instance_data.expiration < datetime.datetime.now()
403+
):
404+
self._current = task
405+
# TODO: Implement force refresh method and a rate-limiter for perform_refresh
406+
# Retry by scheduling a refresh 60s from now.
407+
self._next = self._loop.create_task(self._schedule_refresh(60))
408+
409+
else:
410+
self._current = refresh_task
411+
# Ephemeral certificate expires in 1 hour, so we schedule a refresh to happen in 55 minutes.
412+
self._next = self._loop.create_task(self._schedule_refresh())
388413

389-
return self._current
414+
refresh_task.add_done_callback(_refresh_callback)
415+
416+
return refresh_task
390417

391418
async def _schedule_refresh(self, delay: Optional[int] = None) -> asyncio.Task:
392419
"""A coroutine that sleeps for the specified amount of time before
@@ -466,7 +493,9 @@ async def _connect(
466493
"pytds": self._connect_with_pytds,
467494
}
468495

469-
instance_data: InstanceMetadata = await self._current
496+
instance_data: InstanceMetadata
497+
498+
instance_data = await self._current
470499
ip_address: str = instance_data.get_preferred_ip(ip_type)
471500

472501
try:

0 commit comments

Comments
 (0)