Skip to content

Commit eeebefa

Browse files
authored
fix: support both types of cn endpoint patterns (#701)
1 parent d0bce5f commit eeebefa

9 files changed

+267
-96
lines changed

aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,12 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
5050
"""
5151

5252
aliases: FrozenSet[str] = host_info.as_aliases()
53-
host: str = host_info.as_alias()
5453

55-
if self._rds_utils.is_rds_instance(host):
56-
self._track_connection(host, conn)
54+
if self._rds_utils.is_rds_instance(host_info.host):
55+
self._track_connection(host_info.as_alias(), conn)
5756
return
5857

59-
instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(alias)),
58+
instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))),
6059
None)
6160
if not instance_endpoint:
6261
logger.debug("OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet")
@@ -82,7 +81,7 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
8281
return
8382

8483
for instance in host:
85-
if instance is not None and self._rds_utils.is_rds_instance(instance):
84+
if instance is not None and self._rds_utils.is_rds_instance(self._rds_utils.remove_port(instance)):
8685
instance_endpoint = instance
8786
break
8887

aws_advanced_python_wrapper/host_list_provider.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ def _initialize(self):
199199
else:
200200
self._cluster_instance_template = HostInfo(
201201
host=self._rds_utils.get_rds_instance_host_pattern(self._initial_host_info.host),
202+
host_id=self._initial_host_info.host_id,
203+
port=self._initial_host_info.port,
202204
host_availability_strategy=host_availability_strategy)
203205
self._validate_host_pattern(self._cluster_instance_template.host)
204206

@@ -216,14 +218,15 @@ def _initialize(self):
216218
self._cluster_id = cluster_id_suggestion.cluster_id
217219
self._is_primary_cluster_id = cluster_id_suggestion.is_primary_cluster_id
218220
else:
219-
cluster_url = self._rds_utils.get_rds_cluster_host_url(self._initial_host_info.url)
221+
cluster_url = self._rds_utils.get_rds_cluster_host_url(self._initial_host_info.host)
220222
if cluster_url is not None:
221-
self._cluster_id = cluster_url
223+
self._cluster_id = f"{cluster_url}:{self._cluster_instance_template.port}" \
224+
if self._cluster_instance_template.is_port_specified() else cluster_url
222225
self._is_primary_cluster_id = True
223226
self._is_primary_cluster_id_cache.put(self._cluster_id, True,
224227
self._suggested_cluster_id_refresh_ns)
225228

226-
self._is_initialized = True
229+
self._is_initialized = True
227230

228231
def _validate_host_pattern(self, host: str):
229232
if not self._rds_utils.is_dns_pattern_valid(host):

aws_advanced_python_wrapper/sql_alchemy_connection_provider.py

+4
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,12 @@ def __init__(
5757
self,
5858
pool_configurator: Optional[Callable] = None,
5959
pool_mapping: Optional[Callable] = None,
60+
accept_url_func: Optional[Callable] = None,
6061
pool_expiration_check_ns: int = -1,
6162
pool_cleanup_interval_ns: int = -1):
6263
self._pool_configurator = pool_configurator
6364
self._pool_mapping = pool_mapping
65+
self._accept_url_func = accept_url_func
6466

6567
if pool_expiration_check_ns > -1:
6668
SqlAlchemyPooledConnectionProvider._POOL_EXPIRATION_CHECK_NS = pool_expiration_check_ns
@@ -80,6 +82,8 @@ def keys(self):
8082
return self._database_pools.keys()
8183

8284
def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool:
85+
if self._accept_url_func:
86+
return self._accept_url_func(host_info, props)
8387
url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host)
8488
return RdsUrlType.RDS_INSTANCE == url_type
8589

aws_advanced_python_wrapper/utils/rdsutils.py

+129-75
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from re import search, sub
16-
from typing import Optional
15+
from __future__ import annotations
16+
17+
from re import Match, search, sub
18+
from typing import Dict, Optional
1719

1820
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
1921

@@ -58,135 +60,156 @@ class RdsUtils:
5860
Example: test-postgres-instance-1.123456789012.rds.cn-northwest-1.amazonaws.com.cn
5961
"""
6062

61-
AURORA_DNS_PATTERN = r"(?P<instance>.+)\." \
62-
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-)?" \
63+
AURORA_DNS_PATTERN = r"^(?P<instance>.+)\." \
64+
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \
6365
r"(?P<domain>[a-zA-Z0-9]+\." \
64-
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)"
65-
AURORA_INSTANCE_PATTERN = r"(?P<instance>.+)\." \
66+
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$"
67+
AURORA_INSTANCE_PATTERN = r"^(?P<instance>.+)\." \
6668
r"(?P<domain>[a-zA-Z0-9]+\." \
67-
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)"
68-
AURORA_CLUSTER_PATTERN = r"(?P<instance>.+)\." \
69+
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$"
70+
AURORA_CLUSTER_PATTERN = r"^(?P<instance>.+)\." \
6971
r"(?P<dns>cluster-|cluster-ro-)+" \
7072
r"(?P<domain>[a-zA-Z0-9]+\." \
71-
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)"
72-
AURORA_CUSTOM_CLUSTER_PATTERN = r"(?P<instance>.+)\." \
73+
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$"
74+
AURORA_CUSTOM_CLUSTER_PATTERN = r"^(?P<instance>.+)\." \
7375
r"(?P<dns>cluster-custom-)+" \
7476
r"(?P<domain>[a-zA-Z0-9]+\." \
75-
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)"
76-
AURORA_PROXY_DNS_PATTERN = r"(?P<instance>.+)\." \
77+
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$"
78+
AURORA_PROXY_DNS_PATTERN = r"^(?P<instance>.+)\." \
7779
r"(?P<dns>proxy-)+" \
7880
r"(?P<domain>[a-zA-Z0-9]+\." \
79-
r"(?P<region>[a-zA-Z0-9\\-]+)\.rds\.amazonaws\.com)(?!\.cn$)"
80-
AURORA_CHINA_DNS_PATTERN = r"(?P<instance>.+)\." \
81-
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-)?" \
81+
r"(?P<region>[a-zA-Z0-9\\-]+)\.rds\.amazonaws\.com)(?!\.cn)$"
82+
AURORA_OLD_CHINA_DNS_PATTERN = r"^(?P<instance>.+)\." \
83+
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \
84+
r"(?P<domain>[a-zA-Z0-9]+\." \
85+
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)$"
86+
AURORA_CHINA_DNS_PATTERN = r"^(?P<instance>.+)\." \
87+
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \
8288
r"(?P<domain>[a-zA-Z0-9]+\." \
83-
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)"
84-
AURORA_CHINA_INSTANCE_PATTERN = r"(?P<instance>.+)\." \
85-
r"(?P<domain>[a-zA-Z0-9]+\." \
86-
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)"
87-
AURORA_CHINA_CLUSTER_PATTERN = r"(?P<instance>.+)\." \
89+
r"rds\.(?P<region>[a-zA-Z0-9\-]+)\.amazonaws\.com\.cn)$"
90+
AURORA_OLD_CHINA_CLUSTER_PATTERN = r"^(?P<instance>.+)\." \
91+
r"(?P<dns>cluster-|cluster-ro-)+" \
92+
r"(?P<domain>[a-zA-Z0-9]+\." \
93+
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)$"
94+
AURORA_CHINA_CLUSTER_PATTERN = r"^(?P<instance>.+)\." \
8895
r"(?P<dns>cluster-|cluster-ro-)+" \
8996
r"(?P<domain>[a-zA-Z0-9]+\." \
90-
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)"
91-
AURORA_CHINA_CUSTOM_CLUSTER_PATTERN = r"(?P<instance>.+)\." \
92-
r"(?P<dns>cluster-custom-)+" \
93-
r"(?P<domain>[a-zA-Z0-9]+\." \
94-
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)"
95-
AURORA_CHINA_PROXY_DNS_PATTERN = r"(?P<instance>.+)\." \
96-
r"(?P<dns>proxy-)+" \
97-
r"(?P<domain>[a-zA-Z0-9]+\." \
98-
r"(?P<region>[a-zA-Z0-9\-])+\.rds\.amazonaws\.com\.cn)"
97+
r"rds\.(?P<region>[a-zA-Z0-9\-]+)\.amazonaws\.com\.cn)$"
98+
AURORA_GOV_DNS_PATTERN = r"^(?P<instance>.+)\." \
99+
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \
100+
r"(?P<domain>[a-zA-Z0-9]+\.rds\.(?P<region>[a-zA-Z0-9\-]+)" \
101+
r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$"
102+
AURORA_GOV_CLUSTER_PATTERN = r"^(?P<instance>.+)\." \
103+
r"(?P<dns>cluster-|cluster-ro-)+" \
104+
r"(?P<domain>[a-zA-Z0-9]+\.rds\.(?P<region>[a-zA-Z0-9\-]+)" \
105+
r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$"
106+
ELB_PATTERN = r"^(?<instance>.+)\.elb\.((?<region>[a-zA-Z0-9\-]+)\.amazonaws\.com)$"
99107

