Skip to content

Commit 46e2a54

Browse files
committed
feat: implement configuration profiles
1 parent 6274e73 commit 46e2a54

21 files changed

+758
-131
lines changed

aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py

+45-17
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,19 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
6464

6565
self._track_connection(instance_endpoint, conn)
6666

67+
def invalidate_current_connection(self, host_info: HostInfo, conn: Optional[Connection]):
68+
host: Optional[str] = host_info.as_alias() \
69+
if self._rds_utils.is_rds_instance(host_info.host) \
70+
else next(alias for alias in host_info.aliases if self._rds_utils.is_rds_instance(alias))
71+
72+
if not host:
73+
return
74+
75+
connection_set: Optional[WeakSet] = self._opened_connections.get(host)
76+
if connection_set is not None:
77+
self._log_connection_set(host, connection_set)
78+
connection_set.discard(conn)
79+
6780
def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: Optional[FrozenSet[str]] = None):
6881
"""
6982
Invalidates all opened connections pointing to the same host in a daemon thread.
@@ -77,14 +90,10 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
7790
self.invalidate_all_connections(host=host_info.as_aliases())
7891
return
7992

80-
instance_endpoint: Optional[str] = None
8193
if host is None:
8294
return
8395

84-
for instance in host:
85-
if instance is not None and self._rds_utils.is_rds_instance(instance):
86-
instance_endpoint = instance
87-
break
96+
instance_endpoint = next(instance for instance in host if self._rds_utils.is_rds_instance(instance))
8897

8998
if not instance_endpoint:
9099
return
@@ -135,8 +144,8 @@ def log_opened_connections(self):
135144

136145
return logger.debug("OpenedConnectionTracker.OpenedConnectionsTracked", msg)
137146

138-
def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]):
139-
if conn_set is None or len(conn_set) == 0:
147+
def _log_connection_set(self, host: Optional[str], conn_set: Optional[WeakSet]):
148+
if host is None or conn_set is None or len(conn_set) == 0:
140149
return
141150

142151
conn = ""
@@ -148,13 +157,14 @@ def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]):
148157

149158

150159
class AuroraConnectionTrackerPlugin(Plugin):
151-
_SUBSCRIBED_METHODS: Set[str] = {"*"}
160+
_SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"}
152161
_current_writer: Optional[HostInfo] = None
153162
_need_update_current_writer: bool = False
163+
_METHOD_CLOSE = "Connection.close"
154164

155165
@property
156166
def subscribed_methods(self) -> Set[str]:
157-
return self._SUBSCRIBED_METHODS
167+
return AuroraConnectionTrackerPlugin._SUBSCRIBED_METHODS.union(self._plugin_service.network_bound_methods)
158168

159169
def __init__(self,
160170
plugin_service: PluginService,
@@ -201,19 +211,20 @@ def _connect(self, host_info: HostInfo, connect_func: Callable):
201211
return conn
202212

203213
def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any:
204-
if self._current_writer is None or self._need_update_current_writer:
205-
self._current_writer = self._get_writer(self._plugin_service.hosts)
206-
self._need_update_current_writer = False
214+
self._remember_writer()
207215

208216
try:
209-
return execute_func()
217+
results = execute_func()
218+
if method_name == AuroraConnectionTrackerPlugin._METHOD_CLOSE and self._plugin_service.current_host_info is not None:
219+
self._tracker.invalidate_current_connection(self._plugin_service.current_host_info, self._plugin_service.current_connection)
220+
elif self._need_update_current_writer:
221+
self._check_writer_changed()
222+
return results
210223

211224
except Exception as e:
212225
# Check that e is a FailoverError and that the writer has changed
213-
if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.hosts) != self._current_writer:
214-
self._tracker.invalidate_all_connections(host_info=self._current_writer)
215-
self._tracker.log_opened_connections()
216-
self._need_update_current_writer = True
226+
if isinstance(e, FailoverError):
227+
self._check_writer_changed()
217228
raise e
218229

219230
def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]:
@@ -222,6 +233,23 @@ def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]:
222233
return host
223234
return None
224235

236+
def _remember_writer(self):
237+
if self._current_writer is None or self._need_update_current_writer:
238+
self._current_writer = self._get_writer(self._plugin_service.hosts)
239+
self._need_update_current_writer = False
240+
241+
def _check_writer_changed(self):
242+
host_info_after_failover = self._get_writer(self._plugin_service.hosts)
243+
244+
if self._current_writer is None:
245+
self._current_writer = host_info_after_failover
246+
self._need_update_current_writer = False
247+
elif self._current_writer != host_info_after_failover:
248+
self._tracker.invalidate_all_connections(self._current_writer)
249+
self._tracker.log_opened_connections()
250+
self._current_writer = host_info_after_failover
251+
self._need_update_current_writer = False
252+
225253

226254
class AuroraConnectionTrackerPluginFactory(PluginFactory):
227255
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import time
18+
from typing import TYPE_CHECKING, Callable, Optional, Set
19+
20+
from aws_advanced_python_wrapper.utils.log import Logger
21+
22+
if TYPE_CHECKING:
23+
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
24+
from aws_advanced_python_wrapper.pep249 import Connection
25+
from aws_advanced_python_wrapper.plugin_service import PluginService
26+
from aws_advanced_python_wrapper.host_list_provider import HostListProviderService
27+
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
28+
29+
from aws_advanced_python_wrapper.errors import AwsWrapperError
30+
from aws_advanced_python_wrapper.host_availability import HostAvailability
31+
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
32+
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
33+
from aws_advanced_python_wrapper.utils.messages import Messages
34+
from aws_advanced_python_wrapper.utils.properties import (Properties,
35+
WrapperProperties)
36+
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
37+
38+
logger = Logger(__name__)
39+
40+
41+
class AuroraInitialConnectionStrategyPlugin(Plugin):
42+
_plugin_service: PluginService
43+
_host_list_provider_service: HostListProviderService
44+
_rds_utils: RdsUtils
45+
46+
@property
47+
def subscribed_methods(self) -> Set[str]:
48+
return {"init_host_provider", "connect", "force_connect"}
49+
50+
def __init__(self, plugin_service: PluginService, properties: Properties):
51+
self._plugin_service = plugin_service
52+
53+
def init_host_provider(self, props: Properties, host_list_provider_service: HostListProviderService, init_host_provider_func: Callable):
54+
self._host_list_provider_service = host_list_provider_service
55+
if host_list_provider_service.is_static_host_list_provider():
56+
msg = Messages.get("AuroraInitialConnectionStrategyPlugin.RequireDynamicProvider")
57+
logger.warning(msg)
58+
raise AwsWrapperError(msg)
59+
init_host_provider_func()
60+
61+
def connect(self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties,
62+
is_initial_connection: bool, connect_func: Callable) -> Connection:
63+
return self._connect_internal(host_info, props, is_initial_connection, connect_func)
64+
65+
def force_connect(self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties,
66+
is_initial_connection: bool, force_connect_func: Callable) -> Connection:
67+
return self._connect_internal(host_info, props, is_initial_connection, force_connect_func)
68+
69+
def _connect_internal(self, host_info: HostInfo, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Connection:
70+
urlType: RdsUrlType = self._rds_utils.identify_rds_type(host_info.host)
71+
if not urlType.is_rds_cluster:
72+
return connect_func()
73+
74+
if urlType == RdsUrlType.RDS_WRITER_CLUSTER:
75+
writer_candidate_conn = self._get_verified_writer_connection(props, is_initial_connection, connect_func)
76+
if writer_candidate_conn is None:
77+
return connect_func()
78+
return writer_candidate_conn
79+
80+
if urlType == RdsUrlType.RDS_READER_CLUSTER:
81+
reader_candidate_conn = self._get_verified_reader_connection(props, is_initial_connection, connect_func)
82+
if reader_candidate_conn is None:
83+
return connect_func()
84+
return reader_candidate_conn
85+
86+
# Continue with a normal workflow.
87+
return connect_func()
88+
89+
def _get_verified_writer_connection(self, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Optional[Connection]:
90+
retry_delay_ms: int = WrapperProperties.OPEN_CONNECTION_RETRY_TIMEOUT_MS.get_int(props)
91+
end_time_nano = self._get_time() + retry_delay_ms * 1_000_000
92+
93+
writer_candidate_conn: Optional[Connection]
94+
writer_candidate: Optional[HostInfo]
95+
96+
while self._get_time() < end_time_nano:
97+
writer_candidate_conn = None
98+
writer_candidate = None
99+
100+
try:
101+
writer_candidate = self._get_writer()
102+
if writer_candidate_conn is None or self._rds_utils.is_rds_cluster_dns(writer_candidate.host):
103+
writer_candidate_conn = connect_func()
104+
self._plugin_service.force_refresh_host_list(writer_candidate_conn)
105+
writer_candidate = self._plugin_service.identify_connection(writer_candidate_conn)
106+
107+
if writer_candidate is not None and writer_candidate.role != HostRole.WRITER:
108+
# Shouldn't be here. But let's try again.
109+
self._close_connection(writer_candidate_conn)
110+
self._delay(retry_delay_ms)
111+
continue
112+
113+
if is_initial_connection:
114+
self._host_list_provider_service.initial_connection_host_info = writer_candidate
115+
116+
return writer_candidate_conn
117+
118+
writer_candidate_conn = self._plugin_service.connect(writer_candidate, props)
119+
120+
if self._plugin_service.get_host_role(writer_candidate_conn) != HostRole.WRITER:
121+
self._plugin_service.force_refresh_host_list(writer_candidate_conn)
122+
self._close_connection(writer_candidate_conn)
123+
self._delay(retry_delay_ms)
124+
continue
125+
126+
if is_initial_connection:
127+
self._host_list_provider_service.initial_connection_host_info = writer_candidate
128+
return writer_candidate_conn
129+
130+
except Exception as e:
131+
if writer_candidate is not None:
132+
self._plugin_service.set_availability(writer_candidate.as_aliases(), HostAvailability.UNAVAILABLE)
133+
self._close_connection(writer_candidate_conn)
134+
raise e
135+
136+
return None
137+
138+
def _get_verified_reader_connection(self, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Optional[Connection]:
139+
retry_delay_ms: int = WrapperProperties.OPEN_CONNECTION_RETRY_INTERVAL_MS.get_int(props)
140+
end_time_nano = self._get_time() + WrapperProperties.OPEN_CONNECTION_RETRY_TIMEOUT_MS.get_int(props) * 1_000_000
141+
142+
reader_candidate_conn: Optional[Connection]
143+
reader_candidate: Optional[HostInfo]
144+
145+
while self._get_time() < end_time_nano:
146+
reader_candidate_conn = None
147+
reader_candidate = None
148+
149+
try:
150+
reader_candidate = self._get_reader(props)
151+
if reader_candidate is None or self._rds_utils.is_rds_cluster_dns(reader_candidate.host):
152+
# Reader not found, topology may be outdated
153+
reader_candidate_conn = connect_func()
154+
self._plugin_service.force_refresh_host_list(reader_candidate_conn)
155+
reader_candidate = self._plugin_service.identify_connection(reader_candidate_conn)
156+
157+
if reader_candidate is not None and reader_candidate.role != HostRole.READER:
158+
if self._has_no_readers():
159+
# Cluster has no readers. Simulate Aurora reader cluster endpoint logic
160+
if is_initial_connection and reader_candidate.host is not None:
161+
self._host_list_provider_service.initial_connection_host_info = reader_candidate
162+
return reader_candidate_conn
163+
self._close_connection(reader_candidate_conn)
164+
self._delay(retry_delay_ms)
165+
continue
166+
167+
if reader_candidate is not None and is_initial_connection:
168+
self._host_list_provider_service.initial_connection_host_info = reader_candidate
169+
return reader_candidate_conn
170+
171+
reader_candidate_conn = self._plugin_service.connect(reader_candidate, props)
172+
if self._plugin_service.get_host_role(reader_candidate_conn) != HostRole.READER:
173+
# If the new connection resolves to a writer instance, this means the topology is outdated.
174+
# Force refresh to update the topology.
175+
self._plugin_service.force_refresh_host_list(reader_candidate_conn)
176+
177+
if self._has_no_readers():
178+
# Cluster has no readers. Simulate Aurora reader cluster endpoint logic
179+
if is_initial_connection:
180+
self._host_list_provider_service.initial_connection_host_info = reader_candidate
181+
return reader_candidate_conn
182+
183+
self._close_connection(reader_candidate_conn)
184+
self._delay(retry_delay_ms)
185+
continue
186+
187+
# Reader connection is valid and verified.
188+
if is_initial_connection:
189+
self._host_list_provider_service.initial_connection_host_info = reader_candidate
190+
return reader_candidate_conn
191+
192+
except Exception:
193+
self._close_connection(reader_candidate_conn)
194+
if reader_candidate is not None:
195+
self._plugin_service.set_availability(reader_candidate.as_aliases(), HostAvailability.AVAILABLE)
196+
197+
return None
198+
199+
def _close_connection(self, connection: Optional[Connection]):
200+
if connection is not None:
201+
try:
202+
connection.close()
203+
except Exception:
204+
# ignore
205+
pass
206+
207+
def _delay(self, delay_ms: int):
208+
time.sleep(delay_ms / 1000)
209+
210+
def _get_writer(self) -> Optional[HostInfo]:
211+
return next(host for host in self._plugin_service.hosts if host.role == HostRole.WRITER)
212+
213+
def _get_reader(self, props: Properties) -> Optional[HostInfo]:
214+
strategy: Optional[str] = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(props)
215+
if strategy is not None and self._plugin_service.accepts_strategy(HostRole.READER, strategy):
216+
try:
217+
return self._plugin_service.get_host_info_by_strategy(HostRole.READER, strategy)
218+
except Exception:
219+
# Host isn't found
220+
return None
221+
222+
raise AwsWrapperError(Messages.get_formatted("AuroraInitialConnectionStrategyPlugin.UnsupportedStrategy", strategy))
223+
224+
def _has_no_readers(self) -> bool:
225+
if len(self._plugin_service.hosts) == 0:
226+
# Topology inconclusive.
227+
return False
228+
return next(host_info for host_info in self._plugin_service.hosts if host_info.role == HostRole.READER) is None
229+
230+
def _get_time(self):
231+
return time.perf_counter_ns()
232+
233+
234+
class AuroraInitialConnectionStrategyPluginFactory(PluginFactory):
235+
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
236+
return AuroraInitialConnectionStrategyPlugin(plugin_service, props)

aws_advanced_python_wrapper/driver_configuration_profiles.py

-44
This file was deleted.

0 commit comments

Comments
 (0)