Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
- add TLS support

## [v0.2.0](https://github.com/Mapepire-IBMi/mapepire-python/releases/tag/v0.2.0) - 2024-11-26
- replace `websocket-client` with `websockets`
Expand Down
15 changes: 13 additions & 2 deletions mapepire_python/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from websockets import ConcurrencyError, ConnectionClosed, InvalidHandshake, InvalidURI

from mapepire_python.data_types import DaemonServer
from mapepire_python.ssl import get_certificate

ReturnType = TypeVar("ReturnType")

Expand All @@ -24,8 +25,18 @@ def _create_ssl_context(self, db2_server: DaemonServer):
if db2_server.ignoreUnauthorized:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
elif db2_server.ca:
ssl_context.load_verify_locations(cadata=db2_server.ca)
else:
if db2_server.ca:
ssl_context.load_verify_locations(cadata=db2_server.ca)
else:
cert = get_certificate(db2_server)
if cert:
ssl_context.load_verify_locations(cadata=cert)
else:
raise ssl.SSLError("Failed to retrieve server certificate")

ssl_context.check_hostname = True
ssl_context.verify_mode = ssl.CERT_REQUIRED
return ssl_context


Expand Down
63 changes: 53 additions & 10 deletions tests/tls_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os
import ssl

import pytest

from mapepire_python.client.sql_job import SQLJob
from mapepire_python.data_types import DaemonServer
from mapepire_python.ssl import get_certificate

# Fetch environment variables
server = os.getenv("VITE_SERVER")
Expand All @@ -15,15 +20,53 @@
creds = DaemonServer(host=server, port=port, user=user, password=password, ignoreUnauthorized=False)


# def test_get_cert():
# cert = get_certificate(creds)
# print(cert)
# assert cert != None
def test_get_cert():
cert = get_certificate(creds)
print(cert)
assert cert != None


def test_verify_cert():
cert = get_certificate(creds)
creds.ignoreUnauthorized = False
creds.ca = cert
job = SQLJob()
result = job.connect(creds)
assert result["success"]


def test_verify_cert_not_provided():
creds.ignoreUnauthorized = False
job = SQLJob()
result = job.connect(creds)
assert result["success"]


def test_bad_cert():
badCert = """-----BEGIN CERTIFICATE-----
mIIDhTCCAm2gAwIBAgIEYRpOADANBgkqhkiG9w0BAQsFADBzMRAwDgYDVQQIEwdV
bmtub3duMRAwDgYDVQQGEwdVbmtub3duMRYwFAYDVQQKEw1EYjIgZm9yIElCTSBp
MRowGAYDVQQLExFXZWIgU29ja2V0IFNlcnZlcjEZMBcGA1UEAxMQT1NTQlVJTEQu
UlpLSC5ERTAeFw0yNDA4MjMxODE2MDJaFw0zNDA4MjUxODE2MDJaMHMxEDAOBgNV
BAgTB1Vua25vd24xEDAOBgNVBAYTB1Vua25vd24xFjAUBgNVBAoTDURiMiBmb3Ig
SUJNIGkxGjAYBgNVBAsTEVdlYiBTb2NrZXQgU2VydmVyMRkwFwYDVQQDExBPU1NC
VUlMRC5SWktILkRFMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAhKBx
5KoTsBs3dHibT/j8ycApa6teJOclaiCl9fX5IwKP0dli5qZ91t5sZ+51qS3mgLny
zMWBCaSIQYsuDEE374lHYpYB6wh/00VE1NseJpHbqbCQz1GSUz/d4tK4R1qx0Gv0
lpKOd8/oMLUZ24FCEUKaqeQBxTzlQxkI9t2DbIRwS6U6oc4uj5DN2EIU+mfLb17y
j8iA7VMKsRmoke2vyOLXJJYJeASNI02AbHcbYkd6BaoyNeb3BlpssEhgZribWmdy
FhrJldpGtJyirvABaZQaEFelEqmSVbdPWccX3JWQdorZoNVXCypxJatxOZAhCg6f
iu3AceHUr+dMAS8z4QIDAQABoyEwHzAdBgNVHQ4EFgQU6VvyCjQ5574xtCg0oypV
zHP0vAMwDQYJKoZIhvcNAQELBQADggEBAD4bKhansD+uuUPYaIvPwyclr4zPvuyg
QAFu5oILqddzgPGIwogbxTxQkNjEGyorFJj1vJBCVIq4zJJ0DIv57BK/oVMy4Byl
6zMhJTjS74assgjCq1pVjIBtc2PCfiWxzo0wQCOEL8gsNCy/w5EaIATKfLtx6+Fd
CHsadf7fvJnLnK3FXOStAnN31ISSTwsvsRobdXX70nlYM/2OaZQsIlndftVRbI39
2+94KHciPSwo/4fu+FLuvOm37GS+/ST3BKDSvwRJRxUc0r8lo1STiQz0cXC6uqDd
79/VBUN4NLZ3mBVk2FGuazIu9n80+o0fI5sg1ucQ/hBt8WR8iQ6sZUc=
-----END CERTIFICATE-----"""

# def test_verify_cert():
# cert = get_certificate(creds)
# creds.ca = cert
# job = SQLJob()
# result = job.connect(creds)
# print(result)
creds.ca = badCert
creds.ignoreUnauthorized = False
job = SQLJob()
with pytest.raises(ssl.SSLError) as err:
res = job.connect(creds)
Loading