100108
IP_V4 = r"^(([1-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){1}" \
101-
r"(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){2}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$"
102-
IP_V6 = r"^[0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){7}$"
103-
IP_V6_COMPRESSED = r"^(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)::(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)$"
109+
r"(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){2}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])"
110+
IP_V6 = r"^[0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){7}"
111+
IP_V6_COMPRESSED = r"^(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)::(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)"
104112

105113
DNS_GROUP = "dns"
106114
DOMAIN_GROUP = "domain"
107115
INSTANCE_GROUP = "instance"
108116
REGION_GROUP = "region"
109117

118+
CACHE_DNS_PATTERNS: Dict[str, Match[str]] = {}
119+
CACHE_PATTERNS: Dict[str, str] = {}
120+
110121
def is_rds_cluster_dns(self, host: str) -> bool:
111-
return self._contains(host, [self.AURORA_CLUSTER_PATTERN, self.AURORA_CHINA_CLUSTER_PATTERN])
122+
dns_group = self._get_dns_group(host)
123+
return dns_group is not None and dns_group.casefold() in ["cluster-", "cluster-ro-"]
112124

113125
def is_rds_custom_cluster_dns(self, host: str) -> bool:
114-
return self._contains(host, [self.AURORA_CUSTOM_CLUSTER_PATTERN, self.AURORA_CHINA_CUSTOM_CLUSTER_PATTERN])
126+
dns_group = self._get_dns_group(host)
127+
return dns_group is not None and dns_group.casefold() == "cluster-custom-"
115128

