Skip to content

Commit 89b92e4

Browse files
authored
feat: configure target driver dialect before connection (#61)
1 parent c28a062 commit 89b92e4

22 files changed

+603
-185
lines changed

aws_wrapper/connection_provider.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@
1414

1515
from __future__ import annotations
1616

17+
from logging import getLogger
1718
from threading import Lock
18-
from typing import TYPE_CHECKING
19+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Protocol
1920

2021
if TYPE_CHECKING:
2122
from aws_wrapper.hostinfo import HostInfo, HostRole
2223
from aws_wrapper.pep249 import Connection
23-
from aws_wrapper.utils.properties import Properties
24-
25-
from typing import Callable, Dict, List, Optional, Protocol
24+
from aws_wrapper.target_driver_dialect import TargetDriverDialect
2625

2726
from aws_wrapper.errors import AwsWrapperError
2827
from aws_wrapper.hostselector import HostSelector, RandomHostSelector
2928
from aws_wrapper.plugin import CanReleaseResources
3029
from aws_wrapper.utils.messages import Messages
30+
from aws_wrapper.utils.properties import Properties, PropertiesUtils
31+
32+
logger = getLogger(__name__)
3133

3234

3335
class ConnectionProvider(Protocol):
@@ -101,11 +103,11 @@ def release_resources():
101103

102104

103105
class DriverConnectionProvider(ConnectionProvider):
104-
105106
_accepted_strategies: Dict[str, HostSelector] = {"random": RandomHostSelector()}
106107

107-
def __init__(self, connect_func: Callable):
108+
def __init__(self, connect_func: Callable, target_driver_dialect: TargetDriverDialect):
108109
self._connect_func = connect_func
110+
self._target_driver_dialect = target_driver_dialect
109111

110112
def accepts_host_info(self, host_info: HostInfo, properties: Properties) -> bool:
111113
return True
@@ -122,12 +124,8 @@ def get_host_info_by_strategy(self, hosts: List[HostInfo], role: HostRole, strat
122124
return host_selector.get_host(hosts, role)
123125

124126
def connect(self, host_info: HostInfo, properties: Properties) -> Connection:
125-
# TODO: Behavior based on dialects
126-
prop_copy = properties.copy()
127-
128-
prop_copy["host"] = host_info.host
129-
130-
if host_info.is_port_specified():
131-
prop_copy["port"] = str(host_info.port)
127+
prepared_properties = self._target_driver_dialect.prepare_connect_info(host_info, properties)
128+
logger.debug(
129+
f"Connecting to {host_info.host} with properties: {PropertiesUtils.log_properties(prepared_properties)}")
132130

133-
return self._connect_func(**prop_copy)
131+
return self._connect_func(**prepared_properties)

aws_wrapper/dialect.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
WrapperProperties)
2828
from aws_wrapper.utils.rdsutils import RdsUtils
2929
from .exceptions import ExceptionHandler, PgExceptionHandler
30+
from .target_driver_dialect import TargetDriverDialectCodes
3031
from .utils.cache_map import CacheMap
3132
from .utils.messages import Messages
3233

@@ -55,7 +56,7 @@ def from_string(value: str) -> DialectCode:
5556
raise AwsWrapperError(Messages.get_formatted("DialectCode.InvalidStringValue", value))
5657

5758

58-
class DatabaseType(Enum):
59+
class TargetDriverType(Enum):
5960
MYSQL = auto()
6061
POSTGRES = auto()
6162
MARIADB = auto()
@@ -121,7 +122,7 @@ def exception_handler(self) -> Optional[ExceptionHandler]:
121122

122123

123124
class DialectProvider(Protocol):
124-
def get_dialect(self, props: Properties) -> Optional[Dialect]:
125+
def get_dialect(self, driver_dialect: str, props: Properties) -> Optional[Dialect]:
125126
"""
126127
Returns the dialect identified by analyzing the AwsWrapperProperties.DIALECT property (if set) or the target
127128
driver method
@@ -470,7 +471,7 @@ def reset_custom_dialect(self):
470471
def reset_endpoint_cache(self):
471472
self._known_endpoint_dialects.clear()
472473

473-
def get_dialect(self, props: Properties) -> Optional[Dialect]:
474+
def get_dialect(self, driver_dialect: str, props: Properties) -> Optional[Dialect]:
474475
self._can_update = False
475476
self._dialect = None
476477

@@ -499,8 +500,8 @@ def get_dialect(self, props: Properties) -> Optional[Dialect]:
499500
raise AwsWrapperError(Messages.get_formatted("Dialect.UnknownDialectCode", str(dialect_code)))
500501

501502
host: str = props["host"]
502-
database_type: DatabaseType = self._get_database_type()
503-
if database_type is DatabaseType.MYSQL:
503+
target_driver_type: TargetDriverType = self._get_target_driver_type(driver_dialect)
504+
if target_driver_type is TargetDriverType.MYSQL:
504505
rds_type = self._rds_helper.identify_rds_type(host)
505506
if rds_type.is_rds_cluster:
506507
self._dialect_code = DialectCode.AURORA_MYSQL
@@ -518,7 +519,7 @@ def get_dialect(self, props: Properties) -> Optional[Dialect]:
518519
self._log_current_dialect()
519520
return self._dialect
520521

521-
if database_type is DatabaseType.POSTGRES:
522+
if target_driver_type is TargetDriverType.POSTGRES:
522523
rds_type = self._rds_helper.identify_rds_type(host)
523524
if rds_type.is_rds_cluster:
524525
self._dialect_code = DialectCode.AURORA_PG
@@ -536,7 +537,14 @@ def get_dialect(self, props: Properties) -> Optional[Dialect]:
536537
self._log_current_dialect()
537538
return self._dialect
538539

539-
if database_type is DatabaseType.MARIADB:
540+
if target_driver_type is TargetDriverType.MARIADB:
541+
rds_type = self._rds_helper.identify_rds_type(host)
542+
if rds_type.is_rds_cluster:
543+
# Aurora MariaDB doesn't exist.
544+
# If this is a cluster endpoint then user is trying to connect to AMS via the MariaDB driver.
545+
self._dialect_code = DialectCode.AURORA_MYSQL
546+
self._dialect = self._known_dialects_by_code.get(DialectCode.AURORA_MYSQL)
547+
return self._dialect
540548
self._can_update = True
541549
self._dialect_code = DialectCode.MARIADB
542550
self._dialect = self._known_dialects_by_code.get(DialectCode.MARIADB)
@@ -549,9 +557,15 @@ def get_dialect(self, props: Properties) -> Optional[Dialect]:
549557
self._log_current_dialect()
550558
return self._dialect
551559

552-
def _get_database_type(self) -> DatabaseType:
553-
# TODO: Add logic to identify database based on target driver connect info
554-
return DatabaseType.POSTGRES
560+
def _get_target_driver_type(self, driver_dialect: str) -> TargetDriverType:
561+
if driver_dialect == TargetDriverDialectCodes.PSYCOPG:
562+
return TargetDriverType.POSTGRES
563+
if driver_dialect == TargetDriverDialectCodes.MYSQL_CONNECTOR_PYTHON:
564+
return TargetDriverType.MYSQL
565+
if driver_dialect == TargetDriverDialectCodes.MARIADB_CONNECTOR_PYTHON:
566+
return TargetDriverType.MARIADB
567+
568+
return TargetDriverType.CUSTOM
555569

556570
def query_for_dialect(self, url: str, host_info: Optional[HostInfo], conn: Connection) -> Optional[Dialect]:
557571
if not self._can_update:

aws_wrapper/failover_plugin.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,14 @@
4141
from aws_wrapper.utils.properties import Properties, WrapperProperties
4242
from aws_wrapper.utils.rds_url_type import RdsUrlType
4343
from aws_wrapper.utils.rdsutils import RdsUtils
44-
from aws_wrapper.utils.utils import SubscribedMethodUtils
4544
from aws_wrapper.writer_failover_handler import (WriterFailoverHandler,
4645
WriterFailoverHandlerImpl)
4746

4847
logger = getLogger(__name__)
4948

5049

5150
class FailoverPlugin(Plugin):
52-
_SUBSCRIBED_METHODS: Set[str] = {*SubscribedMethodUtils.NETWORK_BOUND_METHODS,
53-
"init_host_provider",
51+
_SUBSCRIBED_METHODS: Set[str] = {"init_host_provider",
5452
"connect",
5553
"force_connect",
5654
"notify_host_list_changed"}
@@ -69,6 +67,7 @@ def __init__(self, plugin_service: PluginService, props: Properties):
6967
self._last_exception: Optional[Exception] = None
7068
self._rds_utils = RdsUtils()
7169
self._rds_url_type: RdsUrlType = self._rds_utils.identify_rds_type(self._properties.get("host"))
70+
FailoverPlugin._SUBSCRIBED_METHODS.add(*self._plugin_service.network_bound_methods)
7271

7372
def init_host_provider(
7473
self,
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from abc import ABC, abstractmethod
18+
from typing import TYPE_CHECKING, Callable, Set
19+
20+
from aws_wrapper.target_driver_dialect_codes import TargetDriverDialectCodes
21+
from aws_wrapper.utils.properties import Properties, PropertiesUtils
22+
23+
if TYPE_CHECKING:
24+
from aws_wrapper.hostinfo import HostInfo
25+
26+
27+
class TargetDriverDialect(ABC):
28+
_dialect_code: str = TargetDriverDialectCodes.GENERIC
29+
_network_bound_methods: Set[str] = {"*"}
30+
31+
@property
32+
def dialect_code(self) -> str:
33+
return self._dialect_code
34+
35+
@property
36+
def network_bound_methods(self) -> Set[str]:
37+
return self._network_bound_methods
38+
39+
@abstractmethod
40+
def is_dialect(self, conn: Callable) -> bool:
41+
pass
42+
43+
@abstractmethod
44+
def prepare_connect_info(self, host_info: HostInfo, props: Properties) -> Properties:
45+
pass
46+
47+
48+
class GenericTargetDriverDialect(TargetDriverDialect):
49+
50+
def is_dialect(self, conn: Callable) -> bool:
51+
return True
52+
53+
def prepare_connect_info(self, host_info: HostInfo, props: Properties) -> Properties:
54+
prop_copy: Properties = Properties(props.copy())
55+
56+
prop_copy["host"] = host_info.host
57+
58+
if host_info.is_port_specified():
59+
prop_copy["port"] = str(host_info.port)
60+
61+
PropertiesUtils.remove_wrapper_props(prop_copy)
62+
return props

aws_wrapper/host_monitoring_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from aws_wrapper.utils.properties import Properties, WrapperProperties
4141
from aws_wrapper.utils.rdsutils import RdsUtils
4242
from aws_wrapper.utils.timeout import timeout
43-
from aws_wrapper.utils.utils import QueueUtils, SubscribedMethodUtils
43+
from aws_wrapper.utils.utils import QueueUtils
4444

4545
logger = getLogger(__name__)
4646

@@ -92,7 +92,7 @@ def execute(self, target: object, method_name: str, execute_func: Callable, *arg
9292
raise AwsWrapperError(Messages.get_formatted("HostMonitoringPlugin.NullHostInfoForMethod", method_name))
9393

9494
is_enabled = WrapperProperties.FAILURE_DETECTION_ENABLED.get_bool(self._props)
95-
if not is_enabled or method_name not in SubscribedMethodUtils.NETWORK_BOUND_METHODS:
95+
if not is_enabled or not self._plugin_service.is_network_bound_method(method_name):
9696
return execute_func()
9797

9898
failure_detection_time_ms = WrapperProperties.FAILURE_DETECTION_TIME_MS.get_int(self._props)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from inspect import signature
18+
from typing import Callable, Set
19+
20+
from aws_wrapper.generic_target_driver_dialect import \
21+
GenericTargetDriverDialect
22+
from aws_wrapper.target_driver_dialect_codes import TargetDriverDialectCodes
23+
24+
25+
class MariaDBTargetDriverDialect(GenericTargetDriverDialect):
26+
TARGET_DRIVER = "Mariadb"
27+
28+
_dialect_code: str = TargetDriverDialectCodes.MARIADB_CONNECTOR_PYTHON
29+
_network_bound_methods: Set[str] = {
30+
"Connection.commit()",
31+
"Connection.rollback()",
32+
"Connection.cursor()",
33+
"Cursor.callproc()",
34+
"Cursor.execute()",
35+
"Cursor.fetchone()",
36+
"Cursor.fetchmany()",
37+
"Cursor.fetchall()",
38+
"Cursor.nextset()",
39+
}
40+
41+
def is_dialect(self, conn: Callable) -> bool:
42+
return MariaDBTargetDriverDialect.TARGET_DRIVER in str(signature(conn))
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from inspect import signature
18+
from typing import Callable, Set
19+
20+
from aws_wrapper.generic_target_driver_dialect import \
21+
GenericTargetDriverDialect
22+
from aws_wrapper.target_driver_dialect_codes import TargetDriverDialectCodes
23+
24+
25+
class MySQLTargetDriverDialect(GenericTargetDriverDialect):
26+
TARGET_DRIVER = "MySQL"
27+
28+
_dialect_code: str = TargetDriverDialectCodes.MYSQL_CONNECTOR_PYTHON
29+
_network_bound_methods: Set[str] = {
30+
"Connection.commit()",
31+
"Connection.rollback()",
32+
"Connection.cursor()",
33+
"Cursor.close()",
34+
"Cursor.execute()",
35+
"Cursor.fetchone()",
36+
"Cursor.fetchmany()",
37+
"Cursor.fetchall()",
38+
}
39+
40+
def is_dialect(self, conn: Callable) -> bool:
41+
return MySQLTargetDriverDialect.TARGET_DRIVER in str(signature(conn))
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from inspect import signature
18+
from typing import Callable, Set
19+
20+
from aws_wrapper.target_driver_dialect import GenericTargetDriverDialect
21+
from aws_wrapper.target_driver_dialect_codes import TargetDriverDialectCodes
22+
23+
24+
class PgTargetDriverDialect(GenericTargetDriverDialect):
25+
TARGET_DRIVER = "psycopg"
26+
27+
_dialect_code: str = TargetDriverDialectCodes.PSYCOPG
28+
_network_bound_methods: Set[str] = {
29+
"Connection.commit()",
30+
"Connection.rollback()",
31+
"Connection.cursor()",
32+
"Cursor.callproc()",
33+
"Cursor.execute()",
34+
"Cursor.fetchone()",
35+
"Cursor.fetchmany()",
36+
"Cursor.fetchall()",
37+
}
38+
39+
def is_dialect(self, conn: Callable) -> bool:
40+
return PgTargetDriverDialect.TARGET_DRIVER in str(signature(conn))

0 commit comments

Comments
 (0)