Skip to content

feat: limitless plugin implementation #912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions aws_advanced_python_wrapper/connection_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.host_selector import (
HostSelector, RandomHostSelector, RoundRobinHostSelector,
WeightedRandomHostSelector)
HighestWeightHostSelector, HostSelector, RandomHostSelector,
RoundRobinHostSelector, WeightedRandomHostSelector)
from aws_advanced_python_wrapper.plugin import CanReleaseResources
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
Expand Down Expand Up @@ -98,7 +98,8 @@ def connect(
class DriverConnectionProvider(ConnectionProvider):
_accepted_strategies: Dict[str, HostSelector] = {"random": RandomHostSelector(),
"round_robin": RoundRobinHostSelector(),
"weighted_random": WeightedRandomHostSelector()}
"weighted_random": WeightedRandomHostSelector(),
"highest_weight": HighestWeightHostSelector()}

def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool:
return True
Expand Down
18 changes: 17 additions & 1 deletion aws_advanced_python_wrapper/database_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Protocol, Tuple, runtime_checkable)

from aws_advanced_python_wrapper.driver_info import DriverInfo
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType

if TYPE_CHECKING:
from aws_advanced_python_wrapper.pep249 import Connection
Expand Down Expand Up @@ -98,6 +99,15 @@ def is_reader_query(self) -> str:
return self._IS_READER_QUERY


@runtime_checkable
class AuroraLimitlessDialect(Protocol):
_LIMITLESS_ROUTER_ENDPOINT_QUERY: str

@property
def limitless_router_endpoint_query(self) -> str:
return self._LIMITLESS_ROUTER_ENDPOINT_QUERY