116129
def is_rds_dns(self, host: str) -> bool:
117-
return self._contains(host, [self.AURORA_DNS_PATTERN, self.AURORA_CHINA_DNS_PATTERN])
130+
if not host or not host.strip():
131+
return False
132+
133+
pattern = self._find(host, [RdsUtils.AURORA_DNS_PATTERN,
134+
RdsUtils.AURORA_CHINA_DNS_PATTERN,
135+
RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN,
136+
RdsUtils.AURORA_GOV_DNS_PATTERN])
137+
group = self._get_regex_group(pattern, RdsUtils.DNS_GROUP)
138+
139+
if group:
140+
RdsUtils.CACHE_PATTERNS[host] = group
141+
142+
return pattern is not None
118143

119144
def is_rds_instance(self, host: str) -> bool:
120-
return (self._contains(host, [self.AURORA_INSTANCE_PATTERN, self.AURORA_CHINA_INSTANCE_PATTERN])
121-
and self.is_rds_dns(host))
145+
return self._get_dns_group(host) is None and self.is_rds_dns(host)
122146

123147
def is_rds_proxy_dns(self, host: str) -> bool:
124-
return self._contains(host, [self.AURORA_PROXY_DNS_PATTERN, self.AURORA_CHINA_PROXY_DNS_PATTERN])
148+
dns_group = self._get_dns_group(host)
149+
return dns_group is not None and dns_group.casefold() == "proxy-"
125150

