Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Replace aiohttp.ClientSession with AlloyDBAdminAsyncClient #416

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
9 changes: 4 additions & 5 deletions google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from google.cloud.alloydb.connector.lazy import LazyRefreshCache
from google.cloud.alloydb.connector.types import CacheTypes
from google.cloud.alloydb.connector.utils import generate_keys
from google.cloud.alloydb.connector.utils import strip_http_prefix

if TYPE_CHECKING:
from google.auth.credentials import Credentials
Expand All @@ -51,7 +52,7 @@ class AsyncConnector:
billing purposes.
Defaults to None, picking up project from environment.
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
the AlloyDB API endpoint. Defaults to "alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
Expand All @@ -66,7 +67,7 @@ def __init__(
self,
credentials: Optional[Credentials] = None,
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
alloydb_api_endpoint: str = "alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
Expand All @@ -75,7 +76,7 @@ def __init__(
self._cache: dict[str, CacheTypes] = {}
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
self._alloydb_api_endpoint = strip_http_prefix(alloydb_api_endpoint)
self._enable_iam_auth = enable_iam_auth
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
Expand Down Expand Up @@ -235,5 +236,3 @@ async def close(self) -> None:
"""Helper function to cancel RefreshAheadCaches' tasks
and close client."""
await asyncio.gather(*[cache.close() for cache in self._cache.values()])
if self._client:
await self._client.close()
109 changes: 40 additions & 69 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
import logging
from typing import Optional, TYPE_CHECKING

import aiohttp
from cryptography import x509
from google.api_core.client_options import ClientOptions
from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth.credentials import TokenState
from google.auth.transport import requests
import google.cloud.alloydb_v1beta as v1beta
from google.protobuf import duration_pb2

from google.cloud.alloydb.connector.connection_info import ConnectionInfo
from google.cloud.alloydb.connector.version import __version__ as version
Expand Down Expand Up @@ -55,7 +58,7 @@ def __init__(
alloydb_api_endpoint: str,
quota_project: Optional[str],
credentials: Credentials,
client: Optional[aiohttp.ClientSession] = None,
client: Optional[v1beta.AlloyDBAdminAsyncClient] = None,
driver: Optional[str] = None,
user_agent: Optional[str] = None,
) -> None:
Expand All @@ -72,23 +75,28 @@ def __init__(
A credentials object created from the google-auth Python library.
Must have the AlloyDB Admin scopes. For more info check out
https://google-auth.readthedocs.io/en/latest/.
client (aiohttp.ClientSession): Async client used to make requests to
AlloyDB APIs.
client (v1beta.AlloyDBAdminAsyncClient): Async client used to make
requests to AlloyDB APIs.
Optional, defaults to None and creates new client.
driver (str): Database driver to be used by the client.
"""
user_agent = _format_user_agent(driver, user_agent)
headers = {
"x-goog-api-client": user_agent,
"User-Agent": user_agent,
"Content-Type": "application/json",
}
if quota_project:
headers["x-goog-user-project"] = quota_project

self._client = client if client else aiohttp.ClientSession(headers=headers)
self._client = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this still need to be lazy initialized? This was required for aiohttp client to make sure an async event loop was present during initialization. Is the same still needed for AlloyDBAdminAsyncClient?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "lazy initialized"? How is self._client being lazy initialized?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rhatgadkar-goog In the Connector initialization, the client is first set to None and lazy initialized during the first connect call. Is this still needed for the AlloyDBAdminAsyncClient?

self._client: Optional[AlloyDBClient] = None

If it is not then you can improve performance of the first connect call by properly initializing client in Connector init.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe the AlloyDBAdminAsyncClient needs to be initialized in an async context. The code snippet here shows that the constructor to AlloyDBAdminAsyncClient can be called without await.

Why do you think that aiohttp.ClientSession needs to be initialized in an async context? The ClientSession constructor is called without await here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the aiohttp.ClientSession attaches itself to the present event loop. It needs to be initialized after the async entrypoint has been called connect_async for Connector. Otherwise the client will be attached to a different event loop than the background thread used for refreshes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Yes, the aiohttp.ClientSession constructor is using the current event loop here.

But I don't see anything similar to this in the AlloyDBAdminAsyncClient constructor.

So I'll initialize the client in the Connector init.

Copy link
Collaborator Author

@rhatgadkar-goog rhatgadkar-goog Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need access to the driver. And this is being passed in the connect() function.

So we can't initialize the AlloyDBClient in the connector's constructor unless the driver is passed as an argument to the connector's __init()__ function.

We could move the driver as an argument of the connector's __init__() function. But this would be a breaking change, right? Because when a customer installs the latest AlloyDB Python connector, their client programs will break, because the connect() function won't take driver as an argument anymore. And they will need to pass driver into the __init__() function now.

Is it alright to move driver into the __init__() function? Do we need to somehow communicate this change to customers?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do NOT want to move driver into the Connector.__init__. The whole point of the Connector class is to share information that is not driver or instance specific. This is more relevant in Cloud SQL as we support more drivers, but for instance, a single connector = Connector() can be used to connect to MySQL, Postgres, and SQL Server drivers via connector.connect(). The same should be true for AlloyDB so that when more drivers are supported in the future, they do not require a new Connector

Copy link
Collaborator Author

@rhatgadkar-goog rhatgadkar-goog Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see. For the same argument of not having driver in Connector.__init__, can we also remove driver from AlloyDBClient.__init__? The AlloyDBClient isn't instance or driver specific either. To do this, we'll need to make the following changes:

What do you think about this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aiohttp.ClientSession has a way to pass in the user agent in the method argument. So we can do something like:

session = aiohttp.ClientSession()
headers = {"User-Agent": 'hello'}
response = await session.get(url, headers=headers)

Need to check if something similar to this is possible with AlloyDBAdminAsyncClient

client
if client
else v1beta.AlloyDBAdminAsyncClient(
credentials=credentials,
client_options=ClientOptions(
api_endpoint=alloydb_api_endpoint,
quota_project_id=quota_project,
),
client_info=ClientInfo(
user_agent=user_agent,
),
)
)
self._credentials = credentials
self._alloydb_api_endpoint = alloydb_api_endpoint
# asyncpg does not currently support using metadata exchange
# only use metadata exchange for pg8000 driver
self._use_metadata = True if driver == "pg8000" else False
Expand Down Expand Up @@ -118,35 +126,21 @@ async def _get_metadata(
Returns:
dict: IP addresses of the AlloyDB instance.
"""
headers = {
"Authorization": f"Bearer {self._credentials.token}",
}
parent = (
f"projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}"
)

url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}/connectionInfo"

resp = await self._client.get(url, headers=headers)
# try to get response json for better error message
try:
resp_dict = await resp.json()
if resp.status >= 400:
# if detailed error message is in json response, use as error message
message = resp_dict.get("error", {}).get("message")
if message:
resp.reason = message
# skip, raise_for_status will catch all errors in finally block
except Exception:
pass
finally:
resp.raise_for_status()
req = v1beta.GetConnectionInfoRequest(parent=parent)
resp = await self._client.get_connection_info(request=req)

# Remove trailing period from PSC DNS name.
psc_dns = resp_dict.get("pscDnsName")
psc_dns = resp.psc_dns_name
if psc_dns:
psc_dns = psc_dns.rstrip(".")

return {
"PRIVATE": resp_dict.get("ipAddress"),
"PUBLIC": resp_dict.get("publicIpAddress"),
"PRIVATE": resp.ip_address,
"PUBLIC": resp.public_ip_address,
"PSC": psc_dns,
}

Expand Down Expand Up @@ -175,34 +169,17 @@ async def _get_client_certificate(
tuple[str, list[str]]: tuple containing the CA certificate
and certificate chain for the AlloyDB instance.
"""
headers = {
"Authorization": f"Bearer {self._credentials.token}",
}

url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate"

data = {
"publicKey": pub_key,
"certDuration": "3600s",
"useMetadataExchange": self._use_metadata,
}

resp = await self._client.post(url, headers=headers, json=data)
# try to get response json for better error message
try:
resp_dict = await resp.json()
if resp.status >= 400:
# if detailed error message is in json response, use as error message
message = resp_dict.get("error", {}).get("message")
if message:
resp.reason = message
# skip, raise_for_status will catch all errors in finally block
except Exception:
pass
finally:
resp.raise_for_status()

return (resp_dict["caCert"], resp_dict["pemCertificateChain"])
parent = f"projects/{project}/locations/{region}/clusters/{cluster}"
dur = duration_pb2.Duration()
dur.seconds = 3600
req = v1beta.GenerateClientCertificateRequest(
parent=parent,
cert_duration=dur,
public_key=pub_key,
use_metadata_exchange=self._use_metadata,
)
resp = await self._client.generate_client_certificate(request=req)
return (resp.ca_cert, resp.pem_certificate_chain)

async def get_connection_info(
self,
Expand Down Expand Up @@ -267,9 +244,3 @@ async def get_connection_info(
ip_addrs,
expiration,
)

async def close(self) -> None:
"""Close AlloyDBClient gracefully."""
logger.debug("Waiting for connector's http client to close")
await self._client.close()
logger.debug("Closed connector's http client")
9 changes: 4 additions & 5 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import google.cloud.alloydb.connector.pg8000 as pg8000
from google.cloud.alloydb.connector.types import CacheTypes
from google.cloud.alloydb.connector.utils import generate_keys
from google.cloud.alloydb.connector.utils import strip_http_prefix
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb

if TYPE_CHECKING:
Expand Down Expand Up @@ -64,7 +65,7 @@ class Connector:
billing purposes.
Defaults to None, picking up project from environment.
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
the AlloyDB API endpoint. Defaults to "alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
Expand All @@ -85,7 +86,7 @@ def __init__(
self,
credentials: Optional[Credentials] = None,
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
alloydb_api_endpoint: str = "alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
Expand All @@ -99,7 +100,7 @@ def __init__(
self._cache: dict[str, CacheTypes] = {}
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
self._alloydb_api_endpoint = strip_http_prefix(alloydb_api_endpoint)
self._enable_iam_auth = enable_iam_auth
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
Expand Down Expand Up @@ -392,5 +393,3 @@ async def close_async(self) -> None:
"""Helper function to cancel RefreshAheadCaches' tasks
and close client."""
await asyncio.gather(*[cache.close() for cache in self._cache.values()])
if self._client:
await self._client.close()
12 changes: 12 additions & 0 deletions google/cloud/alloydb/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import annotations

import re

import aiofiles
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
Expand Down Expand Up @@ -58,3 +60,13 @@ async def generate_keys() -> tuple[rsa.RSAPrivateKey, str]:
.decode("UTF-8")
)
return (priv_key, pub_key)


def strip_http_prefix(url: str) -> str:
"""
Returns a new URL with 'http://' or 'https://' prefix removed.
"""
m = re.search(r"^(https?://)?(.+)", url)
if m is None:
return ""
return m.group(2)
6 changes: 6 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ ignore_missing_imports = True

[mypy-asyncpg]
ignore_missing_imports = True

[mypy-google.cloud.alloydb_v1beta]
ignore_missing_imports = True

[mypy-google.api_core.*]
ignore_missing_imports = True
38 changes: 34 additions & 4 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from google.auth.credentials import TokenState
from google.auth.transport import requests

from google.cloud import alloydb_v1beta
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb

Expand Down Expand Up @@ -232,7 +233,6 @@ def __init__(
self, instance: Optional[FakeInstance] = None, driver: str = "pg8000"
) -> None:
self.instance = FakeInstance() if instance is None else instance
self.closed = False
self._user_agent = f"test-user-agent+{driver}"
self._credentials = FakeCredentials()

Expand Down Expand Up @@ -317,9 +317,6 @@ async def get_connection_info(
expiration,
)

async def close(self) -> None:
self.closed = True


def metadata_exchange(sock: ssl.SSLSocket) -> None:
"""
Expand Down Expand Up @@ -448,3 +445,36 @@ def write_static_info(i: FakeInstance) -> io.StringIO:
"pscInstanceConfig": {"pscDnsName": i.ip_addrs["PSC"]},
}
return io.StringIO(json.dumps(static))


class FakeAlloyDBAdminAsyncClient:
async def get_connection_info(
self, request: alloydb_v1beta.GetConnectionInfoRequest
) -> alloydb_v1beta.types.resources.ConnectionInfo:
ci = alloydb_v1beta.types.resources.ConnectionInfo()
ci.ip_address = "10.0.0.1"
ci.public_ip_address = "127.0.0.1"
ci.instance_uid = "123456789"
ci.psc_dns_name = "x.y.alloydb.goog"

parent = request.parent
instance = parent.split("/")[-1]
if instance == "test-instance":
ci.public_ip_address = ""
ci.psc_dns_name = ""
elif instance == "public-instance":
ci.psc_dns_name = ""
else:
ci.ip_address = ""
ci.public_ip_address = ""
return ci

async def generate_client_certificate(
self, request: alloydb_v1beta.GenerateClientCertificateRequest
) -> alloydb_v1beta.types.service.GenerateClientCertificateResponse:
ccr = alloydb_v1beta.types.service.GenerateClientCertificateResponse()
ccr.ca_cert = "This is the CA cert"
ccr.pem_certificate_chain.append("This is the client cert")
ccr.pem_certificate_chain.append("This is the intermediate cert")
ccr.pem_certificate_chain.append("This is the root cert")
return ccr
Loading