Skip to content

Commit 36a895f

Browse files
committed
feat: fastest response strategy plugin
1 parent 8ba0909 commit 36a895f

File tree

5 files changed

+382
-4
lines changed

5 files changed

+382
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from abc import abstractmethod
18+
from concurrent.futures import Executor, ThreadPoolExecutor
19+
from copy import copy
20+
from dataclasses import dataclass
21+
from datetime import datetime
22+
from logging import Logger
23+
from threading import Event, Lock
24+
from time import sleep
25+
from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, Optional, Set,
26+
Tuple)
27+
28+
from aws_advanced_python_wrapper.errors import AwsWrapperError
29+
from aws_advanced_python_wrapper.hostselector import RandomHostSelector
30+
from aws_advanced_python_wrapper.plugin import Plugin
31+
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
32+
from aws_advanced_python_wrapper.utils.messages import Messages
33+
from aws_advanced_python_wrapper.utils.properties import (Properties,
34+
WrapperProperties)
35+
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
36+
SlidingExpirationCacheWithCleanupThread
37+
38+
if TYPE_CHECKING:
39+
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
40+
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
41+
from aws_advanced_python_wrapper.pep249 import Connection
42+
from aws_advanced_python_wrapper.plugin_service import PluginService
43+
from aws_advanced_python_wrapper.utils.notifications import HostEvent
44+
45+
logger = Logger(__name__)
46+
47+
48+
class FastestResponseStrategyPlugin(Plugin):
49+
_FASTEST_RESPONSE_STRATEGY_NAME = "fastest_response"
50+
_SUBSCRIBED_METHODS: Set[str] = {"accepts_strategy",
51+
"get_host_info_by_strategy",
52+
"notify_host_list_changed"}
53+
54+
def __init__(self, plugin_service: PluginService, props: Properties, host_response_time_service: Optional[HostResponseTimeService] = None):
55+
self._plugin_service = plugin_service
56+
self._properties = props
57+
self._host_response_time_service: Optional[HostResponseTimeService] = host_response_time_service
58+
self._cache_expiration_nano = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get_int(props)
59+
self._cached_fastest_response_host_by_role: CacheMap[str, HostInfo] = CacheMap()
60+
self._random_host_selector = RandomHostSelector()
61+
self._hosts: Tuple[HostInfo, ...] = ()
62+
63+
@property
64+
def subscribed_methods(self) -> Set[str]:
65+
return self._SUBSCRIBED_METHODS
66+
67+
def connect(
68+
self,
69+
target_driver_func: Callable,
70+
driver_dialect: DriverDialect,
71+
host_info: HostInfo,
72+
props: Properties,
73+
is_initial_connection: bool,
74+
connect_func: Callable) -> Connection:
75+
return self._connect(host_info, props, is_initial_connection, connect_func)
76+
77+
def force_connect(
78+
self,
79+
target_driver_func: Callable,
80+
driver_dialect: DriverDialect,
81+
host_info: HostInfo,
82+
props: Properties,
83+
is_initial_connection: bool,
84+
force_connect_func: Callable) -> Connection:
85+
return self._connect(host_info, props, is_initial_connection, force_connect_func)
86+
87+
def _connect(
88+
self,
89+
host: HostInfo,
90+
properties: Properties,
91+
is_initial_connection: bool,
92+
connect_func: Callable) -> Connection:
93+
conn = connect_func()
94+
95+
if is_initial_connection:
96+
self._plugin_service.refresh_host_list(conn)
97+
98+
return conn
99+
100+
def accepts_strategy(self, role: HostRole, strategy: str) -> bool:
101+
return strategy == FastestResponseStrategyPlugin._FASTEST_RESPONSE_STRATEGY_NAME
102+
103+
def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo:
104+
if not self.accepts_strategy(role, strategy):
105+
raise AwsWrapperError(Messages.get_formatted("DriverConnectionProvider.UnsupportedStrategy", strategy))
106+
107+
fastest_response_host: Optional[HostInfo] = self._cached_fastest_response_host_by_role.get(role.name)
108+
if fastest_response_host is not None:
109+
110+
# Found a fastest host. Let's find it in the the latest topology.
111+
for host in self._plugin_service.hosts:
112+
if host == fastest_response_host:
113+
# found the fastest host in the topology
114+
return host
115+
# It seems that the fastest cached host isn't in the latest topology.
116+
# Let's ignore cached results and find the fastest host.
117+
118+
# Cached result isn't available. Need to find the fastest response time host.
119+
host_dict = {}
120+
for host in self._plugin_service.hosts:
121+
if role == host.role and self._host_response_time_service is not None:
122+
response_time_tuple = FastestResponseStrategyPlugin.ResponseTimeTuple(host,
123+
self._host_response_time_service.get_response_time(host))
124+
host_dict[response_time_tuple.host_info] = response_time_tuple.response_time
125+
# sort by response time then retrieve the first host
126+
sorted_host_dict = dict(sorted(host_dict.items(), key=lambda item: item[1]))
127+
calculated_fastest_response_host: HostInfo = next(iter(sorted_host_dict.keys()))
128+
if calculated_fastest_response_host is None:
129+
return self._random_host_selector.get_host(self._plugin_service.hosts, role, self._properties)
130+
131+
self._cached_fastest_response_host_by_role.put(role.name, calculated_fastest_response_host, self._cache_expiration_nano)
132+
133+
return calculated_fastest_response_host
134+
135+
def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]):
136+
self._hosts = self._plugin_service.hosts
137+
if self._host_response_time_service:
138+
self._host_response_time_service.set_hosts(self._hosts)
139+
140+
@dataclass
141+
class ResponseTimeTuple:
142+
host_info: HostInfo
143+
response_time: int
144+
145+
146+
class FastestResponseStrategyPluginFactory:
147+
148+
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
149+
return FastestResponseStrategyPlugin(plugin_service, props)
150+
151+
152+
class NodeResponseTimeMonitor:
153+
154+
_MONITORING_PROPERTY_PREFIX: str = "frt-"
155+
_NUM_OF_MEASURES: int = 5
156+
157+
def __init__(self, plugin_service: PluginService, host_info: HostInfo, props: Properties, interval_ms: int):
158+
self._plugin_service = plugin_service
159+
self._host_info = host_info
160+
self._properties = props
161+
self._interval_ms = interval_ms
162+
163+
self._response_time: int = 0
164+
self._check_timestamp = datetime.now()
165+
self._lock: Lock = Lock()
166+
self._monitoring_conn: Optional[Connection] = None
167+
self._is_stopped: Event = Event()
168+
169+
self._node_id: Optional[str] = self._host_info.host_id
170+
if self._node_id is None or self._node_id == "":
171+
self._node_id = self._host_info.host
172+
173+
self._executor: Executor = ThreadPoolExecutor(thread_name_prefix="NodeResponseTimeMonitorExecutor")
174+
175+
@property
176+
def response_time(self):
177+
return self._response_time
178+
179+
@property
180+
def check_timestamp(self):
181+
return self._check_timestamp
182+
183+
@property
184+
def host_info(self):
185+
return self._host_info
186+
187+
@property
188+
def is_stopped(self):
189+
return self._is_stopped.is_set()
190+
191+
def close(self):
192+
self._is_stopped.set()
193+
194+
self._executor.shutdown(wait=True)
195+
logger.debug("NodeResponseTimeMonitor.Stopped", self._host_info.host)
196+
197+
def _get_current_time(self):
198+
return datetime.now()
199+
200+
def run(self):
201+
try:
202+
while not self.is_stopped:
203+
try:
204+
self._open_connection()
205+
206+
if self._monitoring_conn is not None:
207+
208+
response_time_sum = 0
209+
count = 0
210+
for i in self._NUM_OF_MEASURES:
211+
if self._is_stopped:
212+
break
213+
start_time = self._get_current_time()
214+
if self._plugin_service.driver_dialect.ping(self._monitoring_conn):
215+
response_time = self._get_current_time() - start_time
216+
response_time_sum = response_time_sum + response_time
217+
count = count + 1
218+
if count > 0:
219+
self.response_time = response_time_sum / count / 1000
220+
else:
221+
self.response_time = 0
222+
self.check_timestamp = self._get_current_time()
223+
logger.debug("NodeResponseTimeMonitor.ResponseTime", self._host_info.host, self._response_time)
224+
225+
sleep(self._interval_ms/1000)
226+
227+
except InterruptedError:
228+
# exit thread
229+
logger.debug("NodeResponseTimeMonitor.InterruptedExceptionDuringMonitoring", self._host_info.host)
230+
except Exception as e:
231+
# this should not be reached; log and exit thread
232+
logger.debug("NodeResponseTimeMonitor.ExceptionDuringMonitoringStop",
233+
self._host_info.host,
234+
e) # print full trace stack of the exception.
235+
finally:
236+
self._is_stopped.set()
237+
if self._monitoring_conn is not None:
238+
try:
239+
self._monitoring_conn.close()
240+
except Exception:
241+
# Do nothing
242+
pass
243+
244+
def _open_connection(self):
245+
try:
246+
driver_dialect = self._plugin_service.driver_dialect
247+
if self._monitoring_conn is None or driver_dialect.is_closed(self._monitoring_conn):
248+
monitoring_conn_properties: Properties = copy(self._properties)
249+
for key, value in self._properties.items():
250+
if key.startswith(self._MONITORING_PROPERTY_PREFIX):
251+
monitoring_conn_properties[key[len(self._MONITORING_PROPERTY_PREFIX):len(key)]] = value
252+
monitoring_conn_properties.pop(key, None)
253+
254+
logger.debug("NodeResponseTimeMonitor.OpeningConnection", self._host_info.url)
255+
self._monitoring_conn = self._plugin_service.force_connect(self._host_info, monitoring_conn_properties, None)
256+
logger.debug("NodeResponseTimeMonitor.OpenedConnection", self._host_info.url)
257+
258+
except Exception:
259+
if self._monitoring_conn is not None:
260+
try:
261+
self._monitoring_conn.close()
262+
except Exception:
263+
pass # ignore
264+
265+
self._monitoring_conn = None
266+
267+
268+
class HostResponseTimeService:
269+
"""
270+
Return a response time in milliseconds to the host.
271+
Return HostResponseTimeService._MAX_VALUE if response time is not available.
272+
273+
@param hostSpec the host details
274+
@return response time in milliseconds for a desired host. It should return HostResponseTimeService._MAX_VALUE
275+
if response time couldn't be measured.
276+
"""
277+
278+
@abstractmethod
279+
def get_response_time(self, host_info: HostInfo) -> int:
280+
...
281+
282+
@abstractmethod
283+
def set_hosts(self, hosts: Tuple[HostInfo, ...]) -> None:
284+
...
285+
286+
287+
class HostResponseTimeServiceImpl(HostResponseTimeService):
288+
_MAX_VALUE = 2 ^ 31 - 1
289+
_CACHE_EXPIRATION_NANO: int = 10 * 10 ^ 9
290+
_CACHE_CLEANUP_NANO: int = 1 * 10 ^ 9
291+
_lock: Lock = Lock()
292+
_monitoring_nodes: ClassVar[SlidingExpirationCacheWithCleanupThread[str, NodeResponseTimeMonitor]] = \
293+
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NANO,
294+
should_dispose_func=lambda monitor: True,
295+
item_disposal_func=lambda monitor: monitor.dispose())
296+
297+
def __init__(self, plugin_service: PluginService, props: Properties, interval_ms: int):
298+
self._plugin_service = plugin_service
299+
self._properties = props
300+
self._interval_ms = interval_ms
301+
self._hosts: Tuple[HostInfo, ...] = ()
302+
303+
def get_response_time(self, host_info: HostInfo) -> int:
304+
monitor: Optional[NodeResponseTimeMonitor] = HostResponseTimeServiceImpl._monitoring_nodes.get(host_info.url)
305+
if monitor is None:
306+
return HostResponseTimeServiceImpl._MAX_VALUE
307+
return monitor.response_time
308+
309+
def set_hosts(self, hosts: Tuple[HostInfo, ...]) -> None:
310+
old_hosts_dict = {x.url: x for x in hosts}
311+
self._hosts = hosts
312+
313+
for host in self._hosts:
314+
new_host = host if host.url not in old_hosts_dict else None
315+
if new_host:
316+
with self._lock:
317+
self._monitoring_nodes.compute_if_absent(host.url,
318+
lambda _: NodeResponseTimeMonitor(
319+
self._plugin_service,
320+
new_host,
321+
self._properties,
322+
self._interval_ms), self._CACHE_EXPIRATION_NANO)

