|
36 | 36 | from typing import ( |
37 | 37 | Any, |
38 | 38 | Awaitable, |
39 | | - Coroutine, |
40 | 39 | Dict, |
41 | 40 | Optional, |
42 | 41 | TYPE_CHECKING, |
43 | | - Union, |
44 | 42 | ) |
45 | 43 |
|
46 | 44 | from functools import partial |
@@ -220,8 +218,8 @@ def _client_session(self) -> aiohttp.ClientSession: |
220 | 218 | _project: str |
221 | 219 | _region: str |
222 | 220 |
|
223 | | - _current: Union[Coroutine, asyncio.Task] |
224 | | - _next: Union[Coroutine, asyncio.Task] |
| 221 | + _current: asyncio.Task |
| 222 | + _next: asyncio.Task |
225 | 223 |
|
226 | 224 | def __init__( |
227 | 225 | self, |
@@ -252,11 +250,13 @@ def __init__( |
252 | 250 | self._keys = asyncio.wrap_future(keys, loop=self._loop) |
253 | 251 | self._auth_init() |
254 | 252 |
|
255 | | - logger.debug("Updating instance data") |
| 253 | + async def _set_instance_data() -> None: |
| 254 | + logger.debug("Updating instance data") |
| 255 | + self._current = self._loop.create_task(self._get_instance_data()) |
| 256 | + self._next = self._loop.create_task(self._schedule_refresh()) |
256 | 257 |
|
257 | | - self._current = self._perform_refresh() |
258 | | - self._next = self._current |
259 | | - asyncio.run_coroutine_threadsafe(self._current, self._loop) |
| 258 | + init_future = asyncio.run_coroutine_threadsafe(_set_instance_data(), self._loop) |
| 259 | + init_future.result() |
260 | 260 |
|
261 | 261 | def __del__(self) -> None: |
262 | 262 | """Deconstructor to make sure ClientSession is closed and tasks have |
@@ -381,12 +381,39 @@ async def _perform_refresh(self) -> asyncio.Task: |
381 | 381 |
|
382 | 382 | logger.debug("Entered _perform_refresh") |
383 | 383 |
|
384 | | - self._current = self._loop.create_task(self._get_instance_data()) |
385 | | - # Ephemeral certificate expires in 1 hour, so we schedule a refresh to happen in 55 minutes. |
| 384 | + refresh_task = self._loop.create_task(self._get_instance_data()) |
386 | 385 |
|
387 | | - self._next = self._loop.create_task(self._schedule_refresh()) |
| 386 | + def _refresh_callback(task: asyncio.Task) -> None: |
| 387 | + try: |
| 388 | + task.result() |
| 389 | + except Exception as e: |
| 390 | + logger.warn( |
| 391 | + "An error occurred while performing refresh. Retrying immediately.", |
| 392 | + e, |
| 393 | + ) |
| 394 | + instance_data = None |
| 395 | + try: |
| 396 | + instance_data = self._current.result() |
| 397 | + except Exception: |
| 398 | + # Current result is invalid, no-op |
| 399 | + logger.debug("Current instance data is invalid.") |
| 400 | + if ( |
| 401 | + instance_data is None |
| 402 | + or instance_data.expiration < datetime.datetime.now() |
| 403 | + ): |
| 404 | + self._current = task |
| 405 | + # TODO: Implement force refresh method and a rate-limiter for perform_refresh |
| 406 | + # Retry by scheduling a refresh 60s from now. |
| 407 | + self._next = self._loop.create_task(self._schedule_refresh(60)) |
| 408 | + |
| 409 | + else: |
| 410 | + self._current = refresh_task |
| 411 | + # Ephemeral certificate expires in 1 hour, so we schedule a refresh to happen in 55 minutes. |
| 412 | + self._next = self._loop.create_task(self._schedule_refresh()) |
388 | 413 |
|
389 | | - return self._current |
| 414 | + refresh_task.add_done_callback(_refresh_callback) |
| 415 | + |
| 416 | + return refresh_task |
390 | 417 |
|
391 | 418 | async def _schedule_refresh(self, delay: Optional[int] = None) -> asyncio.Task: |
392 | 419 | """A coroutine that sleeps for the specified amount of time before |
@@ -466,7 +493,9 @@ async def _connect( |
466 | 493 | "pytds": self._connect_with_pytds, |
467 | 494 | } |
468 | 495 |
|
469 | | - instance_data: InstanceMetadata = await self._current |
| 496 | + instance_data: InstanceMetadata |
| 497 | + |
| 498 | + instance_data = await self._current |
470 | 499 | ip_address: str = instance_data.get_preferred_ip(ip_type) |
471 | 500 |
|
472 | 501 | try: |
|
0 commit comments