Skip to content

Commit

Permalink
Add AzureDBForPostgres connector
Browse files Browse the repository at this point in the history
  • Loading branch information
lossyrob committed Oct 8, 2024
1 parent a4a3dd0 commit bec6917
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 16 deletions.
173 changes: 159 additions & 14 deletions python/samples/getting_started/third_party/postgres-memory.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Microsoft. All rights reserved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import TypeVar

from psycopg_pool import AsyncConnectionPool

from semantic_kernel.connectors.memory.azure_db_for_postgres.azure_db_for_postgres_settings import (
AzureDBForPostgresSettings,
)
from semantic_kernel.connectors.memory.postgres.constants import DEFAULT_SCHEMA
from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection
from semantic_kernel.data.vector_store_model_definition import VectorStoreRecordDefinition

TKey = TypeVar("TKey", str, int)
TModel = TypeVar("TModel")


class AzureDBForPostgresCollection(PostgresCollection[TKey, TModel]):
"""AzureDBForPostgresCollection class."""

def __init__(
self,
collection_name: str,
data_model_type: type[TModel],
data_model_definition: VectorStoreRecordDefinition | None = None,
connection_pool: AsyncConnectionPool | None = None,
db_schema: str = DEFAULT_SCHEMA,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
settings: AzureDBForPostgresSettings | None = None,
):
"""Initialize the collection.
Args:
collection_name: The name of the collection, which corresponds to the table name.
data_model_type (type[TModel]): The type of the data model.
data_model_definition: The data model definition.
connection_pool: The connection pool.
db_schema: The database schema.
env_file_path (str): Use the environment settings file as a fallback to environment variables.
env_file_encoding (str): The encoding of the environment settings file.
settings: The settings for the Azure DB for Postgres connection. If not provided, the settings will be
created from the environment.
"""
# If the connection pool or settings were not provided, create the settings from the environment.
# Passing this to the super class will enforce using Azure DB settings.
if not connection_pool and not settings:
settings = AzureDBForPostgresSettings.create(
env_file_path=env_file_path, env_file_encoding=env_file_encoding
)
super().__init__(
collection_name=collection_name,
data_model_type=data_model_type,
data_model_definition=data_model_definition,
connection_pool=connection_pool,
db_schema=db_schema,
settings=settings,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Microsoft. All rights reserved.
import sys
from typing import Any

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
else:
from typing_extensions import override # pragma: no cover

from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.identity import DefaultAzureCredential
from psycopg.conninfo import conninfo_to_dict

from semantic_kernel.connectors.memory.azure_db_for_postgres.utils import get_entra_token, get_entra_token_aysnc
from semantic_kernel.connectors.memory.postgres.postgres_settings import PostgresSettings


class AzureDBForPostgresSettings(PostgresSettings):
"""Azure DB for Postgres model settings.
This is the same as PostgresSettings, but does not a require a password.
If a password is not supplied, then Entra will use the Azure AD token.
You can also supply an Azure credential directly.
"""

credential: AsyncTokenCredential | TokenCredential | None = None

@override
def get_connection_args(self, **kwargs) -> dict[str, Any]:
"""Get connection arguments."""
password: Any = self.password.get_secret_value() if self.password else None
if not password and self.connection_string:
password = conninfo_to_dict(self.connection_string.get_secret_value()).get("password")

if not password:
self.credential = self.credential or DefaultAzureCredential()
if isinstance(self.credential, AsyncTokenCredential):
password = get_entra_token_aysnc(self.credential)
else:
password = get_entra_token(self.credential)

return super().get_connection_args(password=password)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.

from semantic_kernel.connectors.memory.postgres.postgres_store import PostgresStore


class AzureDBForPostgresStore(PostgresStore):
"""AzureDBForPostgresStore class."""

pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) Microsoft. All rights reserved.

AZURE_DB_FOR_POSTGRES_SCOPE = "https://ossrdbms-aad.database.windows.net/.default"
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Microsoft. All rights reserved.
import logging

from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential

from semantic_kernel.connectors.memory.azure_db_for_postgres.constants import AZURE_DB_FOR_POSTGRES_SCOPE

logger = logging.getLogger(__name__)


async def get_entra_token_aysnc(credential: AsyncTokenCredential) -> str:
"""Get the password from Entra using the provided credential."""
logger.info("Acquiring Entra token for postgres password")

async with credential:
cred = await credential.get_token(AZURE_DB_FOR_POSTGRES_SCOPE)
return cred.token


def get_entra_token(credential: TokenCredential) -> str:
"""Get the password from Entra using the provided credential."""
logger.info("Acquiring Entra token for postgres password")

return credential.get_token(AZURE_DB_FOR_POSTGRES_SCOPE).token
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,16 @@ class PostgresSettings(KernelBaseSettings):
default_dimensionality: int = 100
max_rows_per_transaction: int = 1000

def get_connection_args(self) -> dict[str, Any]:
"""Get connection arguments."""
def get_connection_args(self, **kwargs) -> dict[str, Any]:
"""Get connection arguments.
Args:
kwargs: dict[str, Any] - Additional arguments
Use this to override any connection arguments.
Returns:
dict[str, Any]: Connection arguments that can be passed to psycopg.connect
"""
result = conninfo_to_dict(self.connection_string.get_secret_value()) if self.connection_string else {}

if self.host:
Expand All @@ -86,6 +94,8 @@ def get_connection_args(self) -> dict[str, Any]:
if self.password:
result["password"] = self.password.get_secret_value()

result = {**result, **kwargs}

# Ensure required values
if "host" not in result:
raise MemoryConnectorInitializationError("host is required. Please set PGHOST or connection_string.")
Expand Down

0 comments on commit bec6917

Please sign in to comment.