23
23
import struct
24
24
from typing import Any , Callable , Literal , Optional
25
25
26
+ from cryptography .x509 import (
27
+ CertificateBuilder as x509_CertificateBuilder ,
28
+ NameAttribute as x509_NameAttribute ,
29
+ random_serial_number as x509_random_serial_number ,
30
+ SubjectAlternativeName as x509_SubjectAlternativeName ,
31
+ IPAddress as x509_IPAddress ,
32
+ DNSName as x509_DNSName ,
33
+ load_pem_x509_certificate as x509_load_pem_x509_certificate ,
34
+ Name as x509_Name ,
35
+ )
26
36
from cryptography .hazmat .primitives import hashes
27
37
from cryptography .hazmat .primitives import serialization
28
38
from cryptography .hazmat .primitives .asymmetric import rsa
29
- import cryptography .x509 as x509
30
39
from cryptography .x509 .oid import NameOID
31
40
from google .auth .credentials import _helpers
32
41
from google .auth .credentials import TokenState
@@ -89,7 +98,7 @@ def token_state(
89
98
90
99
def generate_cert (
91
100
common_name : str , expires_in : int = 60 , server_cert : bool = False
92
- ) -> tuple [x509 . CertificateBuilder , rsa .RSAPrivateKey ]:
101
+ ) -> tuple [x509_CertificateBuilder , rsa .RSAPrivateKey ]:
93
102
"""
94
103
Generate a private key and cert object to be used in testing.
95
104
@@ -99,40 +108,40 @@ def generate_cert(
99
108
server_cert (bool): Whether it is a server certificate.
100
109
101
110
Returns:
102
- tuple[x509.CertificateBuilder , rsa.RSAPrivateKey]
111
+ tuple[x509_CertificateBuilder , rsa.RSAPrivateKey]
103
112
"""
104
113
# generate private key
105
114
key = rsa .generate_private_key (public_exponent = 65537 , key_size = 2048 )
106
115
# calculate expiry time
107
116
now = datetime .now (timezone .utc )
108
117
expiration = now + timedelta (minutes = expires_in )
109
118
# configure cert subject
110
- subject = issuer = x509 . Name (
119
+ subject = issuer = x509_Name (
111
120
[
112
- x509 . NameAttribute (NameOID .COUNTRY_NAME , "US" ),
113
- x509 . NameAttribute (NameOID .STATE_OR_PROVINCE_NAME , "California" ),
114
- x509 . NameAttribute (NameOID .LOCALITY_NAME , "Mountain View" ),
115
- x509 . NameAttribute (NameOID .ORGANIZATION_NAME , "Google Inc" ),
116
- x509 . NameAttribute (NameOID .COMMON_NAME , common_name ),
121
+ x509_NameAttribute (NameOID .COUNTRY_NAME , "US" ),
122
+ x509_NameAttribute (NameOID .STATE_OR_PROVINCE_NAME , "California" ),
123
+ x509_NameAttribute (NameOID .LOCALITY_NAME , "Mountain View" ),
124
+ x509_NameAttribute (NameOID .ORGANIZATION_NAME , "Google Inc" ),
125
+ x509_NameAttribute (NameOID .COMMON_NAME , common_name ),
117
126
]
118
127
)
119
128
# build cert
120
129
cert = (
121
- x509 . CertificateBuilder ()
130
+ x509_CertificateBuilder ()
122
131
.subject_name (subject )
123
132
.issuer_name (issuer )
124
133
.public_key (key .public_key ())
125
- .serial_number (x509 . random_serial_number ())
134
+ .serial_number (x509_random_serial_number ())
126
135
.not_valid_before (now )
127
136
.not_valid_after (expiration )
128
137
)
129
138
if server_cert :
130
139
cert = cert .add_extension (
131
- x509 . SubjectAlternativeName (
140
+ x509_SubjectAlternativeName (
132
141
general_names = [
133
- x509 . IPAddress (ipaddress .ip_address ("127.0.0.1" )),
134
- x509 . IPAddress (ipaddress .ip_address ("10.0.0.1" )),
135
- x509 . DNSName ("x.y.alloydb.goog." ),
142
+ x509_IPAddress (ipaddress .ip_address ("127.0.0.1" )),
143
+ x509_IPAddress (ipaddress .ip_address ("10.0.0.1" )),
144
+ x509_DNSName ("x.y.alloydb.goog." ),
136
145
]
137
146
),
138
147
critical = False ,
@@ -206,11 +215,11 @@ def generate_pem_certificate_chain(self, pub_key: str) -> tuple[str, list[str]]:
206
215
)
207
216
# build client cert
208
217
client_cert = (
209
- x509 . CertificateBuilder ()
218
+ x509_CertificateBuilder ()
210
219
.subject_name (self .intermediate_cert .subject )
211
220
.issuer_name (self .intermediate_cert .issuer )
212
221
.public_key (pub_key_bytes )
213
- .serial_number (x509 . random_serial_number ())
222
+ .serial_number (x509_random_serial_number ())
214
223
.not_valid_before (self .cert_before )
215
224
.not_valid_after (self .cert_expiry )
216
225
)
@@ -253,11 +262,11 @@ async def _get_client_certificate(
253
262
)
254
263
# build client cert
255
264
client_cert = (
256
- x509 . CertificateBuilder ()
265
+ x509_CertificateBuilder ()
257
266
.subject_name (self .instance .intermediate_cert .subject )
258
267
.issuer_name (self .instance .intermediate_cert .issuer )
259
268
.public_key (pub_key_bytes )
260
- .serial_number (x509 . random_serial_number ())
269
+ .serial_number (x509_random_serial_number ())
261
270
.not_valid_before (self .instance .cert_before )
262
271
.not_valid_after (self .instance .cert_expiry )
263
272
)
@@ -306,7 +315,7 @@ async def get_connection_info(
306
315
# unpack certs
307
316
ca_cert , cert_chain = certs
308
317
# get expiration from client certificate
309
- cert_obj = x509 . load_pem_x509_certificate (cert_chain [0 ].encode ("UTF-8" ))
318
+ cert_obj = x509_load_pem_x509_certificate (cert_chain [0 ].encode ("UTF-8" ))
310
319
expiration = cert_obj .not_valid_after_utc
311
320
312
321
return ConnectionInfo (
0 commit comments