-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathaurora_connection_tracker_plugin.py
256 lines (200 loc) · 9.86 KB
/
aurora_connection_tracker_plugin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from threading import Thread
from typing import (TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Optional,
Set, Tuple)
if TYPE_CHECKING:
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
from aws_advanced_python_wrapper.utils.properties import Properties
from _weakrefset import WeakSet
from aws_advanced_python_wrapper.errors import FailoverError
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
logger = Logger(__name__)
class OpenedConnectionTracker:
_opened_connections: Dict[str, WeakSet] = {}
_rds_utils = RdsUtils()
def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
"""
Add the given connection to the set of tracked connections.
:param host_info: host information of the given connection.
:param conn: currently opened connection.
"""
aliases: FrozenSet[str] = host_info.as_aliases()
host: str = host_info.as_alias()
if self._rds_utils.is_rds_instance(host):
self._track_connection(host, conn)
return
instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(alias)),
None)
if not instance_endpoint:
logger.debug("OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet")
return
self._track_connection(instance_endpoint, conn)
def invalidate_current_connection(self, host_info: HostInfo, conn: Optional[Connection]):
host: Optional[str] = host_info.as_alias() \
if self._rds_utils.is_rds_instance(host_info.host) \
else next(alias for alias in host_info.aliases if self._rds_utils.is_rds_instance(alias))
if not host:
return
connection_set: Optional[WeakSet] = self._opened_connections.get(host)
if connection_set is not None:
self._log_connection_set(host, connection_set)
connection_set.discard(conn)
def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: Optional[FrozenSet[str]] = None):
"""
Invalidates all opened connections pointing to the same host in a daemon thread.
:param host_info: the :py:class:`HostInfo` object containing the URL of the host.
:param host: the set of aliases representing a specific host.
"""
if host_info:
self.invalidate_all_connections(host=frozenset(host_info.as_alias()))
self.invalidate_all_connections(host=host_info.as_aliases())
return
if host is None:
return
instance_endpoint = next(instance for instance in host if self._rds_utils.is_rds_instance(instance))
if not instance_endpoint:
return
connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint)
if connection_set is not None:
self._log_connection_set(instance_endpoint, connection_set)
self._invalidate_connections(connection_set)
def _track_connection(self, instance_endpoint: str, conn: Connection):
connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint)
if connection_set is None:
connection_set = WeakSet()
connection_set.add(conn)
self._opened_connections[instance_endpoint] = connection_set
else:
connection_set.add(conn)
self.log_opened_connections()
@staticmethod
def _task(connection_set: WeakSet):
while connection_set is not None and len(connection_set) > 0:
conn_reference = connection_set.pop()
if conn_reference is None:
continue
try:
conn_reference.close()
except Exception:
# Swallow this exception, current connection should be useless anyway
pass
def _invalidate_connections(self, connection_set: WeakSet):
invalidate_connection_thread: Thread = Thread(daemon=True, target=self._task,
args=[connection_set]) # type: ignore
invalidate_connection_thread.start()
def log_opened_connections(self):
msg = ""
for key, conn_set in self._opened_connections.items():
conn = ""
for item in list(conn_set):
conn += f"\n\t\t{item}"
msg += f"\t[{key} : {conn}]"
return logger.debug("OpenedConnectionTracker.OpenedConnectionsTracked", msg)
def _log_connection_set(self, host: Optional[str], conn_set: Optional[WeakSet]):
if host is None or conn_set is None or len(conn_set) == 0:
return
conn = ""
for item in list(conn_set):
conn += f"\n\t\t{item}"
msg = host + f"[{conn}\n]"
logger.debug("OpenedConnectionTracker.InvalidatingConnections", msg)
class AuroraConnectionTrackerPlugin(Plugin):
_SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"}
_current_writer: Optional[HostInfo] = None
_need_update_current_writer: bool = False
_METHOD_CLOSE = "Connection.close"
@property
def subscribed_methods(self) -> Set[str]:
return AuroraConnectionTrackerPlugin._SUBSCRIBED_METHODS.union(self._plugin_service.network_bound_methods)
def __init__(self,
plugin_service: PluginService,
props: Properties,
rds_utils: RdsUtils = RdsUtils(),
tracker: OpenedConnectionTracker = OpenedConnectionTracker()):
self._plugin_service = plugin_service
self._props = props
self._rds_utils = rds_utils
self._tracker = tracker
def connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
return self._connect(host_info, connect_func)
def force_connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
force_connect_func: Callable) -> Connection:
return self._connect(host_info, force_connect_func)
def _connect(self, host_info: HostInfo, connect_func: Callable):
conn = connect_func()
if conn:
url_type: RdsUrlType = self._rds_utils.identify_rds_type(host_info.host)
if url_type.is_rds_cluster:
host_info.reset_aliases()
self._plugin_service.fill_aliases(conn, host_info)
self._tracker.populate_opened_connection_set(host_info, conn)
self._tracker.log_opened_connections()
return conn
def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any:
self._remember_writer()
try:
results = execute_func()
if method_name == AuroraConnectionTrackerPlugin._METHOD_CLOSE and self._plugin_service.current_host_info is not None:
self._tracker.invalidate_current_connection(self._plugin_service.current_host_info, self._plugin_service.current_connection)
elif self._need_update_current_writer:
self._check_writer_changed()
return results
except Exception as e:
# Check that e is a FailoverError and that the writer has changed
if isinstance(e, FailoverError):
self._check_writer_changed()
raise e
def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]:
for host in hosts:
if host.role == HostRole.WRITER:
return host
return None
def _remember_writer(self):
if self._current_writer is None or self._need_update_current_writer:
self._current_writer = self._get_writer(self._plugin_service.hosts)
self._need_update_current_writer = False
def _check_writer_changed(self):
host_info_after_failover = self._get_writer(self._plugin_service.hosts)
if self._current_writer is None:
self._current_writer = host_info_after_failover
self._need_update_current_writer = False
elif self._current_writer != host_info_after_failover:
self._tracker.invalidate_all_connections(self._current_writer)
self._tracker.log_opened_connections()
self._current_writer = host_info_after_failover
self._need_update_current_writer = False
class AuroraConnectionTrackerPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return AuroraConnectionTrackerPlugin(plugin_service, props)