aws_advanced_python_wrapper/plugin_service.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from typing import TYPE_CHECKING, ClassVar, List, Type
1818

19+
from aws_advanced_python_wrapper.fastest_response_strategy_plugin import \
20+
FastestResponseStrategyPluginFactory
1921
from aws_advanced_python_wrapper.federated_plugin import \
2022
FederatedAuthPluginFactory
2123

@@ -571,6 +573,7 @@ class PluginManager(CanReleaseResources):
571573
"host_monitoring": HostMonitoringPluginFactory,
572574
"failover": FailoverPluginFactory,
573575
"read_write_splitting": ReadWriteSplittingPluginFactory,
576+
"fastest_response_strategy": FastestResponseStrategyPluginFactory,
574577
"stale_dns": StaleDnsPluginFactory,
575578
"connect_time": ConnectTimePluginFactory,
576579
"execute_time": ExecuteTimePluginFactory,
@@ -589,8 +592,9 @@ class PluginManager(CanReleaseResources):
589592
ReadWriteSplittingPluginFactory: 300,
590593
FailoverPluginFactory: 400,
591594
HostMonitoringPluginFactory: 500,
592-
IamAuthPluginFactory: 600,
593-
AwsSecretsManagerPluginFactory: 700,
595+
FastestResponseStrategyPluginFactory: 600,
596+
IamAuthPluginFactory: 700,
597+
AwsSecretsManagerPluginFactory: 800,
594598
ConnectTimePluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN,
595599
ExecuteTimePluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN,
596600
DeveloperPluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN,

aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties

+9
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,13 @@ MonitorService.ErrorPopulatingAliases=[MonitorService] An error occurred while p
148148

149149
MultiAzHostListProvider.UnableToParseInstanceName=[MultiAzHostListProvider] The MultiAzHostListProvider was unable to parse the instance name from the endpoint returned by the topology query.
150150

151+
NodeResponseTimeMonitor.ExceptionDuringMonitoringStop=[NodeResponseTimeMonitor] Stopping thread after unhandled exception was thrown in Response time thread for node {}.
152+
NodeResponseTimeMonitor.InterruptedExceptionDuringMonitoring=[NodeResponseTimeMonitor] Response time thread for node {} was interrupted.
153+
NodeResponseTimeMonitor.OpenedConnection=[NodeResponseTimeMonitor] Opened Response time connection: {}.
154+
NodeResponseTimeMonitor.OpeningConnection=[NodeResponseTimeMonitor] Opening a Response time connection to ''{}''.
155+
NodeResponseTimeMonitor.ResponseTime=[NodeResponseTimeMonitor] Response time for ''{}'': {} ms
156+
NodeResponseTimeMonitor.Stopped=[NodeResponseTimeMonitor] Stopped Response time thread for node ''{}''.
157+
151158
OpenedConnectionTracker.OpenedConnectionsTracked=[OpenedConnectionTracker] Opened Connections Tracked: {}
152159
OpenedConnectionTracker.InvalidatingConnections=[OpenedConnectionTracker] Invalidating opened connections to host: {}
153160
OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet=[OpenedConnectionTracker] The driver is unable to track this opened connection because the instance endpoint is unknown.
@@ -243,6 +250,8 @@ RoundRobinHostSelector.ClusterInfoNone=[RoundRobinHostSelector] The round robin
243250
RoundRobinHostSelector.RoundRobinInvalidDefaultWeight=[RoundRobinHostSelector] The provided default weight value is not valid. Weight values must be an integer greater than or equal to 1.
244251
RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs= [RoundRobinHostSelector] The provided host weight pairs have not been configured correctly. Please ensure the provided host weight pairs is a comma separated list of pairs, each pair in the format of <host>:<weight>. Weight values must be an integer greater than or equal to the default weight value of 1.
245252

253+
SlidingExpirationCache.CleaningUp=[SlidingExpirationCache] Cleaning up...
254+
246255
SqlAlchemyPooledConnectionProvider.PoolNone=[SqlAlchemyPooledConnectionProvider] Attempted to find or create a pool for '{}' but the result of the attempt evaluated to None.
247256
SqlAlchemyPooledConnectionProvider.UnableToCreateDefaultKey=[SqlAlchemyPooledConnectionProvider] Unable to create a default key for internal connection pools. By default, the user parameter is used, but the given user evaluated to None or the empty string (""). Please ensure you have passed a valid user in the connection properties.
248257

aws_advanced_python_wrapper/utils/properties.py

+5
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,11 @@ class WrapperProperties:
282282
DB_USER = WrapperProperty("db_user",
283283
"The database user used to access the database",
284284
None)
285+
# Fastest Response Strategy
286+
287+
RESPONSE_MEASUREMENT_INTERVAL_MILLIS = WrapperProperty("response_measurement_interval_ms",
288+
"Interval in millis between measuring response time to a database node",
289+
30000)
285290

286291
# Telemetry
287292
ENABLE_TELEMETRY = WrapperProperty(

0 commit comments

Comments
 (0)