diff --git a/python/tests/integration/memory/vector_stores/postgres/test_postgres_int.py b/python/tests/integration/memory/vector_stores/postgres/test_postgres_int.py index fb280e047a39..49078b92e6b2 100644 --- a/python/tests/integration/memory/vector_stores/postgres/test_postgres_int.py +++ b/python/tests/integration/memory/vector_stores/postgres/test_postgres_int.py @@ -10,6 +10,10 @@ import pytest_asyncio from pydantic import BaseModel +from semantic_kernel.connectors.memory.azure_db_for_postgres.azure_db_for_postgres_settings import ( + AzureDBForPostgresSettings, +) +from semantic_kernel.connectors.memory.azure_db_for_postgres.azure_db_for_postgres_store import AzureDBForPostgresStore from semantic_kernel.connectors.memory.postgres import PostgresSettings, PostgresStore from semantic_kernel.data import ( DistanceFunction, @@ -40,8 +44,8 @@ connection_params_present = False pytestmark = pytest.mark.skipif( - not (psycopg_pool_installed or connection_params_present), - reason="psycopg_pool is not installed" if not psycopg_pool_installed else "No connection parameters provided", + not psycopg_pool_installed, + reason="psycopg_pool is not installed", ) @@ -85,15 +89,33 @@ def DataModelPandas(record) -> tuple: return definition, df -@pytest_asyncio.fixture -async def vector_store() -> AsyncGenerator[PostgresStore, None]: +@pytest_asyncio.fixture( + # Parametrize over all Postgres stores. + params=["PostgresStore", "AzureDBForPostgresStore"] +) +async def vector_store(request) -> AsyncGenerator[PostgresStore, None]: + store_type = request.param + if store_type == "PostgresStore": + settings = PostgresSettings.create() + elif store_type == "AzureDBForPostgresStore": + settings = AzureDBForPostgresSettings.create() + + try: + connection_params_present = any(settings.get_connection_args().values()) + except MemoryConnectorInitializationError: + connection_params_present = False + + if not connection_params_present: + pytest.skip(f"No connection parameters provided for {store_type}") + try: - async with await pg_settings.create_connection_pool() as pool: - yield PostgresStore(connection_pool=pool) + async with await settings.create_connection_pool() as pool: + if store_type == "PostgresStore": + yield PostgresStore(connection_pool=pool) + elif store_type == "AzureDBForPostgresStore": + yield AzureDBForPostgresStore(connection_pool=pool) except MemoryConnectorConnectionException: - pytest.skip("Postgres connection not available") - yield None - return + pytest.skip(f"{store_type} connection not available") @asynccontextmanager