Skip to content

Commit 060a78f

Browse files
authored
feat: add pytds support (#57)
* feat: add pytds support * move proxy port number to constant * scope dependencies in setup.py * Update requirements.txt * add type hints * update setup.py
1 parent 8150b6f commit 060a78f

File tree

6 files changed

+136
-8
lines changed

6 files changed

+136
-8
lines changed

.mypy.ini

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,10 @@ plugins = sqlmypy
77
ignore_missing_imports = True
88

99
[mypy-pg8000]
10+
ignore_missing_imports = True
11+
12+
[mypy-pytds]
13+
ignore_missing_imports = True
14+
15+
[mypy-pytest]
1016
ignore_missing_imports = True

google/cloud/sql/connector/instance_connection_manager.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@
4444
if TYPE_CHECKING:
4545
import pymysql
4646
import pg8000
47+
import pytds
4748
logger = logging.getLogger(name=__name__)
4849

4950
APPLICATION_NAME = "cloud-sql-python-connector"
51+
SERVER_PROXY_PORT = 3307
5052

5153
# The default delay is set to 55 minutes since each ephemeral certificate is only
5254
# valid for an hour. This gives five minutes of buffer time.
@@ -351,6 +353,7 @@ async def _connect(self, driver: str, **kwargs: Any) -> Any:
351353
connect_func = {
352354
"pymysql": self._connect_with_pymysql,
353355
"pg8000": self._connect_with_pg8000,
356+
"pytds": self._connect_with_pytds,
354357
}
355358

356359
instance_data: InstanceMetadata = await self._current
@@ -391,7 +394,8 @@ def _connect_with_pymysql(
391394

392395
# Create socket and wrap with context.
393396
sock = ctx.wrap_socket(
394-
socket.create_connection((ip_address, 3307)), server_hostname=ip_address
397+
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
398+
server_hostname=ip_address,
395399
)
396400

397401
# Create pymysql connection object and hand in pre-made connection
@@ -431,7 +435,43 @@ def _connect_with_pg8000(
431435
database=db,
432436
password=passwd,
433437
host=ip_address,
434-
port=3307,
438+
port=SERVER_PROXY_PORT,
435439
ssl_context=ctx,
436440
**kwargs,
437441
)
442+
443+
def _connect_with_pytds(
444+
self, ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
445+
) -> "pytds.Connection":
446+
"""Helper function to create a pg8000 DB-API connection object.
447+
448+
:type ip_address: str
449+
:param ip_address: A string containing an IP address for the Cloud SQL
450+
instance.
451+
452+
:type ctx: ssl.SSLContext
453+
:param ctx: An SSLContext object created from the Cloud SQL server CA
454+
cert and ephemeral cert.
455+
456+
457+
:rtype: pytds.Connection
458+
:returns: A pytds Connection object for the Cloud SQL instance.
459+
"""
460+
try:
461+
import pytds
462+
except ImportError:
463+
raise ImportError(
464+
'Unable to import module "pytds." Please install and try again.'
465+
)
466+
user = kwargs.pop("user")
467+
db = kwargs.pop("db")
468+
passwd = kwargs.pop("password")
469+
470+
# Create socket and wrap with context.
471+
sock = ctx.wrap_socket(
472+
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
473+
server_hostname=ip_address,
474+
)
475+
return pytds.connect(
476+
ip_address, database=db, user=user, password=passwd, sock=sock, **kwargs
477+
)

requirements-test.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ pytest==6.2.2
22
mock==4.0.3
33
pytest-cov==2.11.1
44
pytest-asyncio==0.14.0
5-
SQLAlchemy==1.3.23
5+
SQLAlchemy==1.3.23
6+
sqlalchemy-pytds

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ aiohttp==3.7.4.post0
22
cryptography==3.4.7
33
PyMySQL==1.0.2
44
pg8000==1.19.2
5+
# python-tds==1.10.0
6+
git+https://github.com/denisenkom/pytds.git
57
pyopenssl==20.0.1
6-
pytest==6.2.3
78
Requests==2.25.1
89
google-api-python-client==2.2.0

setup.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,10 @@
3737
version = version["__version__"]
3838

3939
release_status = "Development Status :: 3 - Alpha"
40-
dependencies = [
40+
core_dependencies = [
4141
"aiohttp",
4242
"cryptography",
43-
"PyMySQL",
44-
"pytest",
43+
"pyopenssl",
4544
"Requests",
4645
"google-api-python-client",
4746
]
@@ -67,7 +66,12 @@
6766
platforms="Posix; MacOS X",
6867
packages=packages,
6968
namespace_packages=namespaces,
70-
install_requires=dependencies,
69+
install_requires=core_dependencies,
70+
extras_require={
71+
"pymysql": ["PyMySQL==1.0.2"],
72+
"pg8000": ["pg8000==1.19.2"],
73+
"pytds": ["python-tds @ git+https://github.com/denisenkom/pytds.git"]
74+
},
7175
python_requires=">=3.6",
7276
include_package_data=True,
7377
zip_safe=False,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
""""
2+
Copyright 2021 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
import os
17+
import uuid
18+
from typing import Generator
19+
20+
import pytds
21+
import pytest
22+
import sqlalchemy
23+
from google.cloud.sql.connector import connector
24+
25+
table_name = f"books_{uuid.uuid4().hex}"
26+
27+
28+
def init_connection_engine() -> sqlalchemy.engine.Engine:
29+
def getconn() -> pytds.Connection:
30+
conn = connector.connect(
31+
os.environ["SQLSERVER_CONNECTION_NAME"],
32+
"pytds",
33+
user=os.environ["SQLSERVER_USER"],
34+
password=os.environ["SQLSERVER_PASS"],
35+
db=os.environ["SQLSERVER_DB"],
36+
)
37+
return conn
38+
39+
engine = sqlalchemy.create_engine(
40+
"mssql+pytds://localhost",
41+
creator=getconn,
42+
)
43+
engine.dialect.description_encoding = None
44+
return engine
45+
46+
47+
@pytest.fixture(name="pool")
48+
def setup() -> Generator:
49+
pool = init_connection_engine()
50+
51+
with pool.connect() as conn:
52+
conn.execute(
53+
f"CREATE TABLE {table_name}"
54+
" ( id CHAR(20) NOT NULL, title TEXT NOT NULL );"
55+
)
56+
57+
yield pool
58+
59+
with pool.connect() as conn:
60+
conn.execute(f"DROP TABLE {table_name}")
61+
62+
63+
def test_pooled_connection_with_pytds(pool: sqlalchemy.engine.Engine) -> None:
64+
insert_stmt = sqlalchemy.text(
65+
f"INSERT INTO {table_name} (id, title) VALUES (:id, :title)",
66+
)
67+
with pool.connect() as conn:
68+
conn.execute(insert_stmt, id="book1", title="Book One")
69+
conn.execute(insert_stmt, id="book2", title="Book Two")
70+
71+
select_stmt = sqlalchemy.text(f"SELECT title FROM {table_name} ORDER BY ID;")
72+
with pool.connect() as conn:
73+
rows = conn.execute(select_stmt).fetchall()
74+
titles = [row[0] for row in rows]
75+
76+
assert titles == ["Book One", "Book Two"]

0 commit comments

Comments
 (0)