Skip to content

Commit c721c6f

Browse files
committed
feat: implement configuration profiles
1 parent 6274e73 commit c721c6f

21 files changed

+568
-126
lines changed

aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py

+45-17
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,19 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
6464

6565
self._track_connection(instance_endpoint, conn)
6666

67+
def invalidate_current_connection(self, host_info: HostInfo, conn: Optional[Connection]):
68+
host: Optional[str] = host_info.as_alias() \
69+
if self._rds_utils.is_rds_instance(host_info.host) \
70+
else next(alias for alias in host_info.aliases if self._rds_utils.is_rds_instance(alias))
71+
72+
if not host:
73+
return
74+
75+
connection_set: Optional[WeakSet] = self._opened_connections.get(host)
76+
if connection_set is not None:
77+
self._log_connection_set(host, connection_set)
78+
connection_set.discard(conn)
79+
6780
def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: Optional[FrozenSet[str]] = None):
6881
"""
6982
Invalidates all opened connections pointing to the same host in a daemon thread.
@@ -77,14 +90,10 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
7790
self.invalidate_all_connections(host=host_info.as_aliases())
7891
return
7992

80-
instance_endpoint: Optional[str] = None
8193
if host is None:
8294
return
8395

84-
for instance in host:
85-
if instance is not None and self._rds_utils.is_rds_instance(instance):
86-
instance_endpoint = instance
87-
break
96+
instance_endpoint = next(instance for instance in host if self._rds_utils.is_rds_instance(instance))
8897

8998
if not instance_endpoint:
9099
return
@@ -135,8 +144,8 @@ def log_opened_connections(self):
135144

136145
return logger.debug("OpenedConnectionTracker.OpenedConnectionsTracked", msg)
137146

138-
def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]):
139-
if conn_set is None or len(conn_set) == 0:
147+
def _log_connection_set(self, host: Optional[str], conn_set: Optional[WeakSet]):
148+
if host is None or conn_set is None or len(conn_set) == 0:
140149
return
141150

142151
conn = ""
@@ -148,13 +157,14 @@ def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]):
148157

149158

150159
class AuroraConnectionTrackerPlugin(Plugin):
151-
_SUBSCRIBED_METHODS: Set[str] = {"*"}
160+
_SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"}
152161
_current_writer: Optional[HostInfo] = None
153162
_need_update_current_writer: bool = False
163+
_METHOD_CLOSE = "Connection.close"
154164

155165
@property
156166
def subscribed_methods(self) -> Set[str]:
157-
return self._SUBSCRIBED_METHODS
167+
return AuroraConnectionTrackerPlugin._SUBSCRIBED_METHODS.union(self._plugin_service.network_bound_methods)
158168

159169
def __init__(self,
160170
plugin_service: PluginService,
@@ -201,19 +211,20 @@ def _connect(self, host_info: HostInfo, connect_func: Callable):
201211
return conn
202212

203213
def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any:
204-
if self._current_writer is None or self._need_update_current_writer:
205-
self._current_writer = self._get_writer(self._plugin_service.hosts)
206-
self._need_update_current_writer = False
214+
self._remember_writer()
207215

208216
try:
209-
return execute_func()
217+
results = execute_func()
218+
if method_name == AuroraConnectionTrackerPlugin._METHOD_CLOSE and self._plugin_service.current_host_info is not None:
219+
self._tracker.invalidate_current_connection(self._plugin_service.current_host_info, self._plugin_service.current_connection)
220+
elif self._need_update_current_writer:
221+
self._check_writer_changed()
222+
return results
210223

211224
except Exception as e:
212225
# Check that e is a FailoverError and that the writer has changed
213-
if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.hosts) != self._current_writer:
214-
self._tracker.invalidate_all_connections(host_info=self._current_writer)
215-
self._tracker.log_opened_connections()
216-
self._need_update_current_writer = True
226+
if isinstance(e, FailoverError):
227+
self._check_writer_changed()
217228
raise e
218229

219230
def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]:
@@ -222,6 +233,23 @@ def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]:
222233
return host
223234
return None
224235

236+
def _remember_writer(self):
237+
if self._current_writer is None or self._need_update_current_writer:
238+
self._current_writer = self._get_writer(self._plugin_service.hosts)
239+
self._need_update_current_writer = False
240+
241+
def _check_writer_changed(self):
242+
host_info_after_failover = self._get_writer(self._plugin_service.hosts)
243+
244+
if self._current_writer is None:
245+
self._current_writer = host_info_after_failover
246+
self._need_update_current_writer = False
247+
elif self._current_writer != host_info_after_failover:
248+
self._tracker.invalidate_all_connections(self._current_writer)
249+
self._tracker.log_opened_connections()
250+
self._current_writer = host_info_after_failover
251+
self._need_update_current_writer = False
252+
225253

226254
class AuroraConnectionTrackerPluginFactory(PluginFactory):
227255
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import time
18+
from typing import TYPE_CHECKING, Callable, Optional, Set
19+
20+
if TYPE_CHECKING:
21+
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
22+
from aws_advanced_python_wrapper.pep249 import Connection
23+
from aws_advanced_python_wrapper.plugin_service import PluginService
24+
from aws_advanced_python_wrapper.host_list_provider import HostListProviderService
25+
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
26+
27+
from aws_advanced_python_wrapper.errors import AwsWrapperError
28+
from aws_advanced_python_wrapper.host_availability import HostAvailability
29+
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
30+
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
31+
from aws_advanced_python_wrapper.utils.messages import Messages
32+
from aws_advanced_python_wrapper.utils.properties import (Properties,
33+
WrapperProperties)
34+
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
35+
36+
37+
class AuroraInitialConnectionStrategyPlugin(Plugin):
38+
_plugin_service: PluginService
39+
_host_list_provider_service: HostListProviderService
40+
_rds_utils: RdsUtils
41+
42+
@property
43+
def subscribed_methods(self) -> Set[str]:
44+
return {"init_host_provider", "connect", "force_connect"}
45+
46+
def __init__(self, plugin_service: PluginService, properties: Properties):
47+
self._plugin_service = plugin_service
48+
49+
def init_host_provider(self, props: Properties, host_list_provider_service: HostListProviderService, init_host_provider_func: Callable):
50+
self._host_list_provider_service = host_list_provider_service
51+
if host_list_provider_service.is_static_host_list_provider():
52+
raise AwsWrapperError(Messages.get("AuroraInitialConnectionStrategyPlugin.RequireDynamicProvider"))
53+
init_host_provider_func()
54+
55+
def connect(self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties,
56+
is_initial_connection: bool, connect_func: Callable) -> Connection:
57+
return self._connect_internal(host_info, props, is_initial_connection, connect_func)
58+
59+
def force_connect(self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties,
60+
is_initial_connection: bool, force_connect_func: Callable) -> Connection:
61+
return self._connect_internal(host_info, props, is_initial_connection, force_connect_func)
62+
63+
def _connect_internal(self, host_info: HostInfo, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Connection:
64+
type: RdsUrlType = self._rds_utils.identify_rds_type(host_info.host)
65+
if not type.is_rds_cluster:
66+
return connect_func()
67+
68+
if type == RdsUrlType.RDS_WRITER_CLUSTER:
69+
writer_candidate_conn = self._get_verified_writer_connection(props, is_initial_connection, connect_func)
70+
if writer_candidate_conn is None:
71+
return connect_func()
72+
return writer_candidate_conn
73+
74+
if type == RdsUrlType.RDS_READER_CLUSTER:
75+
reader_candidate_conn = self._get_verified_reader_connection(props, is_initial_connection, connect_func)
76+
if reader_candidate_conn is None:
77+
return connect_func()
78+
return reader_candidate_conn
79+
80+
# Continue with a normal workflow.
81+
return connect_func()
82+
83+
def _get_verified_writer_connection(self, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Optional[Connection]:
84+
retry_delay_ms: int = WrapperProperties.OPEN_CONNECTION_RETRY_TIMEOUT_MS.get_int(props)
85+
end_time_nano = self._get_time() + retry_delay_ms * 1_000_000
86+
87+
writer_candidate_conn: Optional[Connection]
88+
writer_candidate: Optional[HostInfo]
89+
90+
while self._get_time() < end_time_nano:
91+
writer_candidate_conn = None
92+
writer_candidate = None
93+
94+
try:
95+
writer_candidate = self._get_writer()
96+
if writer_candidate_conn is None or self._rds_utils.is_rds_cluster_dns(writer_candidate.host):
97+
writer_candidate_conn = connect_func()
98+
self._plugin_service.force_refresh_host_list(writer_candidate_conn)
99+
writer_candidate = self._plugin_service.identify_connection(writer_candidate_conn)
100+
101+
if writer_candidate is not None and writer_candidate.role != HostRole.WRITER:
102+
# Shouldn't be here. But let's try again.
103+
self._close_connection(writer_candidate_conn)
104+
self._delay(retry_delay_ms)
105+
continue
106+
107+
if is_initial_connection:
108+
self._host_list_provider_service.initial_connection_host_info = writer_candidate
109+
110+
return writer_candidate_conn
111+
112+
writer_candidate_conn = self._plugin_service.connect(writer_candidate, props)
113+
114+
if self._plugin_service.get_host_role(writer_candidate_conn) != HostRole.WRITER:
115+
self._plugin_service.force_refresh_host_list(writer_candidate_conn)
116+
self._close_connection(writer_candidate_conn)
117+
self._delay(retry_delay_ms)
118+
continue
119+
120+
if is_initial_connection:
121+
self._host_list_provider_service.initial_connection_host_info = writer_candidate
122+
return writer_candidate_conn
123+
124+
except Exception as e:
125+
if writer_candidate is not None:
126+
self._plugin_service.set_availability(writer_candidate.as_aliases(), HostAvailability.UNAVAILABLE)
127+
self._close_connection(writer_candidate_conn)
128+
raise e
129+
130+
return None
131+
132+
def _get_verified_reader_connection(self, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Optional[Connection]:
133+
retry_delay_ms: int = WrapperProperties.OPEN_CONNECTION_RETRY_INTERVAL_MS.get_int(props)
134+
end_time_nano = self._get_time() + WrapperProperties.OPEN_CONNECTION_RETRY_TIMEOUT_MS.get_int(props) * 1_000_000
135+
136+
reader_candidate_conn: Optional[Connection]
137+
reader_candidate: Optional[HostInfo]
138+
139+
while self._get_time() < end_time_nano:
140+
reader_candidate_conn = None
141+
reader_candidate = None
142+
143+
try:
144+
reader_candidate = self._get_reader(props)
145+
if reader_candidate is None or self._rds_utils.is_rds_cluster_dns(reader_candidate.host):
146+
# Reader not found, topology may be outdated
147+
reader_candidate_conn = connect_func()
148+
self._plugin_service.force_refresh_host_list(reader_candidate_conn)
149+
reader_candidate = self._plugin_service.identify_connection(reader_candidate_conn)
150+
151+
if reader_candidate is not None and reader_candidate.role != HostRole.READER:
152+
if self._has_no_readers():
153+
# Cluster has no readers. Simulate Aurora reader cluster endpoint logic
154+
if is_initial_connection and reader_candidate.host is not None:
155+
self._host_list_provider_service.initial_connection_host_info = reader_candidate
156+
return reader_candidate_conn
157+
self._close_connection(reader_candidate_conn)
158+
self._delay(retry_delay_ms)
159+
continue
160+
161+
if reader_candidate is not None and is_initial_connection:
162+
self._host_list_provider_service.initial_connection_host_info = reader_candidate
163+
return reader_candidate_conn
164+
165+
reader_candidate_conn = self._plugin_service.connect(reader_candidate, props)
166+
if self._plugin_service.get_host_role(reader_candidate_conn) != HostRole.READER:
167+
# If the new connection resolves to a writer instance, this means the topology is outdated.
168+
# Force refresh to update the topology.
169+
self._plugin_service.force_refresh_host_list(reader_candidate_conn)
170+
171+
if self._has_no_readers():
172+
# Cluster has no readers. Simulate Aurora reader cluster endpoint logic
173+
if is_initial_connection:
174+
self._host_list_provider_service.initial_connection_host_info = reader_candidate
175+
return reader_candidate_conn
176+
177+
self._close_connection(reader_candidate_conn)
178+
self._delay(retry_delay_ms)
179+
continue
180+
181+
# Reader connection is valid and verified.
182+
if is_initial_connection:
183+
self._host_list_provider_service.initial_connection_host_info = reader_candidate
184+
return reader_candidate_conn
185+
186+
except Exception:
187+
self._close_connection(reader_candidate_conn)
188+
if reader_candidate is not None:
189+
self._plugin_service.set_availability(reader_candidate.as_aliases(), HostAvailability.AVAILABLE)
190+
191+
return None
192+
193+
def _close_connection(self, connection: Optional[Connection]):
194+
if connection is not None:
195+
try:
196+
connection.close()
197+
except Exception:
198+
# ignore
199+
pass
200+
201+
def _delay(self, delay_ms: int):
202+
time.sleep(delay_ms / 1000)
203+
204+
def _get_writer(self) -> Optional[HostInfo]:
205+
return next(host for host in self._plugin_service.hosts if host.role == HostRole.WRITER)
206+
207+
def _get_reader(self, props: Properties) -> Optional[HostInfo]:
208+
strategy: Optional[str] = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(props)
209+
if strategy is not None and self._plugin_service.accepts_strategy(HostRole.READER, strategy):
210+
try:
211+
return self._plugin_service.get_host_info_by_strategy(HostRole.READER, strategy)
212+
except Exception:
213+
# Host isn't found
214+
return None
215+
216+
raise AwsWrapperError(Messages.get_formatted("AuroraInitialConnectionStrategyPlugin.UnsupportedStrategy", strategy))
217+
218+
def _has_no_readers(self) -> bool:
219+
if len(self._plugin_service.hosts) == 0:
220+
# Topology inconclusive.
221+
return False
222+
return next(host_info for host_info in self._plugin_service.hosts if host_info.role == HostRole.READER) is None
223+
224+
def _get_time(self):
225+
return time.perf_counter_ns()
226+
227+
228+
class AuroraInitialConnectionStrategyPluginFactory(PluginFactory):
229+
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
230+
return AuroraInitialConnectionStrategyPlugin(plugin_service, props)

aws_advanced_python_wrapper/driver_configuration_profiles.py

-44
This file was deleted.

0 commit comments

Comments
 (0)