-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
309 additions
and
16 deletions.
There are no files selected for viewing
173 changes: 159 additions & 14 deletions
173
python/samples/getting_started/third_party/postgres-memory.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
1 change: 1 addition & 0 deletions
1
python/semantic_kernel/connectors/memory/azure_db_for_postgres/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Copyright (c) Microsoft. All rights reserved. |
57 changes: 57 additions & 0 deletions
57
...mantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_collection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
from typing import TypeVar | ||
|
||
from psycopg_pool import AsyncConnectionPool | ||
|
||
from semantic_kernel.connectors.memory.azure_db_for_postgres.azure_db_for_postgres_settings import ( | ||
AzureDBForPostgresSettings, | ||
) | ||
from semantic_kernel.connectors.memory.postgres.constants import DEFAULT_SCHEMA | ||
from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection | ||
from semantic_kernel.data.vector_store_model_definition import VectorStoreRecordDefinition | ||
|
||
TKey = TypeVar("TKey", str, int) | ||
TModel = TypeVar("TModel") | ||
|
||
|
||
class AzureDBForPostgresCollection(PostgresCollection[TKey, TModel]): | ||
"""AzureDBForPostgresCollection class.""" | ||
|
||
def __init__( | ||
self, | ||
collection_name: str, | ||
data_model_type: type[TModel], | ||
data_model_definition: VectorStoreRecordDefinition | None = None, | ||
connection_pool: AsyncConnectionPool | None = None, | ||
db_schema: str = DEFAULT_SCHEMA, | ||
env_file_path: str | None = None, | ||
env_file_encoding: str | None = None, | ||
settings: AzureDBForPostgresSettings | None = None, | ||
): | ||
"""Initialize the collection. | ||
Args: | ||
collection_name: The name of the collection, which corresponds to the table name. | ||
data_model_type (type[TModel]): The type of the data model. | ||
data_model_definition: The data model definition. | ||
connection_pool: The connection pool. | ||
db_schema: The database schema. | ||
env_file_path (str): Use the environment settings file as a fallback to environment variables. | ||
env_file_encoding (str): The encoding of the environment settings file. | ||
settings: The settings for the Azure DB for Postgres connection. If not provided, the settings will be | ||
created from the environment. | ||
""" | ||
# If the connection pool or settings were not provided, create the settings from the environment. | ||
# Passing this to the super class will enforce using Azure DB settings. | ||
if not connection_pool and not settings: | ||
settings = AzureDBForPostgresSettings.create( | ||
env_file_path=env_file_path, env_file_encoding=env_file_encoding | ||
) | ||
super().__init__( | ||
collection_name=collection_name, | ||
data_model_type=data_model_type, | ||
data_model_definition=data_model_definition, | ||
connection_pool=connection_pool, | ||
db_schema=db_schema, | ||
settings=settings, | ||
) |
43 changes: 43 additions & 0 deletions
43
...semantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_settings.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
import sys | ||
from typing import Any | ||
|
||
if sys.version_info >= (3, 12): | ||
from typing import override # pragma: no cover | ||
else: | ||
from typing_extensions import override # pragma: no cover | ||
|
||
from azure.core.credentials import TokenCredential | ||
from azure.core.credentials_async import AsyncTokenCredential | ||
from azure.identity import DefaultAzureCredential | ||
from psycopg.conninfo import conninfo_to_dict | ||
|
||
from semantic_kernel.connectors.memory.azure_db_for_postgres.utils import get_entra_token, get_entra_token_aysnc | ||
from semantic_kernel.connectors.memory.postgres.postgres_settings import PostgresSettings | ||
|
||
|
||
class AzureDBForPostgresSettings(PostgresSettings): | ||
"""Azure DB for Postgres model settings. | ||
This is the same as PostgresSettings, but does not a require a password. | ||
If a password is not supplied, then Entra will use the Azure AD token. | ||
You can also supply an Azure credential directly. | ||
""" | ||
|
||
credential: AsyncTokenCredential | TokenCredential | None = None | ||
|
||
@override | ||
def get_connection_args(self, **kwargs) -> dict[str, Any]: | ||
"""Get connection arguments.""" | ||
password: Any = self.password.get_secret_value() if self.password else None | ||
if not password and self.connection_string: | ||
password = conninfo_to_dict(self.connection_string.get_secret_value()).get("password") | ||
|
||
if not password: | ||
self.credential = self.credential or DefaultAzureCredential() | ||
if isinstance(self.credential, AsyncTokenCredential): | ||
password = get_entra_token_aysnc(self.credential) | ||
else: | ||
password = get_entra_token(self.credential) | ||
|
||
return super().get_connection_args(password=password) |
9 changes: 9 additions & 0 deletions
9
...on/semantic_kernel/connectors/memory/azure_db_for_postgres/azure_db_for_postgres_store.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
from semantic_kernel.connectors.memory.postgres.postgres_store import PostgresStore | ||
|
||
|
||
class AzureDBForPostgresStore(PostgresStore): | ||
"""AzureDBForPostgresStore class.""" | ||
|
||
pass |
3 changes: 3 additions & 0 deletions
3
python/semantic_kernel/connectors/memory/azure_db_for_postgres/constants.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
AZURE_DB_FOR_POSTGRES_SCOPE = "https://ossrdbms-aad.database.windows.net/.default" |
25 changes: 25 additions & 0 deletions
25
python/semantic_kernel/connectors/memory/azure_db_for_postgres/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
import logging | ||
|
||
from azure.core.credentials import TokenCredential | ||
from azure.core.credentials_async import AsyncTokenCredential | ||
|
||
from semantic_kernel.connectors.memory.azure_db_for_postgres.constants import AZURE_DB_FOR_POSTGRES_SCOPE | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
async def get_entra_token_aysnc(credential: AsyncTokenCredential) -> str: | ||
"""Get the password from Entra using the provided credential.""" | ||
logger.info("Acquiring Entra token for postgres password") | ||
|
||
async with credential: | ||
cred = await credential.get_token(AZURE_DB_FOR_POSTGRES_SCOPE) | ||
return cred.token | ||
|
||
|
||
def get_entra_token(credential: TokenCredential) -> str: | ||
"""Get the password from Entra using the provided credential.""" | ||
logger.info("Acquiring Entra token for postgres password") | ||
|
||
return credential.get_token(AZURE_DB_FOR_POSTGRES_SCOPE).token |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters