Skip to content

feat: fastest response strategy plugin #427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions aws_advanced_python_wrapper/driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class DriverDialect(ABC):
"""
Driver dialects help the driver-agnostic AWS Python Driver interface with the driver-specific functionality of the underlying Python Driver.
"""
_QUERY = "SELECT 1"

_executor: ClassVar[Executor] = ThreadPoolExecutor()
_dialect_code: str = DriverDialectCodes.GENERIC
Expand Down Expand Up @@ -152,3 +153,13 @@ def unwrap_connection(self, conn_obj: object) -> Any:

def transfer_session_state(self, from_conn: Connection, to_conn: Connection):
return

def ping(self, conn: Connection) -> bool:
try:
with conn.cursor() as cursor:
query = DriverDialect._QUERY
self.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=10)
cursor.fetchone()
return True
except Exception:
return False
345 changes: 345 additions & 0 deletions aws_advanced_python_wrapper/fastest_response_strategy_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from copy import copy
from dataclasses import dataclass
from datetime import datetime
from threading import Event, Lock, Thread
from time import sleep
from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, List, Optional,
Set, Tuple)

from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.hostselector import RandomHostSelector
from aws_advanced_python_wrapper.plugin import Plugin
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
SlidingExpirationCacheWithCleanupThread
from aws_advanced_python_wrapper.utils.telemetry.telemetry import (
TelemetryContext, TelemetryFactory, TelemetryGauge, TelemetryTraceLevel)

if TYPE_CHECKING:
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.utils.notifications import HostEvent

logger = Logger(__name__)

MAX_VALUE = 2147483647


class FastestResponseStrategyPlugin(Plugin):
_FASTEST_RESPONSE_STRATEGY_NAME = "fastest_response"
_SUBSCRIBED_METHODS: Set[str] = {"accepts_strategy",
"get_host_info_by_strategy",
"notify_host_list_changed"}

def __init__(self, plugin_service: PluginService, props: Properties):
self._plugin_service = plugin_service
self._properties = props
self._host_response_time_service: HostResponseTimeService = \
HostResponseTimeService(plugin_service, props, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get_int(props))
self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get_int(props) * 10 ^ 6
self._random_host_selector = RandomHostSelector()
self._cached_fastest_response_host_by_role: CacheMap[str, HostInfo] = CacheMap()
self._hosts: Tuple[HostInfo, ...] = ()

@property
def subscribed_methods(self) -> Set[str]:
return self._SUBSCRIBED_METHODS

def connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
return self._connect(host_info, props, is_initial_connection, connect_func)

def force_connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
force_connect_func: Callable) -> Connection:
return self._connect(host_info, props, is_initial_connection, force_connect_func)

def _connect(
self,
host: HostInfo,
properties: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
conn = connect_func()

if is_initial_connection:
self._plugin_service.refresh_host_list(conn)

return conn

def accepts_strategy(self, role: HostRole, strategy: str) -> bool:
return strategy == FastestResponseStrategyPlugin._FASTEST_RESPONSE_STRATEGY_NAME

def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo:
if not self.accepts_strategy(role, strategy):
logger.error("FastestResponseStrategyPlugin.UnsupportedHostSelectorStrategy", strategy)
raise AwsWrapperError(Messages.get_formatted("FastestResponseStrategyPlugin.UnsupportedHostSelectorStrategy", strategy))

fastest_response_host: Optional[HostInfo] = self._cached_fastest_response_host_by_role.get(role.name)
if fastest_response_host is not None:

# Found a fastest host. Let's find it in the the latest topology.
for host in self._plugin_service.hosts:
if host == fastest_response_host:
# found the fastest host in the topology
return host
# It seems that the fastest cached host isn't in the latest topology.
# Let's ignore cached results and find the fastest host.

# Cached result isn't available. Need to find the fastest response time host.
eligible_hosts: List[FastestResponseStrategyPlugin.ResponseTimeTuple] = []
for host in self._plugin_service.hosts:
if role == host.role:
response_time_tuple = FastestResponseStrategyPlugin.ResponseTimeTuple(host,
self._host_response_time_service.get_response_time(host))
eligible_hosts.append(response_time_tuple)

# Sort by response time then retrieve the first host
sorted_eligible_hosts: List[FastestResponseStrategyPlugin.ResponseTimeTuple] = \
sorted(eligible_hosts, key=lambda x: x.response_time)

calculated_fastest_response_host = sorted_eligible_hosts[0].host_info
if calculated_fastest_response_host is None or \
self._host_response_time_service.get_response_time(calculated_fastest_response_host) == MAX_VALUE:
logger.debug("FastestResponseStrategyPlugin.RandomHostSelected")
return self._random_host_selector.get_host(self._plugin_service.hosts, role, self._properties)

self._cached_fastest_response_host_by_role.put(role.name,
calculated_fastest_response_host,
self._cache_expiration_nanos)

return calculated_fastest_response_host

def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]):
self._hosts = self._plugin_service.hosts
if self._host_response_time_service is not None:
self._host_response_time_service.set_hosts(self._hosts)

@dataclass
class ResponseTimeTuple:
host_info: HostInfo
response_time: int


class FastestResponseStrategyPluginFactory:

def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return FastestResponseStrategyPlugin(plugin_service, props)


class HostResponseTimeMonitor:

_MONITORING_PROPERTY_PREFIX: str = "frt-"
_NUM_OF_MEASURES: int = 5
_DEFAULT_CONNECT_TIMEOUT_SEC = 10

def __init__(self, plugin_service: PluginService, host_info: HostInfo, props: Properties, interval_ms: int):
self._plugin_service = plugin_service
self._host_info = host_info
self._properties = props
self._interval_ms = interval_ms

self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory()
self._response_time: int = MAX_VALUE
self._lock: Lock = Lock()
self._monitoring_conn: Optional[Connection] = None
self._is_stopped: Event = Event()

self._host_id: Optional[str] = self._host_info.host_id
if self._host_id is None or self._host_id == "":
self._host_id = self._host_info.host

self._daemon_thread: Thread = Thread(daemon=True, target=self.run)

# Report current response time (in milliseconds) to telemetry engine.
# Report -1 if response time couldn't be measured.
self._response_time_gauge: TelemetryGauge = \
self._telemetry_factory.create_gauge("frt.response.time." + self._host_id,
lambda: self._response_time if self._response_time != MAX_VALUE else -1)
self._daemon_thread.start()

@property
def response_time(self):
return self._response_time

@response_time.setter
def response_time(self, response_time: int):
self._response_time = response_time

@property
def host_info(self):
return self._host_info

@property
def is_stopped(self):
return self._is_stopped.is_set()

def close(self):
self._is_stopped.set()
self._daemon_thread.join(5)
logger.debug("HostResponseTimeMonitor.Stopped", self._host_info.host)

def _get_current_time(self):
return datetime.now().microsecond / 1000 # milliseconds

def run(self):
context: TelemetryContext = self._telemetry_factory.open_telemetry_context(
"node response time thread", TelemetryTraceLevel.TOP_LEVEL)
context.set_attribute("url", self._host_info.url)
try:
while not self.is_stopped:
self._open_connection()

if self._monitoring_conn is not None:

response_time_sum = 0
count = 0
for i in range(self._NUM_OF_MEASURES):
if self.is_stopped:
break
start_time = self._get_current_time()
if self._plugin_service.driver_dialect.ping(self._monitoring_conn):
calculated_response_time = self._get_current_time() - start_time
response_time_sum = response_time_sum + calculated_response_time
count = count + 1

if count > 0:
self.response_time = response_time_sum / count
else:
self.response_time = MAX_VALUE
logger.debug("HostResponseTimeMonitor.ResponseTime", self._host_info.host, self._response_time)

sleep(self._interval_ms / 1000)

except InterruptedError:
# exit thread
logger.debug("HostResponseTimeMonitor.InterruptedExceptionDuringMonitoring", self._host_info.host)
except Exception as e:
# this should not be reached; log and exit thread
logger.debug("HostResponseTimeMonitor.ExceptionDuringMonitoringStop",
self._host_info.host,
e) # print full trace stack of the exception.
finally:
self._is_stopped.set()
if self._monitoring_conn is not None:
try:
self._monitoring_conn.close()
except Exception:
# Do nothing
pass

if context is not None:
context.close_context()

def _open_connection(self):
try:
driver_dialect = self._plugin_service.driver_dialect
if self._monitoring_conn is None or driver_dialect.is_closed(self._monitoring_conn):
monitoring_conn_properties: Properties = copy(self._properties)
for key, value in self._properties.items():
if key.startswith(self._MONITORING_PROPERTY_PREFIX):
monitoring_conn_properties[key[len(self._MONITORING_PROPERTY_PREFIX):len(key)]] = value
monitoring_conn_properties.pop(key, None)

