2828import ssl
2929import socket
3030from tempfile import NamedTemporaryFile
31- from typing import Any , Awaitable
31+ from typing import (
32+ Any ,
33+ Awaitable ,
34+ Coroutine ,
35+ IO ,
36+ Optional ,
37+ TYPE_CHECKING ,
38+ Union ,
39+ )
3240
3341from functools import partial
3442import logging
3543
44+ if TYPE_CHECKING :
45+ import pymysql
46+ import pg8000
3647logger = logging .getLogger (name = __name__ )
3748
3849APPLICATION_NAME = "cloud-sql-python-connector"
@@ -47,25 +58,25 @@ class ConnectionSSLContext(ssl.SSLContext):
4758 required for compatibility with pg8000 driver.
4859 """
4960
50- def __init__ (self , * args , ** kwargs ) :
61+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
5162 self .request_ssl = False
5263 super (ConnectionSSLContext , self ).__init__ (* args , ** kwargs )
5364
5465
5566class InstanceMetadata :
5667 ip_address : str
57- _ca_fileobject : NamedTemporaryFile
58- _cert_fileobject : NamedTemporaryFile
59- _key_fileobject : NamedTemporaryFile
68+ _ca_fileobject : IO
69+ _cert_fileobject : IO
70+ _key_fileobject : IO
6071 context : ssl .SSLContext
6172
6273 def __init__ (
6374 self ,
6475 ephemeral_cert : str ,
6576 ip_address : str ,
66- private_key : str ,
77+ private_key : bytes ,
6778 server_ca_cert : str ,
68- ):
79+ ) -> None :
6980 self .ip_address = ip_address
7081
7182 self ._ca_fileobject = NamedTemporaryFile (suffix = ".pem" )
@@ -96,8 +107,8 @@ class CloudSQLConnectionError(Exception):
96107 correctly.
97108 """
98109
99- def __init__ (self , * args , ** kwargs ) -> None :
100- super (CloudSQLConnectionError , self ).__init__ (self , * args , ** kwargs )
110+ def __init__ (self , * args : Any ) -> None :
111+ super (CloudSQLConnectionError , self ).__init__ (self , * args )
101112
102113
103114class InstanceConnectionManager :
@@ -124,9 +135,9 @@ class InstanceConnectionManager:
124135 # while developing on Windows.
125136 # Link to Github issue:
126137 # https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/22
127- _loop : asyncio .AbstractEventLoop = None
138+ _loop : asyncio .AbstractEventLoop
128139
129- __client_session : aiohttp .ClientSession = None
140+ __client_session : Optional [ aiohttp .ClientSession ] = None
130141
131142 @property
132143 def _client_session (self ) -> aiohttp .ClientSession :
@@ -140,16 +151,17 @@ def _client_session(self) -> aiohttp.ClientSession:
140151 )
141152 return self .__client_session
142153
143- _credentials : Credentials = None
154+ _credentials : Optional [Credentials ] = None
155+ _keys : Awaitable
144156
145- _instance_connection_string : str = None
146- _user_agent_string : str = None
147- _instance : str = None
148- _project : str = None
149- _region : str = None
157+ _instance_connection_string : str
158+ _user_agent_string : str
159+ _instance : str
160+ _project : str
161+ _region : str
150162
151- _current : asyncio .Task = None
152- _next : asyncio .Task = None
163+ _current : Union [ Coroutine , asyncio .Task ]
164+ _next : Union [ Coroutine , asyncio .Task ]
153165
154166 def __init__ (
155167 self ,
@@ -174,7 +186,7 @@ def __init__(
174186
175187 self ._user_agent_string = f"{ APPLICATION_NAME } /{ version } +{ driver_name } "
176188 self ._loop = loop
177- self ._keys : Awaitable = asyncio .wrap_future (keys , loop = self ._loop )
189+ self ._keys = asyncio .wrap_future (keys , loop = self ._loop )
178190 self ._auth_init ()
179191
180192 logger .debug ("Updating instance data" )
@@ -183,17 +195,17 @@ def __init__(
183195 self ._next = self ._current
184196 asyncio .run_coroutine_threadsafe (self ._current , self ._loop )
185197
186- def __del__ (self ):
198+ def __del__ (self ) -> None :
187199 """Deconstructor to make sure ClientSession is closed and tasks have
188200 finished to have a graceful exit.
189201 """
190202 logger .debug ("Entering deconstructor" )
191203
192204 async def _deconstruct () -> None :
193- if self ._current is not None :
205+ if isinstance ( self ._current , asyncio . Task ) :
194206 logger .debug ("Waiting for _current to be cancelled" )
195207 self ._current .cancel ()
196- if self ._next is not None :
208+ if isinstance ( self ._next , asyncio . Task ) :
197209 logger .debug ("Waiting for _next to be cancelled" )
198210 self ._next .cancel ()
199211 if not self ._client_session .closed :
@@ -293,9 +305,9 @@ async def _schedule_refresh(self, delay: int) -> asyncio.Task:
293305 logger .debug ("Schedule refresh task cancelled." )
294306 raise e
295307
296- return self ._perform_refresh ()
308+ return await self ._perform_refresh ()
297309
298- def connect (self , driver : str , timeout : int , ** kwargs ) :
310+ def connect (self , driver : str , timeout : int , ** kwargs : Any ) -> Any :
299311 """A method that returns a DB-API connection to the database.
300312
301313 :type driver: str
@@ -308,7 +320,7 @@ def connect(self, driver: str, timeout: int, **kwargs):
308320 :returns: A DB-API connection to the primary IP of the database.
309321 """
310322
311- connect_future = asyncio .run_coroutine_threadsafe (
323+ connect_future : concurrent . futures . Future = asyncio .run_coroutine_threadsafe (
312324 self ._connect (driver , ** kwargs ), self ._loop
313325 )
314326
@@ -320,7 +332,7 @@ def connect(self, driver: str, timeout: int, **kwargs):
320332 else :
321333 return connection
322334
323- async def _connect (self , driver : str , ** kwargs ) -> Any :
335+ async def _connect (self , driver : str , ** kwargs : Any ) -> Any :
324336 """A method that returns a DB-API connection to the database.
325337
326338 :type driver: str
@@ -354,7 +366,9 @@ async def _connect(self, driver: str, **kwargs) -> Any:
354366
355367 return await self ._loop .run_in_executor (None , connect_partial )
356368
357- def _connect_with_pymysql (self , ip_address : str , ctx : ssl .SSLContext , ** kwargs ):
369+ def _connect_with_pymysql (
370+ self , ip_address : str , ctx : ssl .SSLContext , ** kwargs : Any
371+ ) -> "pymysql.connections.Connection" :
358372 """Helper function to create a pymysql DB-API connection object.
359373
360374 :type ip_address: str
@@ -385,7 +399,9 @@ def _connect_with_pymysql(self, ip_address: str, ctx: ssl.SSLContext, **kwargs):
385399 conn .connect (sock )
386400 return conn
387401
388- def _connect_with_pg8000 (self , ip_address : str , ctx : ssl .SSLContext , ** kwargs ):
402+ def _connect_with_pg8000 (
403+ self , ip_address : str , ctx : ssl .SSLContext , ** kwargs : Any
404+ ) -> "pg8000.dbapi.Connection" :
389405 """Helper function to create a pg8000 DB-API connection object.
390406
391407 :type ip_address: str
@@ -409,7 +425,7 @@ def _connect_with_pg8000(self, ip_address: str, ctx: ssl.SSLContext, **kwargs):
409425 user = kwargs .pop ("user" )
410426 db = kwargs .pop ("db" )
411427 passwd = kwargs .pop ("password" )
412- ctx . request_ssl = False
428+ setattr ( ctx , " request_ssl" , False )
413429 return pg8000 .dbapi .connect (
414430 user ,
415431 database = db ,
0 commit comments