Skip to content

Commit 91d90b9

Browse files
authored
chore: check for cached token and exception type before retrying (#902)
1 parent 400f826 commit 91d90b9

File tree

4 files changed

+71
-4
lines changed

4 files changed

+71
-4
lines changed

aws_advanced_python_wrapper/federated_plugin.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
9898
region
9999
)
100100

101-
token_info = FederatedAuthPlugin._token_cache.get(cache_key)
101+
token_info: Optional[TokenInfo] = FederatedAuthPlugin._token_cache.get(cache_key)
102102

103103
if token_info is not None and not token_info.is_expired():
104104
logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token)
@@ -110,7 +110,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
110110

111111
try:
112112
return connect_func()
113-
except Exception:
113+
except Exception as e:
114+
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
115+
raise e
116+
114117
self._update_authentication_token(host_info, props, user, region, cache_key)
115118

116119
try:

aws_advanced_python_wrapper/okta_plugin.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
9494
region
9595
)
9696

97-
token_info = OktaAuthPlugin._token_cache.get(cache_key)
97+
token_info: Optional[TokenInfo] = OktaAuthPlugin._token_cache.get(cache_key)
9898

9999
if token_info is not None and not token_info.is_expired():
100100
logger.debug("OktaAuthPlugin.UseCachedToken", token_info.token)
@@ -106,7 +106,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
106106

107107
try:
108108
return connect_func()
109-
except Exception:
109+
except Exception as e:
110+
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
111+
raise e
112+
110113
self._update_authentication_token(host_info, props, user, region, cache_key)
111114

112115
try:

tests/unit/test_federated_auth_plugin.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,37 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m
173173
assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN
174174

175175

176+
@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache)
177+
def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect,
178+
mock_credentials_provider_factory):
179+
test_props: Properties = Properties(
180+
{"plugins": "federated_auth", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"})
181+
WrapperProperties.DB_USER.set(test_props, _DB_USER)
182+
183+
exception_message = "generic exception"
184+
mock_func.side_effect = Exception(exception_message)
185+
186+
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory,
187+
mock_session)
188+
with pytest.raises(Exception) as e_info:
189+
target_plugin.connect(
190+
target_driver_func=mocker.MagicMock(),
191+
driver_dialect=mock_dialect,
192+
host_info=_PG_HOST_INFO,
193+
props=test_props,
194+
is_initial_connection=False,
195+
connect_func=mock_func)
196+
197+
mock_client.generate_db_auth_token.assert_called_with(
198+
DBHostname="pg.testdb.us-east-2.rds.amazonaws.com",
199+
Port=5432,
200+
DBUsername="postgresqlUser"
201+
)
202+
203+
assert e_info.type == Exception
204+
assert str(e_info.value) == exception_message
205+
206+
176207
@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache)
177208
def test_connect_with_specified_iam_host_port_region(mocker,
178209
mock_plugin_service,

tests/unit/test_okta_plugin.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,36 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m
170170
assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN
171171

172172

173+
@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache)
174+
def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_session, mock_func, mock_client,
175+
mock_dialect, mock_credentials_provider_factory):
176+
test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"})
177+
WrapperProperties.DB_USER.set(test_props, _DB_USER)
178+
179+
exception_message = "generic exception"
180+
mock_func.side_effect = Exception(exception_message)
181+
182+
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
183+
184+
with pytest.raises(Exception) as e_info:
185+
target_plugin.connect(
186+
target_driver_func=mocker.MagicMock(),
187+
driver_dialect=mock_dialect,
188+
host_info=_PG_HOST_INFO,
189+
props=test_props,
190+
is_initial_connection=False,
191+
connect_func=mock_func)
192+
193+
mock_client.generate_db_auth_token.assert_called_with(
194+
DBHostname="pg.testdb.us-east-2.rds.amazonaws.com",
195+
Port=5432,
196+
DBUsername="postgresqlUser"
197+
)
198+
199+
assert e_info.type == Exception
200+
assert str(e_info.value) == exception_message
201+
202+
173203
@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache)
174204
def test_connect_with_specified_iam_host_port_region(mocker,
175205
mock_plugin_service,

0 commit comments

Comments
 (0)