# Set a default connect timeout if the user hasn't configured one
if monitoring_conn_properties.get(WrapperProperties.CONNECT_TIMEOUT_SEC.name, None) is None:
monitoring_conn_properties[WrapperProperties.CONNECT_TIMEOUT_SEC.name] = HostResponseTimeMonitor._DEFAULT_CONNECT_TIMEOUT_SEC

logger.debug("HostResponseTimeMonitor.OpeningConnection", self._host_info.url)
self._monitoring_conn = self._plugin_service.force_connect(self._host_info, monitoring_conn_properties, None)
logger.debug("HostResponseTimeMonitor.OpenedConnection", self._host_info.url)

except Exception:
if self._monitoring_conn is not None:
try:
self._monitoring_conn.close()
except Exception:
pass # ignore

self._monitoring_conn = None


class HostResponseTimeService:
_CACHE_EXPIRATION_NS: int = 6 * 10 ^ 11 # 10 minutes
_CACHE_CLEANUP_NS: int = 6 * 10 ^ 10 # 1 minute
_lock: Lock = Lock()
_monitoring_nodes: ClassVar[SlidingExpirationCacheWithCleanupThread[str, HostResponseTimeMonitor]] = \
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NS,
should_dispose_func=lambda monitor: True,
item_disposal_func=lambda monitor: HostResponseTimeService._monitor_close(monitor))

def __init__(self, plugin_service: PluginService, props: Properties, interval_ms: int):
self._plugin_service = plugin_service
self._properties = props
self._interval_ms = interval_ms
self._hosts: Tuple[HostInfo, ...] = ()
self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory()
self._host_count_gauge: TelemetryGauge = self._telemetry_factory.create_gauge("frt.nodes.count", lambda: len(self._monitoring_nodes))

@property
def hosts(self) -> Tuple[HostInfo, ...]:
return self._hosts

@hosts.setter
def hosts(self, new_hosts: Tuple[HostInfo, ...]):
self._hosts = new_hosts

@staticmethod
def _monitor_close(monitor: HostResponseTimeMonitor):
try:
monitor.close()
except Exception:
pass

def get_response_time(self, host_info: HostInfo) -> int:
monitor: Optional[HostResponseTimeMonitor] = HostResponseTimeService._monitoring_nodes.get(host_info.url)
if monitor is None:
return MAX_VALUE
return monitor.response_time

def set_hosts(self, new_hosts: Tuple[HostInfo, ...]) -> None:
old_hosts_dict = {x.url: x for x in self.hosts}
self.hosts = new_hosts

for host in self.hosts:
if host.url not in old_hosts_dict:
with self._lock:
self._monitoring_nodes.compute_if_absent(host.url,
lambda _: HostResponseTimeMonitor(
self._plugin_service,
host,
self._properties,
self._interval_ms), self._CACHE_EXPIRATION_NS)
3 changes: 2 additions & 1 deletion aws_advanced_python_wrapper/host_monitoring_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ class Monitor:
_INACTIVE_SLEEP_MS = 100
_MIN_HOST_CHECK_TIMEOUT_MS = 3000
_MONITORING_PROPERTY_PREFIX = "monitoring-"
_QUERY = "SELECT 1"

def __init__(
self,
Expand Down Expand Up @@ -560,7 +561,7 @@ def _is_host_available(self, conn: Connection, timeout_sec: float) -> bool:
def _execute_conn_check(self, conn: Connection, timeout_sec: float):
driver_dialect = self._plugin_service.driver_dialect
with conn.cursor() as cursor:
query = "SELECT 1"
query = Monitor._QUERY
driver_dialect.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=timeout_sec)
cursor.fetchone()

Expand Down
3 changes: 3 additions & 0 deletions aws_advanced_python_wrapper/mysql_driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def transfer_session_state(self, from_conn: Connection, to_conn: Connection):
isinstance(to_conn, CMySQLConnection) or isinstance(to_conn, MySQLConnection)):
to_conn.autocommit = from_conn.autocommit

def ping(self, conn: Connection) -> bool:
return not self.is_closed(conn)

def prepare_connect_info(self, host_info: HostInfo, original_props: Properties) -> Properties:
driver_props: Properties = Properties(original_props.copy())
PropertiesUtils.remove_wrapper_props(driver_props)
Expand Down
Loading
Loading