Skip to content

Commit

Permalink
Add SSLContext methods and properties
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson committed Jan 22, 2023
1 parent b6db4be commit 1df4895
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 8 deletions.
141 changes: 133 additions & 8 deletions src/truststore/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import platform
import socket
import ssl
from typing import Any
import typing

from _ssl import ENCODING_DER # type: ignore[import]

Expand All @@ -14,6 +14,10 @@
from ._openssl import _configure_context, _verify_peercerts_impl


_StrOrBytesPath: typing.TypeAlias = str | bytes | os.PathLike[str] | os.PathLike[bytes]
_PasswordType: typing.TypeAlias = str | bytes | typing.Callable[[], str | bytes]


class SSLContext(ssl.SSLContext):
"""SSLContext API that uses system certificates on all platforms"""

Expand Down Expand Up @@ -84,14 +88,135 @@ def load_verify_locations(
cafile=cafile, capath=capath, cadata=cadata
)

def __getattr__(self, name: str) -> Any:
return getattr(self._ctx, name)
def load_cert_chain(
self,
certfile: _StrOrBytesPath,
keyfile: _StrOrBytesPath | None = None,
password: _PasswordType | None = None,
) -> None:
return self._ctx.load_cert_chain(
certfile=certfile, keyfile=keyfile, password=password
)

def load_default_certs(
self, purpose: ssl.Purpose = ssl.Purpose.SERVER_AUTH
) -> None:
return self._ctx.load_default_certs(purpose)

def set_alpn_protocols(self, alpn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_alpn_protocols(alpn_protocols)

def set_npn_protocols(self, npn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_npn_protocols(npn_protocols)

def set_ciphers(self, __cipherlist: str) -> None:
return self._ctx.set_ciphers(__cipherlist)

def get_ciphers(self) -> typing.Any:
return self._ctx.get_ciphers()

def session_stats(self) -> dict[str, int]:
return self._ctx.session_stats()

def cert_store_stats(self) -> dict[str, int]:
raise NotImplementedError()

@typing.overload
def get_ca_certs(
self, binary_form: typing.Literal[False] = ...
) -> list[typing.Any]:
...

@typing.overload
def get_ca_certs(self, binary_form: typing.Literal[True] = ...) -> list[bytes]:
...

@typing.overload
def get_ca_certs(self, binary_form: bool = ...) -> typing.Any:
...

def get_ca_certs(self, binary_form: bool = False) -> list[typing.Any] | list[bytes]:
raise NotImplementedError()

@property
def check_hostname(self) -> bool:
return self._ctx.check_hostname

@check_hostname.setter
def check_hostname(self, value: bool) -> None:
self._ctx.check_hostname = value

@property
def hostname_checks_common_name(self) -> bool:
return self._ctx.hostname_checks_common_name

@hostname_checks_common_name.setter
def hostname_checks_common_name(self, value: bool) -> None:
self._ctx.hostname_checks_common_name = value

@property
def keylog_filename(self) -> str:
return self._ctx.keylog_filename

@keylog_filename.setter
def keylog_filename(self, value: str) -> None:
self._ctx.keylog_filename = value

@property
def maximum_version(self) -> ssl.TLSVersion:
return self._ctx.maximum_version

@maximum_version.setter
def maximum_version(self, value: ssl.TLSVersion) -> None:
self._ctx.maximum_version = value

@property
def minimum_version(self) -> ssl.TLSVersion:
return self._ctx.minimum_version

@minimum_version.setter
def minimum_version(self, value: ssl.TLSVersion) -> None:
self._ctx.minimum_version = value

@property
def options(self) -> ssl.Options:
return self._ctx.options

@options.setter
def options(self, value: ssl.Options) -> None:
self._ctx.options = value

@property
def post_handshake_auth(self) -> bool:
return self._ctx.post_handshake_auth

@post_handshake_auth.setter
def post_handshake_auth(self, value: bool) -> None:
self._ctx.post_handshake_auth = value

@property
def protocol(self) -> ssl._SSLMethod:
return self._ctx.protocol

@property
def security_level(self) -> int:
return self._ctx.security_level # type: ignore[attr-defined,no-any-return]

@property
def verify_flags(self) -> ssl.VerifyFlags:
return self._ctx.verify_flags

@verify_flags.setter
def verify_flags(self, value: ssl.VerifyFlags) -> None:
self._ctx.verify_flags = value

@property
def verify_mode(self) -> ssl.VerifyMode:
return self._ctx.verify_mode

def __setattr__(self, name: str, value: Any) -> None:
if name == "verify_flags":
self._ctx.verify_flags = value
else:
return super().__setattr__(name, value)
@verify_mode.setter
def verify_mode(self, value: ssl.VerifyMode) -> None:
self._ctx.verify_mode = value


def _verify_peercerts(
Expand Down
36 changes: 36 additions & 0 deletions tests/test_sslcontext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import json
import ssl

import pytest
import urllib3
from urllib3.exceptions import InsecureRequestWarning

import truststore


def test_minimum_maximum_version():
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.maximum_version = ssl.TLSVersion.TLSv1_2
with urllib3.PoolManager(ssl_context=ctx) as http:

resp = http.request("GET", "https://howsmyssl.com/a/check")
data = json.loads(resp.data)
assert data["tls_version"] == "TLS 1.2"

assert ctx.minimum_version in (
ssl.TLSVersion.TLSv1_2,
ssl.TLSVersion.MINIMUM_SUPPORTED,
)
assert ctx.maximum_version == ssl.TLSVersion.TLSv1_2


def test_disable_verification():
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE

with urllib3.PoolManager(ssl_context=ctx) as http, pytest.warns(
InsecureRequestWarning
) as w:
http.request("GET", "https://expired.badssl.com/")
assert len(w) == 1

0 comments on commit 1df4895

Please sign in to comment.