Skip to content

Commit 88f5bcd

Browse files
feat: format user argument for automatic IAM authn (GoogleCloudPlatform#449)
1 parent 9a4b251 commit 88f5bcd

File tree

5 files changed

+81
-2
lines changed

5 files changed

+81
-2
lines changed

google/cloud/sql/connector/connector.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import google.cloud.sql.connector.pg8000 as pg8000
2525
import google.cloud.sql.connector.pytds as pytds
2626
import google.cloud.sql.connector.asyncpg as asyncpg
27-
from google.cloud.sql.connector.utils import generate_keys
27+
from google.cloud.sql.connector.utils import generate_keys, format_database_user
2828
from google.cloud.sql.connector.exceptions import ConnectorLoopError
2929
from google.auth.credentials import Credentials
3030
from threading import Thread
@@ -235,6 +235,18 @@ async def connect_async(
235235
# helper function to wrap in timeout
236236
async def get_connection() -> Any:
237237
instance_data, ip_address = await instance.connect_info(ip_type)
238+
239+
# format `user` param for automatic IAM database authn
240+
if enable_iam_auth:
241+
formatted_user = format_database_user(
242+
instance_data.database_version, kwargs["user"]
243+
)
244+
if formatted_user != kwargs["user"]:
245+
logger.debug(
246+
f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}"
247+
)
248+
kwargs["user"] = formatted_user
249+
238250
# async drivers are unblocking and can be awaited directly
239251
if driver in ASYNC_DRIVERS:
240252
return await connector(ip_address, instance_data.context, **kwargs)

google/cloud/sql/connector/instance.py

+4
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,21 @@ class IPTypes(Enum):
6464
class InstanceMetadata:
6565
ip_addrs: Dict[str, Any]
6666
context: ssl.SSLContext
67+
database_version: str
6768
expiration: datetime.datetime
6869

6970
def __init__(
7071
self,
7172
ephemeral_cert: str,
73+
database_version: str,
7274
ip_addrs: Dict[str, Any],
7375
private_key: bytes,
7476
server_ca_cert: str,
7577
expiration: datetime.datetime,
7678
enable_iam_auth: bool,
7779
) -> None:
7880
self.ip_addrs = ip_addrs
81+
self.database_version = database_version
7982
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS)
8083

8184
# verify OpenSSL version supports TLSv1.3
@@ -370,6 +373,7 @@ async def _perform_refresh(self) -> InstanceMetadata:
370373

371374
return InstanceMetadata(
372375
ephemeral_cert,
376+
metadata["database_version"],
373377
metadata["ip_addresses"],
374378
priv_key,
375379
metadata["server_ca_cert"],

google/cloud/sql/connector/utils.py

+25
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,28 @@ def write_to_file(
7777
priv_out.write(priv_key)
7878

7979
return (ca_filename, cert_filename, key_filename)
80+
81+
82+
def format_database_user(database_version: str, user: str) -> str:
83+
"""
84+
Format database `user` param for Cloud SQL automatic IAM authentication.
85+
86+
:type database_version: str
87+
:param database_version
88+
Cloud SQL database version. (i.e. POSTGRES_14, MYSQL8_0, etc.)
89+
90+
:type user: str
91+
:param user
92+
Database username to connect to Cloud SQL database with.
93+
"""
94+
# remove suffix for Postgres service accounts
95+
if database_version.startswith("POSTGRES"):
96+
suffix = ".gserviceaccount.com"
97+
user = user[: -len(suffix)] if user.endswith(suffix) else user
98+
return user
99+
100+
# remove everything after and including the @ for MySQL
101+
if database_version.startswith("MYSQL") and "@" in user:
102+
return user.split("@")[0]
103+
104+
return user

tests/system/test_connector_object.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import logging
2222
import google.auth
2323
from google.cloud.sql.connector import Connector
24-
from google.cloud.sql.connector.instance import AutoIAMAuthNotSupported
24+
from google.cloud.sql.connector.exceptions import (
25+
AutoIAMAuthNotSupported,
26+
)
2527
import datetime
2628
import concurrent.futures
2729
from threading import Thread

tests/unit/test_utils.py

+36
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,39 @@ async def test_generate_keys_returns_bytes_and_str() -> None:
3939

4040
res1, res2 = await utils.generate_keys()
4141
assert isinstance(res1, bytes) and (isinstance(res2, str))
42+
43+
44+
def test_format_database_user_postgres() -> None:
45+
"""
46+
Test that format_database_user properly formats Postgres IAM database users.
47+
"""
48+
service_account = utils.format_database_user(
49+
"POSTGRES_14", "[email protected]"
50+
)
51+
service_account2 = utils.format_database_user(
52+
"POSTGRES_14", "[email protected]"
53+
)
54+
assert service_account == "[email protected]"
55+
assert service_account2 == "[email protected]"
56+
user = utils.format_database_user("POSTGRES_14", "[email protected]")
57+
assert user == "[email protected]"
58+
59+
60+
def test_format_database_user_mysql() -> None:
61+
"""
62+
Test that format_database _user properly formats MySQL IAM database users.
63+
"""
64+
service_account = utils.format_database_user(
65+
"MYSQL_8_0", "[email protected]"
66+
)
67+
service_account2 = utils.format_database_user(
68+
"MYSQL_8_0", "[email protected]"
69+
)
70+
service_account3 = utils.format_database_user("MYSQL_8_0", "service-account")
71+
assert service_account == "service-account"
72+
assert service_account2 == "service-account"
73+
assert service_account3 == "service-account"
74+
user = utils.format_database_user("MYSQL_8_0", "[email protected]")
75+
user2 = utils.format_database_user("MYSQL_8_0", "test")
76+
assert user == "test"
77+
assert user2 == "test"

0 commit comments

Comments
 (0)