From 1df489515dbf0235cbd94946405ec76feaf65559 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Sun, 22 Jan 2023 13:59:01 -0600 Subject: [PATCH] Add SSLContext methods and properties --- src/truststore/_api.py | 141 ++++++++++++++++++++++++++++++++++++--- tests/test_sslcontext.py | 36 ++++++++++ 2 files changed, 169 insertions(+), 8 deletions(-) create mode 100644 tests/test_sslcontext.py diff --git a/src/truststore/_api.py b/src/truststore/_api.py index 463bdbf..4f2030d 100644 --- a/src/truststore/_api.py +++ b/src/truststore/_api.py @@ -2,7 +2,7 @@ import platform import socket import ssl -from typing import Any +import typing from _ssl import ENCODING_DER # type: ignore[import] @@ -14,6 +14,10 @@ from ._openssl import _configure_context, _verify_peercerts_impl +_StrOrBytesPath: typing.TypeAlias = str | bytes | os.PathLike[str] | os.PathLike[bytes] +_PasswordType: typing.TypeAlias = str | bytes | typing.Callable[[], str | bytes] + + class SSLContext(ssl.SSLContext): """SSLContext API that uses system certificates on all platforms""" @@ -84,14 +88,135 @@ def load_verify_locations( cafile=cafile, capath=capath, cadata=cadata ) - def __getattr__(self, name: str) -> Any: - return getattr(self._ctx, name) + def load_cert_chain( + self, + certfile: _StrOrBytesPath, + keyfile: _StrOrBytesPath | None = None, + password: _PasswordType | None = None, + ) -> None: + return self._ctx.load_cert_chain( + certfile=certfile, keyfile=keyfile, password=password + ) + + def load_default_certs( + self, purpose: ssl.Purpose = ssl.Purpose.SERVER_AUTH + ) -> None: + return self._ctx.load_default_certs(purpose) + + def set_alpn_protocols(self, alpn_protocols: typing.Iterable[str]) -> None: + return self._ctx.set_alpn_protocols(alpn_protocols) + + def set_npn_protocols(self, npn_protocols: typing.Iterable[str]) -> None: + return self._ctx.set_npn_protocols(npn_protocols) + + def set_ciphers(self, __cipherlist: str) -> None: + return self._ctx.set_ciphers(__cipherlist) + + def get_ciphers(self) -> typing.Any: + return self._ctx.get_ciphers() + + def session_stats(self) -> dict[str, int]: + return self._ctx.session_stats() + + def cert_store_stats(self) -> dict[str, int]: + raise NotImplementedError() + + @typing.overload + def get_ca_certs( + self, binary_form: typing.Literal[False] = ... + ) -> list[typing.Any]: + ... + + @typing.overload + def get_ca_certs(self, binary_form: typing.Literal[True] = ...) -> list[bytes]: + ... + + @typing.overload + def get_ca_certs(self, binary_form: bool = ...) -> typing.Any: + ... + + def get_ca_certs(self, binary_form: bool = False) -> list[typing.Any] | list[bytes]: + raise NotImplementedError() + + @property + def check_hostname(self) -> bool: + return self._ctx.check_hostname + + @check_hostname.setter + def check_hostname(self, value: bool) -> None: + self._ctx.check_hostname = value + + @property + def hostname_checks_common_name(self) -> bool: + return self._ctx.hostname_checks_common_name + + @hostname_checks_common_name.setter + def hostname_checks_common_name(self, value: bool) -> None: + self._ctx.hostname_checks_common_name = value + + @property + def keylog_filename(self) -> str: + return self._ctx.keylog_filename + + @keylog_filename.setter + def keylog_filename(self, value: str) -> None: + self._ctx.keylog_filename = value + + @property + def maximum_version(self) -> ssl.TLSVersion: + return self._ctx.maximum_version + + @maximum_version.setter + def maximum_version(self, value: ssl.TLSVersion) -> None: + self._ctx.maximum_version = value + + @property + def minimum_version(self) -> ssl.TLSVersion: + return self._ctx.minimum_version + + @minimum_version.setter + def minimum_version(self, value: ssl.TLSVersion) -> None: + self._ctx.minimum_version = value + + @property + def options(self) -> ssl.Options: + return self._ctx.options + + @options.setter + def options(self, value: ssl.Options) -> None: + self._ctx.options = value + + @property + def post_handshake_auth(self) -> bool: + return self._ctx.post_handshake_auth + + @post_handshake_auth.setter + def post_handshake_auth(self, value: bool) -> None: + self._ctx.post_handshake_auth = value + + @property + def protocol(self) -> ssl._SSLMethod: + return self._ctx.protocol + + @property + def security_level(self) -> int: + return self._ctx.security_level # type: ignore[attr-defined,no-any-return] + + @property + def verify_flags(self) -> ssl.VerifyFlags: + return self._ctx.verify_flags + + @verify_flags.setter + def verify_flags(self, value: ssl.VerifyFlags) -> None: + self._ctx.verify_flags = value + + @property + def verify_mode(self) -> ssl.VerifyMode: + return self._ctx.verify_mode - def __setattr__(self, name: str, value: Any) -> None: - if name == "verify_flags": - self._ctx.verify_flags = value - else: - return super().__setattr__(name, value) + @verify_mode.setter + def verify_mode(self, value: ssl.VerifyMode) -> None: + self._ctx.verify_mode = value def _verify_peercerts( diff --git a/tests/test_sslcontext.py b/tests/test_sslcontext.py new file mode 100644 index 0000000..5241abc --- /dev/null +++ b/tests/test_sslcontext.py @@ -0,0 +1,36 @@ +import json +import ssl + +import pytest +import urllib3 +from urllib3.exceptions import InsecureRequestWarning + +import truststore + + +def test_minimum_maximum_version(): + ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.maximum_version = ssl.TLSVersion.TLSv1_2 + with urllib3.PoolManager(ssl_context=ctx) as http: + + resp = http.request("GET", "https://howsmyssl.com/a/check") + data = json.loads(resp.data) + assert data["tls_version"] == "TLS 1.2" + + assert ctx.minimum_version in ( + ssl.TLSVersion.TLSv1_2, + ssl.TLSVersion.MINIMUM_SUPPORTED, + ) + assert ctx.maximum_version == ssl.TLSVersion.TLSv1_2 + + +def test_disable_verification(): + ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + with urllib3.PoolManager(ssl_context=ctx) as http, pytest.warns( + InsecureRequestWarning + ) as w: + http.request("GET", "https://expired.badssl.com/") + assert len(w) == 1