Skip to content

Commit edf3372

Browse files
Don't import cryptography.x509
1 parent 32cb0b5 commit edf3372

File tree

1 file changed

+29
-20
lines changed

1 file changed

+29
-20
lines changed

tests/unit/mocks.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,19 @@
2323
import struct
2424
from typing import Any, Callable, Literal, Optional
2525

26+
from cryptography.x509 import (
27+
CertificateBuilder as x509_CertificateBuilder,
28+
NameAttribute as x509_NameAttribute,
29+
random_serial_number as x509_random_serial_number,
30+
SubjectAlternativeName as x509_SubjectAlternativeName,
31+
IPAddress as x509_IPAddress,
32+
DNSName as x509_DNSName,
33+
load_pem_x509_certificate as x509_load_pem_x509_certificate,
34+
Name as x509_Name,
35+
)
2636
from cryptography.hazmat.primitives import hashes
2737
from cryptography.hazmat.primitives import serialization
2838
from cryptography.hazmat.primitives.asymmetric import rsa
29-
import cryptography.x509 as x509
3039
from cryptography.x509.oid import NameOID
3140
from google.auth.credentials import _helpers
3241
from google.auth.credentials import TokenState
@@ -89,7 +98,7 @@ def token_state(
8998

9099
def generate_cert(
91100
common_name: str, expires_in: int = 60, server_cert: bool = False
92-
) -> tuple[x509.CertificateBuilder, rsa.RSAPrivateKey]:
101+
) -> tuple[x509_CertificateBuilder, rsa.RSAPrivateKey]:
93102
"""
94103
Generate a private key and cert object to be used in testing.
95104
@@ -99,40 +108,40 @@ def generate_cert(
99108
server_cert (bool): Whether it is a server certificate.
100109
101110
Returns:
102-
tuple[x509.CertificateBuilder, rsa.RSAPrivateKey]
111+
tuple[x509_CertificateBuilder, rsa.RSAPrivateKey]
103112
"""
104113
# generate private key
105114
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
106115
# calculate expiry time
107116
now = datetime.now(timezone.utc)
108117
expiration = now + timedelta(minutes=expires_in)
109118
# configure cert subject
110-
subject = issuer = x509.Name(
119+
subject = issuer = x509_Name(
111120
[
112-
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
113-
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
114-
x509.NameAttribute(NameOID.LOCALITY_NAME, "Mountain View"),
115-
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Google Inc"),
116-
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
121+
x509_NameAttribute(NameOID.COUNTRY_NAME, "US"),
122+
x509_NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
123+
x509_NameAttribute(NameOID.LOCALITY_NAME, "Mountain View"),
124+
x509_NameAttribute(NameOID.ORGANIZATION_NAME, "Google Inc"),
125+
x509_NameAttribute(NameOID.COMMON_NAME, common_name),
117126
]
118127
)
119128
# build cert
120129
cert = (
121-
x509.CertificateBuilder()
130+
x509_CertificateBuilder()
122131
.subject_name(subject)
123132
.issuer_name(issuer)
124133
.public_key(key.public_key())
125-
.serial_number(x509.random_serial_number())
134+
.serial_number(x509_random_serial_number())
126135
.not_valid_before(now)
127136
.not_valid_after(expiration)
128137
)
129138
if server_cert:
130139
cert = cert.add_extension(
131-
x509.SubjectAlternativeName(
140+
x509_SubjectAlternativeName(
132141
general_names=[
133-
x509.IPAddress(ipaddress.ip_address("127.0.0.1")),
134-
x509.IPAddress(ipaddress.ip_address("10.0.0.1")),
135-
x509.DNSName("x.y.alloydb.goog."),
142+
x509_IPAddress(ipaddress.ip_address("127.0.0.1")),
143+
x509_IPAddress(ipaddress.ip_address("10.0.0.1")),
144+
x509_DNSName("x.y.alloydb.goog."),
136145
]
137146
),
138147
critical=False,
@@ -206,11 +215,11 @@ def generate_pem_certificate_chain(self, pub_key: str) -> tuple[str, list[str]]:
206215
)
207216
# build client cert
208217
client_cert = (
209-
x509.CertificateBuilder()
218+
x509_CertificateBuilder()
210219
.subject_name(self.intermediate_cert.subject)
211220
.issuer_name(self.intermediate_cert.issuer)
212221
.public_key(pub_key_bytes)
213-
.serial_number(x509.random_serial_number())
222+
.serial_number(x509_random_serial_number())
214223
.not_valid_before(self.cert_before)
215224
.not_valid_after(self.cert_expiry)
216225
)
@@ -253,11 +262,11 @@ async def _get_client_certificate(
253262
)
254263
# build client cert
255264
client_cert = (
256-
x509.CertificateBuilder()
265+
x509_CertificateBuilder()
257266
.subject_name(self.instance.intermediate_cert.subject)
258267
.issuer_name(self.instance.intermediate_cert.issuer)
259268
.public_key(pub_key_bytes)
260-
.serial_number(x509.random_serial_number())
269+
.serial_number(x509_random_serial_number())
261270
.not_valid_before(self.instance.cert_before)
262271
.not_valid_after(self.instance.cert_expiry)
263272
)
@@ -306,7 +315,7 @@ async def get_connection_info(
306315
# unpack certs
307316
ca_cert, cert_chain = certs
308317
# get expiration from client certificate
309-
cert_obj = x509.load_pem_x509_certificate(cert_chain[0].encode("UTF-8"))
318+
cert_obj = x509_load_pem_x509_certificate(cert_chain[0].encode("UTF-8"))
310319
expiration = cert_obj.not_valid_after_utc
311320

312321
return ConnectionInfo(

0 commit comments

Comments
 (0)