Skip to content

Commit 0f8aac5

Browse files
authored
chore: add static type checking with flake8-annotations and mypy (#74)
* add type checking to lint session * add missing type annotations * remove unused utils.connect function * add static type checking with mypy
1 parent 07de2a2 commit 0f8aac5

17 files changed

+117
-91
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
2-
ignore = E203, E266, E501, W503
2+
ignore = E203, E266, E501, W503, ANN101
33
exclude =
44
# Exclude generated code.
55
**/proto/**

.mypy.ini

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[mypy]
2+
python_version = 3.6
3+
warn_unused_configs = True
4+
plugins = sqlmypy
5+
6+
[mypy-google.auth.*]
7+
ignore_missing_imports = True
8+
9+
[mypy-pg8000]
10+
ignore_missing_imports = True

google/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Iterable
16+
1517
try:
1618
import pkg_resources
1719

1820
pkg_resources.declare_namespace(__name__)
1921
except ImportError:
2022
import pkgutil
2123

22-
__path__ = pkgutil.extend_path(__path__, __name__)
24+
__path__: Iterable[str] = pkgutil.extend_path(__path__, __name__)

google/cloud/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Iterable
1415

1516
try:
1617
import pkg_resources
@@ -19,4 +20,4 @@
1920
except ImportError:
2021
import pkgutil
2122

22-
__path__ = pkgutil.extend_path(__path__, __name__)
23+
__path__: Iterable[str] = pkgutil.extend_path(__path__, __name__)

google/cloud/sql/connector/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
limitations under the License.
1515
"""
1616

17+
from typing import Iterable
18+
1719
from .connector import connect
1820
from .instance_connection_manager import CloudSQLConnectionError
1921

22+
2023
__ALL__ = [connect, CloudSQLConnectionError]
2124

2225
try:
@@ -26,4 +29,4 @@
2629
except ImportError:
2730
import pkgutil
2831

29-
__path__ = pkgutil.extend_path(__path__, __name__)
32+
__path__: Iterable[str] = pkgutil.extend_path(__path__, __name__)

google/cloud/sql/connector/connector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@
2121
from google.cloud.sql.connector.utils import generate_keys
2222

2323
from threading import Thread
24-
from typing import Optional
24+
from typing import Any, Dict, Optional
2525

2626

2727
# This thread is used to background processing
2828
_thread: Optional[Thread] = None
2929
_loop: Optional[asyncio.AbstractEventLoop] = None
3030
_keys: Optional[concurrent.futures.Future] = None
3131

32-
_instances = {}
32+
_instances: Dict[str, InstanceConnectionManager] = {}
3333

3434

3535
def _get_loop() -> asyncio.AbstractEventLoop:
@@ -41,14 +41,14 @@ def _get_loop() -> asyncio.AbstractEventLoop:
4141
return _loop
4242

4343

44-
def _get_keys() -> concurrent.futures.Future:
44+
def _get_keys(loop: asyncio.AbstractEventLoop) -> concurrent.futures.Future:
4545
global _keys
4646
if _keys is None:
47-
_keys = asyncio.run_coroutine_threadsafe(generate_keys(), _loop)
47+
_keys = asyncio.run_coroutine_threadsafe(generate_keys(), loop)
4848
return _keys
4949

5050

51-
def connect(instance_connection_string, driver: str, **kwargs):
51+
def connect(instance_connection_string: str, driver: str, **kwargs: Any) -> Any:
5252
"""Prepares and returns a database connection object and starts a
5353
background thread to refresh the certificates and metadata.
5454
@@ -81,7 +81,7 @@ def connect(instance_connection_string, driver: str, **kwargs):
8181
if instance_connection_string in _instances:
8282
icm = _instances[instance_connection_string]
8383
else:
84-
keys = _get_keys()
84+
keys = _get_keys(loop)
8585
icm = InstanceConnectionManager(instance_connection_string, driver, keys, loop)
8686
_instances[instance_connection_string] = icm
8787

google/cloud/sql/connector/instance_connection_manager.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,22 @@
2828
import ssl
2929
import socket
3030
from tempfile import NamedTemporaryFile
31-
from typing import Any, Awaitable
31+
from typing import (
32+
Any,
33+
Awaitable,
34+
Coroutine,
35+
IO,
36+
Optional,
37+
TYPE_CHECKING,
38+
Union,
39+
)
3240

3341
from functools import partial
3442
import logging
3543

44+
if TYPE_CHECKING:
45+
import pymysql
46+
import pg8000
3647
logger = logging.getLogger(name=__name__)
3748

3849
APPLICATION_NAME = "cloud-sql-python-connector"
@@ -47,25 +58,25 @@ class ConnectionSSLContext(ssl.SSLContext):
4758
required for compatibility with pg8000 driver.
4859
"""
4960

50-
def __init__(self, *args, **kwargs):
61+
def __init__(self, *args: Any, **kwargs: Any) -> None:
5162
self.request_ssl = False
5263
super(ConnectionSSLContext, self).__init__(*args, **kwargs)
5364

5465

5566
class InstanceMetadata:
5667
ip_address: str
57-
_ca_fileobject: NamedTemporaryFile
58-
_cert_fileobject: NamedTemporaryFile
59-
_key_fileobject: NamedTemporaryFile
68+
_ca_fileobject: IO
69+
_cert_fileobject: IO
70+
_key_fileobject: IO
6071
context: ssl.SSLContext
6172

6273
def __init__(
6374
self,
6475
ephemeral_cert: str,
6576
ip_address: str,
66-
private_key: str,
77+
private_key: bytes,
6778
server_ca_cert: str,
68-
):
79+
) -> None:
6980
self.ip_address = ip_address
7081

7182
self._ca_fileobject = NamedTemporaryFile(suffix=".pem")
@@ -96,8 +107,8 @@ class CloudSQLConnectionError(Exception):
96107
correctly.
97108
"""
98109

99-
def __init__(self, *args, **kwargs) -> None:
100-
super(CloudSQLConnectionError, self).__init__(self, *args, **kwargs)
110+
def __init__(self, *args: Any) -> None:
111+
super(CloudSQLConnectionError, self).__init__(self, *args)
101112

102113

103114
class InstanceConnectionManager:
@@ -124,9 +135,9 @@ class InstanceConnectionManager:
124135
# while developing on Windows.
125136
# Link to Github issue:
126137
# https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/22
127-
_loop: asyncio.AbstractEventLoop = None
138+
_loop: asyncio.AbstractEventLoop
128139

129-
__client_session: aiohttp.ClientSession = None
140+
__client_session: Optional[aiohttp.ClientSession] = None
130141

131142
@property
132143
def _client_session(self) -> aiohttp.ClientSession:
@@ -140,16 +151,17 @@ def _client_session(self) -> aiohttp.ClientSession:
140151
)
141152
return self.__client_session
142153

143-
_credentials: Credentials = None
154+
_credentials: Optional[Credentials] = None
155+
_keys: Awaitable
144156

145-
_instance_connection_string: str = None
146-
_user_agent_string: str = None
147-
_instance: str = None
148-
_project: str = None
149-
_region: str = None
157+
_instance_connection_string: str
158+
_user_agent_string: str
159+
_instance: str
160+
_project: str
161+
_region: str
150162

151-
_current: asyncio.Task = None
152-
_next: asyncio.Task = None
163+
_current: Union[Coroutine, asyncio.Task]
164+
_next: Union[Coroutine, asyncio.Task]
153165

154166
def __init__(
155167
self,
@@ -174,7 +186,7 @@ def __init__(
174186

175187
self._user_agent_string = f"{APPLICATION_NAME}/{version}+{driver_name}"
176188
self._loop = loop
177-
self._keys: Awaitable = asyncio.wrap_future(keys, loop=self._loop)
189+
self._keys = asyncio.wrap_future(keys, loop=self._loop)
178190
self._auth_init()
179191

180192
logger.debug("Updating instance data")
@@ -183,17 +195,17 @@ def __init__(
183195
self._next = self._current
184196
asyncio.run_coroutine_threadsafe(self._current, self._loop)
185197

186-
def __del__(self):
198+
def __del__(self) -> None:
187199
"""Deconstructor to make sure ClientSession is closed and tasks have
188200
finished to have a graceful exit.
189201
"""
190202
logger.debug("Entering deconstructor")
191203

192204
async def _deconstruct() -> None:
193-
if self._current is not None:
205+
if isinstance(self._current, asyncio.Task):
194206
logger.debug("Waiting for _current to be cancelled")
195207
self._current.cancel()
196-
if self._next is not None:
208+
if isinstance(self._next, asyncio.Task):
197209
logger.debug("Waiting for _next to be cancelled")
198210
self._next.cancel()
199211
if not self._client_session.closed:
@@ -293,9 +305,9 @@ async def _schedule_refresh(self, delay: int) -> asyncio.Task:
293305
logger.debug("Schedule refresh task cancelled.")
294306
raise e
295307

296-
return self._perform_refresh()
308+
return await self._perform_refresh()
297309

298-
def connect(self, driver: str, timeout: int, **kwargs):
310+
def connect(self, driver: str, timeout: int, **kwargs: Any) -> Any:
299311
"""A method that returns a DB-API connection to the database.
300312
301313
:type driver: str
@@ -308,7 +320,7 @@ def connect(self, driver: str, timeout: int, **kwargs):
308320
:returns: A DB-API connection to the primary IP of the database.
309321
"""
310322

311-
connect_future = asyncio.run_coroutine_threadsafe(
323+
connect_future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
312324
self._connect(driver, **kwargs), self._loop
313325
)
314326

@@ -320,7 +332,7 @@ def connect(self, driver: str, timeout: int, **kwargs):
320332
else:
321333
return connection
322334

323-
async def _connect(self, driver: str, **kwargs) -> Any:
335+
async def _connect(self, driver: str, **kwargs: Any) -> Any:
324336
"""A method that returns a DB-API connection to the database.
325337
326338
:type driver: str
@@ -354,7 +366,9 @@ async def _connect(self, driver: str, **kwargs) -> Any:
354366

355367
return await self._loop.run_in_executor(None, connect_partial)
356368

357-
def _connect_with_pymysql(self, ip_address: str, ctx: ssl.SSLContext, **kwargs):
369+
def _connect_with_pymysql(
370+
self, ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
371+
) -> "pymysql.connections.Connection":
358372
"""Helper function to create a pymysql DB-API connection object.
359373
360374
:type ip_address: str
@@ -385,7 +399,9 @@ def _connect_with_pymysql(self, ip_address: str, ctx: ssl.SSLContext, **kwargs):
385399
conn.connect(sock)
386400
return conn
387401

388-
def _connect_with_pg8000(self, ip_address: str, ctx: ssl.SSLContext, **kwargs):
402+
def _connect_with_pg8000(
403+
self, ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
404+
) -> "pg8000.dbapi.Connection":
389405
"""Helper function to create a pg8000 DB-API connection object.
390406
391407
:type ip_address: str
@@ -409,7 +425,7 @@ def _connect_with_pg8000(self, ip_address: str, ctx: ssl.SSLContext, **kwargs):
409425
user = kwargs.pop("user")
410426
db = kwargs.pop("db")
411427
passwd = kwargs.pop("password")
412-
ctx.request_ssl = False
428+
setattr(ctx, "request_ssl", False)
413429
return pg8000.dbapi.connect(
414430
user,
415431
database=db,

google/cloud/sql/connector/refresh_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from google.auth.credentials import Credentials
2121
import google.auth.transport.requests
2222
import json
23-
from typing import Dict, Union
23+
from typing import Any, Dict
2424

2525
import logging
2626

@@ -34,7 +34,7 @@ async def _get_metadata(
3434
credentials: Credentials,
3535
project: str,
3636
instance: str,
37-
) -> Dict[str, Union[Dict, str]]:
37+
) -> Dict[str, Any]:
3838
"""Requests metadata from the Cloud SQL Instance
3939
and returns a dictionary containing the IP addresses and certificate
4040
authority of the Cloud SQL Instance.

google/cloud/sql/connector/utils.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
limitations under the License.
1515
"""
1616

17-
import pymysql.cursors
1817
from cryptography.hazmat.backends import default_backend
1918
from cryptography.hazmat.primitives import serialization
2019
from cryptography.hazmat.primitives.asymmetric import rsa
2120

21+
from typing import Tuple
2222

23-
async def generate_keys():
23+
24+
async def generate_keys() -> Tuple[bytes, str]:
2425
"""A helper function to generate the private and public keys.
2526
2627
backend - The value specified is default_backend(). This is because the
@@ -57,27 +58,7 @@ async def generate_keys():
5758
return priv_key, pub_key
5859

5960

60-
def connect(host, user, password, db_name):
61-
"""
62-
Connect method to be used as a custom creator in the SQLAlchemy engine
63-
creation.
64-
"""
65-
return pymysql.connect(
66-
host=host,
67-
user=user,
68-
password=password,
69-
db=db_name,
70-
ssl={
71-
"ssl": {
72-
"ca": "./ca.pem",
73-
"cert": "./cert.pem",
74-
"key": "./priv.pem",
75-
} # noqa: E501
76-
},
77-
)
78-
79-
80-
def write_to_file(serverCaCert, ephemeralCert, priv_key):
61+
def write_to_file(serverCaCert: str, ephemeralCert: str, priv_key: bytes) -> None:
8162
"""
8263
Helper function to write the serverCaCert, ephemeral certificate and
8364
private key to .pem files

noxfile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,17 @@
2626
if os.path.exists("samples"):
2727
BLACK_PATHS.append("samples")
2828

29-
3029
@nox.session(python="3.7")
3130
def lint(session):
3231
"""Run linters.
3332
Returns a failure if the linters find linting errors or sufficiently
3433
serious code quality issues.
3534
"""
36-
session.install("flake8", "black")
35+
session.install("flake8", "flake8-annotations", "black", "mypy", "sqlalchemy-stubs")
3736
session.install("-r", "requirements.txt")
3837
session.run("black", "--check", *BLACK_PATHS)
3938
session.run("flake8", "google", "tests")
39+
session.run("mypy", "google", "tests")
4040

4141

4242
@nox.session(python="3.6")

0 commit comments

Comments
 (0)