Skip to content

Commit

Permalink
Change the inner SSLContext configuration to undo post-handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson committed Jan 25, 2023
1 parent 1df4895 commit 0d7f8a2
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 25 deletions.
36 changes: 20 additions & 16 deletions src/truststore/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class SSLContext(ssl.SSLContext):

def __init__(self, protocol: int = ssl.PROTOCOL_TLS) -> None:
self._ctx = ssl.SSLContext(protocol)
_configure_context(self._ctx)

class TruststoreSSLObject(ssl.SSLObject):
# This object exists because wrap_bio() doesn't
Expand All @@ -46,14 +45,18 @@ def wrap_socket(
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLSocket:
ssl_sock = self._ctx.wrap_socket(
sock,
server_side=server_side,
server_hostname=server_hostname,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
session=session,
)
# Use a context manager here because the
# inner SSLContext holds on to our state
# but also does the actual handshake.
with _configure_context(self._ctx):
ssl_sock = self._ctx.wrap_socket(
sock,
server_side=server_side,
server_hostname=server_hostname,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
session=session,
)
try:
_verify_peercerts(ssl_sock, server_hostname=server_hostname)
except ssl.SSLError:
Expand All @@ -69,13 +72,14 @@ def wrap_bio(
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLObject:
ssl_obj = self._ctx.wrap_bio(
incoming,
outgoing,
server_hostname=server_hostname,
server_side=server_side,
session=session,
)
with _configure_context(self._ctx):
ssl_obj = self._ctx.wrap_bio(
incoming,
outgoing,
server_hostname=server_hostname,
server_side=server_side,
session=session,
)
return ssl_obj

def load_verify_locations(
Expand Down
13 changes: 10 additions & 3 deletions src/truststore/_macos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import ctypes
import platform
import ssl
import typing
from ctypes import (
CDLL,
POINTER,
Expand All @@ -13,7 +15,6 @@
c_void_p,
)
from ctypes.util import find_library
from typing import Any

_mac_version = platform.mac_ver()[0]
_mac_version_info = tuple(map(int, _mac_version.split(".")))
Expand Down Expand Up @@ -201,7 +202,7 @@ def _load_cdll(name: str, macos10_16_path: str) -> CDLL:
raise ImportError("Error initializing ctypes") from None


def _handle_osstatus(result: OSStatus, _: Any, args: Any) -> Any:
def _handle_osstatus(result: OSStatus, _: typing.Any, args: typing.Any) -> typing.Any:
"""
Raises an error if the OSStatus value is non-zero.
"""
Expand Down Expand Up @@ -338,9 +339,15 @@ def _der_certs_to_cf_cert_array(certs: list[bytes]) -> CFMutableArrayRef: # typ
return cf_array # type: ignore[no-any-return]


def _configure_context(ctx: ssl.SSLContext) -> None:
@contextlib.contextmanager
def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]:
values = ctx.check_hostname, ctx.verify_mode
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
try:
yield
finally:
ctx.check_hostname, ctx.verify_mode = values


def _verify_peercerts_impl(
Expand Down
8 changes: 5 additions & 3 deletions src/truststore/_openssl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import os
import re
import ssl
import typing

# candidates based on https://github.com/tiran/certifi-system-store by Christian Heimes
_CA_FILE_CANDIDATES = [
Expand All @@ -17,7 +19,8 @@
_HASHED_CERT_FILENAME_RE = re.compile(r"^[0-9a-fA-F]{8}\.[0-9]$")


def _configure_context(ctx: ssl.SSLContext) -> None:
@contextlib.contextmanager
def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]:
# First, check whether the default locations from OpenSSL
# seem like they will give us a usable set of CA certs.
# ssl.get_default_verify_paths already takes care of:
Expand All @@ -40,8 +43,7 @@ def _configure_context(ctx: ssl.SSLContext) -> None:
ctx.load_verify_locations(cafile=cafile)
break

ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = True
yield


def _capath_contains_certs(capath: str) -> bool:
Expand Down
43 changes: 42 additions & 1 deletion src/truststore/_windows.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import contextlib
import ssl
import typing
from ctypes import WinDLL # type: ignore
from ctypes import WinError # type: ignore
from ctypes import (
Expand Down Expand Up @@ -199,11 +201,33 @@ class CERT_CHAIN_ENGINE_CONFIG(Structure):
OID_PKIX_KP_SERVER_AUTH = c_char_p(b"1.3.6.1.5.5.7.3.1")
CERT_CHAIN_REVOCATION_CHECK_END_CERT = 0x10000000
CERT_CHAIN_REVOCATION_CHECK_CHAIN = 0x20000000
CERT_CHAIN_POLICY_IGNORE_ALL_NOT_TIME_VALID_FLAGS = 0x00000007
CERT_CHAIN_POLICY_IGNORE_INVALID_BASIC_CONSTRAINTS_FLAG = 0x00000008
CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG = 0x00000010
CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG = 0x00000040
CERT_CHAIN_POLICY_IGNORE_WRONG_USAGE_FLAG = 0x00000020
CERT_CHAIN_POLICY_IGNORE_INVALID_POLICY_FLAG = 0x00000080
CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS = 0x00000F00
CERT_CHAIN_POLICY_ALLOW_TESTROOT_FLAG = 0x00008000
CERT_CHAIN_POLICY_TRUST_TESTROOT_FLAG = 0x00004000
AUTHTYPE_SERVER = 2
CERT_CHAIN_POLICY_SSL = 4
FORMAT_MESSAGE_FROM_SYSTEM = 0x00001000
FORMAT_MESSAGE_IGNORE_INSERTS = 0x00000200

# Flags to set for SSLContext.verify_mode=CERT_NONE
CERT_CHAIN_POLICY_VERIFY_MODE_NONE_FLAGS = (
CERT_CHAIN_POLICY_IGNORE_ALL_NOT_TIME_VALID_FLAGS
| CERT_CHAIN_POLICY_IGNORE_INVALID_BASIC_CONSTRAINTS_FLAG
| CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG
| CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG
| CERT_CHAIN_POLICY_IGNORE_WRONG_USAGE_FLAG
| CERT_CHAIN_POLICY_IGNORE_INVALID_POLICY_FLAG
| CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS
| CERT_CHAIN_POLICY_ALLOW_TESTROOT_FLAG
| CERT_CHAIN_POLICY_TRUST_TESTROOT_FLAG
)

wincrypt = WinDLL("crypt32.dll")
kernel32 = WinDLL("kernel32.dll")

Expand Down Expand Up @@ -341,6 +365,7 @@ def _verify_peercerts_impl(
# First attempt to verify using the default Windows system trust roots
# (default chain engine).
_get_and_verify_cert_chain(
ssl_context,
None,
hIntermediateCertStore,
pCertContext,
Expand All @@ -358,6 +383,7 @@ def _verify_peercerts_impl(
)
if custom_ca_certs:
_verify_using_custom_ca_certs(
ssl_context,
custom_ca_certs,
hIntermediateCertStore,
pCertContext,
Expand All @@ -374,6 +400,7 @@ def _verify_peercerts_impl(


def _get_and_verify_cert_chain(
ssl_context: ssl.SSLContext,
hChainEngine: HCERTCHAINENGINE | None,
hIntermediateCertStore: HCERTSTORE,
pPeerCertContext: c_void_p,
Expand Down Expand Up @@ -406,11 +433,17 @@ def _get_and_verify_cert_chain(
ssl_extra_cert_chain_policy_para.fdwChecks = 0
if server_hostname:
ssl_extra_cert_chain_policy_para.pwszServerName = c_wchar_p(server_hostname)

chain_policy = CERT_CHAIN_POLICY_PARA()
chain_policy.pvExtraPolicyPara = cast(
pointer(ssl_extra_cert_chain_policy_para), c_void_p
)
if ssl_context.verify_mode == ssl.CERT_NONE:
chain_policy.dwFlags |= CERT_CHAIN_POLICY_VERIFY_MODE_NONE_FLAGS
if not ssl_context.check_hostname:
chain_policy.dwFlags |= CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG
chain_policy.cbSize = sizeof(chain_policy)

pPolicyPara = pointer(chain_policy)
policy_status = CERT_CHAIN_POLICY_STATUS()
policy_status.cbSize = sizeof(policy_status)
Expand Down Expand Up @@ -456,6 +489,7 @@ def _get_and_verify_cert_chain(


def _verify_using_custom_ca_certs(
ssl_context: ssl.SSLContext,
custom_ca_certs: list[bytes],
hIntermediateCertStore: HCERTSTORE,
pPeerCertContext: c_void_p,
Expand Down Expand Up @@ -492,6 +526,7 @@ def _verify_using_custom_ca_certs(

# Get and verify a cert chain using the custom chain engine
_get_and_verify_cert_chain(
ssl_context,
hChainEngine,
hIntermediateCertStore,
pPeerCertContext,
Expand All @@ -505,6 +540,12 @@ def _verify_using_custom_ca_certs(
CertCloseStore(hRootCertStore, 0)


def _configure_context(ctx: ssl.SSLContext) -> None:
@contextlib.contextmanager
def _configure_context(ctx: ssl.SSLContext) -> typing.Iterator[None]:
values = ctx.check_hostname, ctx.verify_mode
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
try:
yield
finally:
ctx.check_hostname, ctx.verify_mode = values
22 changes: 20 additions & 2 deletions tests/test_sslcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

import pytest
import urllib3
from urllib3.exceptions import InsecureRequestWarning
from urllib3.exceptions import InsecureRequestWarning, SSLError

import truststore


def test_minimum_maximum_version():
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.maximum_version = ssl.TLSVersion.TLSv1_2

with urllib3.PoolManager(ssl_context=ctx) as http:

resp = http.request("GET", "https://howsmyssl.com/a/check")
Expand All @@ -24,11 +25,28 @@ def test_minimum_maximum_version():
assert ctx.maximum_version == ssl.TLSVersion.TLSv1_2


def test_disable_verification():
def test_check_hostname_false():
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
assert ctx.check_hostname is True
assert ctx.verify_mode == ssl.CERT_REQUIRED

with urllib3.PoolManager(ssl_context=ctx, retries=False) as http:
with pytest.raises(SSLError) as e:
http.request("GET", "https://wrong.host.badssl.com/")
assert "match" in str(e.value)


def test_verify_mode_cert_none():
ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
assert ctx.check_hostname is True
assert ctx.verify_mode == ssl.CERT_REQUIRED

ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE

assert ctx.check_hostname is False
assert ctx.verify_mode == ssl.CERT_NONE

with urllib3.PoolManager(ssl_context=ctx) as http, pytest.warns(
InsecureRequestWarning
) as w:
Expand Down

0 comments on commit 0d7f8a2

Please sign in to comment.