class DatabaseDialect(Protocol):
"""
Database dialects help the AWS Advanced Python Driver determine what kind of underlying database is being used,
Expand Down Expand Up @@ -342,7 +352,7 @@ def get_host_list_provider_supplier(self) -> Callable:
return lambda provider_service, props: RdsHostListProvider(provider_service, props)


class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect):
class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect, AuroraLimitlessDialect):
_DIALECT_UPDATE_CANDIDATES: Tuple[DialectCode, ...] = (DialectCode.MULTI_AZ_PG,)

_EXTENSIONS_QUERY = "SELECT (setting LIKE '%aurora_stat_utils%') AS aurora_stat_utils " \
Expand All @@ -359,6 +369,7 @@ class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect):

_HOST_ID_QUERY = "SELECT aurora_db_instance_identifier()"
_IS_READER_QUERY = "SELECT pg_is_in_recovery()"
_LIMITLESS_ROUTER_ENDPOINT_QUERY = "SELECT router_endpoint, load FROM aurora_limitless_router_endpoints()"

@property
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
Expand Down Expand Up @@ -621,6 +632,11 @@ def get_dialect(self, driver_dialect: str, props: Properties) -> DatabaseDialect

if target_driver_type is TargetDriverType.POSTGRES:
rds_type = self._rds_helper.identify_rds_type(host)
if rds_type == RdsUrlType.RDS_AURORA_LIMITLESS_DB_SHARD_GROUP:
self._can_update = False
self._dialect_code = DialectCode.AURORA_PG
self._dialect = DatabaseDialectManager._known_dialects_by_code[DialectCode.AURORA_PG]
return self._dialect
if rds_type.is_rds_cluster:
self._can_update = True
self._dialect_code = DialectCode.AURORA_PG
Expand Down
9 changes: 6 additions & 3 deletions aws_advanced_python_wrapper/default_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional

if TYPE_CHECKING:
from aws_advanced_python_wrapper.connection_provider import (ConnectionProvider,
Expand Down Expand Up @@ -118,7 +118,7 @@ def accepts_strategy(self, role: HostRole, strategy: str) -> bool:
return False
return self._connection_provider_manager.accepts_strategy(role, strategy)

def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo:
def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Optional[List[HostInfo]] = None) -> HostInfo:
if HostRole.UNKNOWN == role:
raise AwsWrapperError(Messages.get("DefaultPlugin.UnknownHosts"))

Expand All @@ -127,7 +127,10 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo:
if len(hosts) < 1:
raise AwsWrapperError(Messages.get("DefaultPlugin.EmptyHosts"))

return self._connection_provider_manager.get_host_info_by_strategy(hosts, role, strategy, self._plugin_service.props)
if host_list is None:
return self._connection_provider_manager.get_host_info_by_strategy(hosts, role, strategy, self._plugin_service.props)
else:
return self._connection_provider_manager.get_host_info_by_strategy(tuple(host_list), role, strategy, self._plugin_service.props)

@property
def subscribed_methods(self) -> Set[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __init__(self, plugin_service: PluginService, props: Properties):
self._plugin_service = plugin_service
self._properties = props
self._host_response_time_service: HostResponseTimeService = \
HostResponseTimeService(plugin_service, props, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get_int(props))
self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get_int(props) * 10 ^ 6
HostResponseTimeService(plugin_service, props, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props))
self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props) * 10 ^ 6
self._random_host_selector = RandomHostSelector()
self._cached_fastest_response_host_by_role: CacheMap[str, HostInfo] = CacheMap()
self._hosts: Tuple[HostInfo, ...] = ()
Expand All @@ -86,7 +86,7 @@ def connect(
def accepts_strategy(self, role: HostRole, strategy: str) -> bool:
return strategy == FastestResponseStrategyPlugin._FASTEST_RESPONSE_STRATEGY_NAME

def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo:
def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Optional[List[HostInfo]] = None) -> HostInfo:
if not self.accepts_strategy(role, strategy):
logger.error("FastestResponseStrategyPlugin.UnsupportedHostSelectorStrategy", strategy)
raise AwsWrapperError(Messages.get_formatted("FastestResponseStrategyPlugin.UnsupportedHostSelectorStrategy", strategy))
Expand Down
10 changes: 10 additions & 0 deletions aws_advanced_python_wrapper/host_list_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def get_host_role(self, connection: Connection) -> HostRole:
def identify_connection(self, connection: Optional[Connection]) -> Optional[HostInfo]:
...

def get_cluster_id(self) -> str:
...


@runtime_checkable
class DynamicHostListProvider(HostListProvider, Protocol):
Expand Down Expand Up @@ -519,6 +522,10 @@ def _identify_connection(self, conn: Connection):
cursor.execute(self._dialect.host_id_query)
return cursor.fetchone()

def get_cluster_id(self):
self._initialize()
return self._cluster_id

@dataclass()
class ClusterIdSuggestion:
cluster_id: str
Expand Down Expand Up @@ -646,3 +653,6 @@ def get_host_role(self, connection: Connection) -> HostRole:
def identify_connection(self, connection: Optional[Connection]) -> Optional[HostInfo]:
raise UnsupportedOperationError(
Messages.get_formatted("ConnectionStringHostListProvider.UnsupportedMethod", "identify_connection"))

def get_cluster_id(self):
return "<none>"
12 changes: 12 additions & 0 deletions aws_advanced_python_wrapper/host_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,15 @@ def _update_host_weight_map_from_string(self, props: Optional[Properties] = None
except ValueError:
logger.error(message, pair)
raise AwsWrapperError(Messages.get_formatted(message, pair))


class HighestWeightHostSelector(HostSelector):

def get_host(self, hosts: Tuple[HostInfo, ...], role: HostRole, props: Optional[Properties] = None) -> HostInfo:
eligible_hosts: List[HostInfo] = [host for host in hosts if
host.role == role and host.get_availability() == HostAvailability.AVAILABLE]

if len(eligible_hosts) == 0:
raise AwsWrapperError(Messages.get_formatted("HostSelector.NoHostsMatchingRole", role))

return max(eligible_hosts, key=lambda host: host.weight)
Loading