126151
def get_rds_instance_host_pattern(self, host: str) -> str:
127152
if not host or not host.strip():
128153
return "?"
129154

130-
match = self._find(host, [self.AURORA_DNS_PATTERN, self.AURORA_CHINA_DNS_PATTERN])
155+
match = self._get_group(host, RdsUtils.DOMAIN_GROUP)
131156
if match:
132-
return f"?.{match.group(self.DOMAIN_GROUP)}"
157+
return f"?.{match}"
133158

134159
return "?"
135160

136161
def get_rds_region(self, host: Optional[str]):
137162
if not host or not host.strip():
138163
return None
139164

140-
match = self._find(host, [self.AURORA_DNS_PATTERN, self.AURORA_CHINA_DNS_PATTERN])
141-
if match:
142-
return match.group(self.REGION_GROUP)
165+
group = self._get_group(host, RdsUtils.REGION_GROUP)
166+
if group:
167+
return group
143168

169+
elb_matcher = search(RdsUtils.ELB_PATTERN, host)
170+
if elb_matcher:
171+
return elb_matcher.group(RdsUtils.REGION_GROUP)
144172
return None
145173

146174
def is_writer_cluster_dns(self, host: str) -> bool:
147-
if not host or not host.strip():
148-
return False
149-
150-
match = self._find(host, [self.AURORA_CLUSTER_PATTERN, self.AURORA_CHINA_CLUSTER_PATTERN])
151-
if match:
152-
return "cluster-".casefold() == match.group(self.DNS_GROUP).casefold()
153-
154-
return False
175+
dns_group = self._get_dns_group(host)
176+
return dns_group is not None and dns_group.casefold() == "cluster-"
155177

156178
def is_reader_cluster_dns(self, host: str) -> bool:
157-
match = self._find(host, [self.AURORA_CLUSTER_PATTERN, self.AURORA_CHINA_CLUSTER_PATTERN])
158-
if match:
159-
return "cluster-ro-".casefold() == match.group(self.DNS_GROUP).casefold()
160-
161-
return False
179+
dns_group = self._get_dns_group(host)
180+
return dns_group is not None and dns_group.casefold() == "cluster-ro-"
162181

163182
def get_rds_cluster_host_url(self, host: str):
164183
if not host or not host.strip():
165184
return None
166185

167-
if search(self.AURORA_CLUSTER_PATTERN, host):
168-
return sub(self.AURORA_CLUSTER_PATTERN, r"\g<instance>.cluster-\g<domain>", host)
169-
170-
if search(self.AURORA_CHINA_CLUSTER_PATTERN, host):
171-
return sub(self.AURORA_CHINA_CLUSTER_PATTERN, r"\g<instance>.cluster-\g<domain>", host)
186+
for pattern in [RdsUtils.AURORA_DNS_PATTERN,
187+
RdsUtils.AURORA_CHINA_DNS_PATTERN,
188+
RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN,
189+
RdsUtils.AURORA_GOV_DNS_PATTERN]:
190+
if m := search(pattern, host):
191+
group = self._get_regex_group(m, RdsUtils.DNS_GROUP)
192+
if group is not None:
193+
return sub(pattern, r"\g<instance>.cluster-\g<domain>", host)
194+
return None
172195

173196
return None
174197

175198
def get_instance_id(self, host: str) -> Optional[str]:
176-
if not host or not host.strip():
177-
return None
178-
179-
match = self._find(host, [self.AURORA_INSTANCE_PATTERN, self.AURORA_CHINA_INSTANCE_PATTERN])
180-
if match:
181-
return match.group(self.INSTANCE_GROUP)
199+
if self._get_dns_group(host) is None:
200+
return self._get_group(host, self.INSTANCE_GROUP)
182201

183202
return None
184203

185204
def is_ipv4(self, host: str) -> bool:
186-
return self._contains(host, [self.IP_V4])
205+
if host is None or not host.strip():
206+
return False
207+
return search(RdsUtils.IP_V4, host) is not None
187208

