Skip to content

Commit a6433b9

Browse files
feat: configure pg8000 connection with SSLSocket (#789)
1 parent c4cd9bc commit a6433b9

File tree

4 files changed

+25
-19
lines changed

4 files changed

+25
-19
lines changed

google/cloud/sql/connector/instance.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,6 @@ def __init__(
104104
"Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support."
105105
)
106106
self.context.minimum_version = ssl.TLSVersion.TLSv1_2
107-
108-
# add request_ssl attribute to ssl.SSLContext, required for pg8000 driver
109-
self.context.request_ssl = False # type: ignore
110-
111107
self.expiration = expiration
112108

113109
# tmpdir and its contents are automatically deleted after the CA cert

google/cloud/sql/connector/pg8000.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
import socket
1617
import ssl
1718
from typing import Any, TYPE_CHECKING
1819

@@ -39,26 +40,26 @@ def connect(
3940
:rtype: pg8000.dbapi.Connection
4041
:returns: A pg8000 Connection object for the Cloud SQL instance.
4142
"""
42-
# Connecting through pg8000 is done by passing in an SSL Context and setting the
43-
# "request_ssl" attr to false. This works because when "request_ssl" is false,
44-
# the driver skips the database level SSL/TLS exchange, but still uses the
45-
# ssl_context (if it is not None) to create the connection.
4643
try:
4744
import pg8000
4845
except ImportError:
4946
raise ImportError(
5047
'Unable to import module "pg8000." Please install and try again.'
5148
)
49+
50+
# Create socket and wrap with context.
51+
sock = ctx.wrap_socket(
52+
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
53+
server_hostname=ip_address,
54+
)
55+
5256
user = kwargs.pop("user")
5357
db = kwargs.pop("db")
5458
passwd = kwargs.pop("password", None)
55-
setattr(ctx, "request_ssl", False)
5659
return pg8000.dbapi.connect(
5760
user,
5861
database=db,
5962
password=passwd,
60-
host=ip_address,
61-
port=SERVER_PROXY_PORT,
62-
ssl_context=ctx,
63+
sock=sock,
6364
**kwargs,
6465
)

tests/unit/test_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_Connector_connect(connector: Connector) -> None:
108108
"""Test that Connector.connect can properly return a DB API connection."""
109109
connect_string = "my-project:my-region:my-instance"
110110
# patch db connection creation
111-
with patch("pg8000.dbapi.connect") as mock_connect:
111+
with patch("google.cloud.sql.connector.pg8000.connect") as mock_connect:
112112
mock_connect.return_value = True
113113
connection = connector.connect(
114114
connect_string, "pg8000", user="my-user", password="my-pass", db="my-db"

tests/unit/test_pg8000.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,32 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16-
import ssl
16+
from functools import partial
1717
from typing import Any
1818

1919
from mock import patch
20+
from mocks import create_ssl_context
21+
import pytest
2022

2123
from google.cloud.sql.connector.pg8000 import connect
2224

2325

24-
def test_pg8000(kwargs: Any) -> None:
26+
@pytest.mark.usefixtures("server")
27+
@pytest.mark.asyncio
28+
async def test_pg8000(kwargs: Any) -> None:
2529
"""Test to verify that pg8000 gets to proper connection call."""
26-
ip_addr = "0.0.0.0"
27-
context = ssl.create_default_context()
30+
ip_addr = "127.0.0.1"
31+
# build ssl.SSLContext
32+
context = await create_ssl_context()
33+
# force all wrap_socket calls to have do_handshake_on_connect=False
34+
setattr(
35+
context,
36+
"wrap_socket",
37+
partial(context.wrap_socket, do_handshake_on_connect=False),
38+
)
2839
with patch("pg8000.dbapi.connect") as mock_connect:
2940
mock_connect.return_value = True
3041
connection = connect(ip_addr, context, **kwargs)
3142
assert connection is True
32-
# verify ssl.SSLContext has 'request_ssl' attribute set to false
33-
assert context.request_ssl is False # type: ignore
3443
# verify that driver connection call would be made
3544
assert mock_connect.assert_called_once

0 commit comments

Comments
 (0)