Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions openfeature/provider/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ def set_provider(self, domain: str, provider: FeatureProvider) -> None:
if domain in providers:
old_provider = providers[domain]
del providers[domain]
if old_provider not in providers.values():
if (
old_provider != self._default_provider
and old_provider not in providers.values()
):
self._shutdown_provider(old_provider)
if provider not in providers.values():
if provider != self._default_provider and provider not in providers.values():
self._initialize_provider(provider)
providers[domain] = provider

Expand All @@ -44,10 +47,15 @@ def get_provider(self, domain: str | None) -> FeatureProvider:
def set_default_provider(self, provider: FeatureProvider) -> None:
if provider is None:
raise GeneralError(error_message="No provider")
if self._default_provider:
if (
self._default_provider
and self._default_provider not in self._providers.values()
):
self._shutdown_provider(self._default_provider)
self._default_provider = provider
self._initialize_provider(provider)

if self._default_provider not in self._providers.values():
self._initialize_provider(provider)

def get_default_provider(self) -> FeatureProvider:
return self._default_provider
Expand Down Expand Up @@ -94,7 +102,7 @@ def _shutdown_provider(self, provider: FeatureProvider) -> None:
try:
if hasattr(provider, "shutdown"):
provider.shutdown()
self._provider_status[provider] = ProviderStatus.NOT_READY
del self._provider_status[provider]
except Exception as err:
self.dispatch_event(
provider,
Expand Down
123 changes: 123 additions & 0 deletions tests/provider/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from openfeature.exception import GeneralError
from openfeature.provider import ProviderStatus
from openfeature.provider._registry import ProviderRegistry
from openfeature.provider.no_op_provider import NoOpProvider

Expand Down Expand Up @@ -105,3 +106,125 @@ def test_setting_default_provider_initializes_it():
registry.set_default_provider(provider)

provider.initialize.assert_called_once()


def test_registering_provider_as_default_then_domain_only_initializes_once():
"""Test that registering the same provider as default and for a domain only initializes it once."""

registry = ProviderRegistry()
provider = Mock()

registry.set_default_provider(provider)
registry.set_provider("domain", provider)

provider.initialize.assert_called_once()


def test_registering_provider_as_domain_then_default_only_initializes_once():
"""Test that registering the same provider as default and for a domain only initializes it once."""

registry = ProviderRegistry()
provider = Mock()

registry.set_provider("domain", provider)
registry.set_default_provider(provider)

provider.initialize.assert_called_once()


def test_replacing_provider_used_as_default_does_not_shutdown():
"""Test that replacing a provider that is also the default does not shut it down twice."""

registry = ProviderRegistry()
provider1 = Mock()
provider2 = Mock()

registry.set_default_provider(provider1)
registry.set_provider("domain", provider1)

registry.set_provider("domain", provider2)

provider1.shutdown.assert_not_called()
provider2.shutdown.assert_not_called()


def test_replacing_default_provider_used_as_domain_does_not_shutdown():
"""Test that replacing a default provider that is also used for a domain does not shut it down twice."""

registry = ProviderRegistry()
provider1 = Mock()
provider2 = Mock()

registry.set_provider("domain", provider1)
registry.set_default_provider(provider1)

registry.set_default_provider(provider2)

provider1.shutdown.assert_not_called()
provider2.shutdown.assert_not_called()


def test_shutting_down_registry_shuts_down_providers_once():
"""Test that shutting down the registry shuts down each provider only once."""

registry = ProviderRegistry()
provider1 = Mock()
provider2 = Mock()

registry.set_default_provider(provider1)
registry.set_provider("domain1", provider1)

registry.set_provider("domain2a", provider2)
registry.set_provider("domain2b", provider2)

registry.shutdown()

provider1.shutdown.assert_called_once()
provider2.shutdown.assert_called_once()


def test_initializing_provider_sets_status_ready():
"""Test that initializing a provider sets its status to READY."""

registry = ProviderRegistry()
provider = Mock()

assert registry.get_provider_status(provider) == ProviderStatus.NOT_READY

registry.set_provider("domain", provider)

provider.initialize.assert_called_once()
assert registry.get_provider_status(provider) == ProviderStatus.READY


def test_shutting_down_provider_sets_status_not_ready():
"""Test that shutting down a provider sets its status to NOT_READY."""

registry = ProviderRegistry()
provider = Mock()

registry.set_provider("domain", provider)
assert registry.get_provider_status(provider) == ProviderStatus.READY

registry.shutdown()
assert registry.get_provider_status(provider) == ProviderStatus.NOT_READY


def test_clearing_registry_resets_providers_and_default():
"""Test that clearing the registry resets all providers and the default provider."""

registry = ProviderRegistry()
provider = Mock()

registry.set_provider("domain", provider)
registry.set_default_provider(provider)

registry.clear_providers()

default_provider = registry.get_default_provider()
assert isinstance(default_provider, NoOpProvider)
assert registry.get_provider("domain") is default_provider
assert registry.get_provider_status(default_provider) == ProviderStatus.READY

provider.initialize.assert_called_once()
provider.shutdown.assert_called_once()