188209
def is_ipv6(self, host: str) -> bool:
189-
return self._contains(host, [self.IP_V6, self.IP_V6_COMPRESSED])
210+
if host is None or not host.strip():
211+
return False
212+
return search(RdsUtils.IP_V6_COMPRESSED, host) is not None or search(RdsUtils.IP_V6, host) is not None
190213

191214
def is_dns_pattern_valid(self, host: str) -> bool:
192215
return "?" in host
@@ -210,17 +233,48 @@ def identify_rds_type(self, host: Optional[str]) -> RdsUrlType:
210233

211234
return RdsUrlType.OTHER
212235

213-
def _contains(self, host: str, patterns: list) -> bool:
214-
if not host or not host.strip():
215-
return False
216-
217-
return len([pattern for pattern in patterns if search(pattern, host)]) > 0
218-
219236
def _find(self, host: str, patterns: list):
220237
if not host or not host.strip():
221238
return None
222239

223240
for pattern in patterns:
241+
match = RdsUtils.CACHE_DNS_PATTERNS.get(host)
242+
if match:
243+
return match
244+
224245
match = search(pattern, host)
225246
if match:
247+
RdsUtils.CACHE_DNS_PATTERNS[host] = match
226248
return match
249+
250+
return None
251+
252+
def _get_regex_group(self, pattern: Match[str], group_name: str):
253+
if pattern is None:
254+
return None
255+
return pattern.group(group_name)
256+
257+
def _get_group(self, host: str, group: str):
258+
if not host or not host.strip():
259+
return None
260+
261+
pattern = self._find(host, [RdsUtils.AURORA_DNS_PATTERN,
262+
RdsUtils.AURORA_CHINA_DNS_PATTERN,
263+
RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN,
264+
RdsUtils.AURORA_GOV_DNS_PATTERN])
265+
return self._get_regex_group(pattern, group)
266+
267+
def _get_dns_group(self, host: str):
268+
return self._get_group(host, RdsUtils.DNS_GROUP)
269+
270+
def remove_port(self, url: str):
271+
if not url or not url.strip():
272+
return None
273+
if ":" in url:
274+
return url.split(":")[0]
275+
return url
276+
277+
@staticmethod
278+
def clear_cache():
279+
RdsUtils.CACHE_PATTERNS.clear()
280+
RdsUtils.CACHE_DNS_PATTERNS.clear()

tests/integration/container/conftest.py

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from aws_advanced_python_wrapper.host_list_provider import RdsHostListProvider
2929
from aws_advanced_python_wrapper.plugin_service import PluginServiceImpl
3030
from aws_advanced_python_wrapper.utils.log import Logger
31+
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
3132

3233
if TYPE_CHECKING:
3334
from .utils.test_driver import TestDriver
@@ -124,6 +125,7 @@ def pytest_runtest_setup(item):
124125

125126
assert cluster_ip == writer_ip
126127

128+
RdsUtils.clear_cache()
127129
RdsHostListProvider._topology_cache.clear()
128130
RdsHostListProvider._is_primary_cluster_id_cache.clear()
129131
RdsHostListProvider._cluster_ids_to_update.clear()

tests/integration/container/test_autoscaling.py

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def test_pooled_connection_auto_scaling__set_read_only_on_old_connection(
101101
provider = SqlAlchemyPooledConnectionProvider(
102102
lambda _, __: {"pool_size": original_cluster_size},
103103
None,
104+
None,
104105
120000000000, # 2 minutes
105106
180000000000) # 3 minutes
106107
ConnectionProviderManager.set_connection_provider(provider)
@@ -167,6 +168,7 @@ def test_pooled_connection_auto_scaling__failover_from_deleted_reader(
167168
provider = SqlAlchemyPooledConnectionProvider(
168169
lambda _, __: {"pool_size": len(instances) * 5},
169170
None,
171+
None,
170172
120000000000, # 2 minutes
171173
180000000000) # 3 minutes
172174
ConnectionProviderManager.set_connection_provider(provider)

0 commit comments

Comments
 (0)