2121
2222import pytest
2323
24+ from aws_advanced_python_wrapper .errors import AwsWrapperError
2425from aws_advanced_python_wrapper .hostinfo import HostInfo
2526from aws_advanced_python_wrapper .iam_plugin import IamAuthPlugin , TokenInfo
2627from aws_advanced_python_wrapper .utils .properties import (Properties ,
@@ -350,10 +351,17 @@ def test_connect_with_specified_region(mocker, mock_plugin_service, mock_session
350351 mock_dialect .set_password .assert_called_with (expected_props , f"{ _TEST_TOKEN } :{ iam_region } " )
351352
352353
354+ @pytest .mark .parametrize ("iam_host" , [
355+ pytest .param ("foo.testdb.us-east-2.rds.amazonaws.com" ),
356+ pytest .param ("test.cluster-123456789012.us-east-2.rds.amazonaws.com" ),
357+ pytest .param ("test-.cluster-ro-123456789012.us-east-2.rds.amazonaws.com" ),
358+ pytest .param ("test.cluster-custom-123456789012.us-east-2.rds.amazonaws.com" ),
359+ pytest .param ("test-.proxy-123456789012.us-east-2.rds.amazonaws.com.cn" ),
360+ pytest .param ("test-.proxy-123456789012.us-east-2.rds.amazonaws.com.proxy" ),
361+ ])
353362@patch ("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache" , _token_cache )
354- def test_connect_with_specified_host (mocker , mock_plugin_service , mock_session , mock_func , mock_client , mock_dialect ):
363+ def test_connect_with_specified_host (iam_host : str , mocker , mock_plugin_service , mock_session , mock_func , mock_client , mock_dialect ):
355364 test_props : Properties = Properties ({"user" : "postgresqlUser" })
356- iam_host : str = "foo.testdb.us-east-2.rds.amazonaws.com"
357365
358366 test_props [WrapperProperties .IAM_HOST .name ] = iam_host
359367
@@ -365,7 +373,7 @@ def test_connect_with_specified_host(mocker, mock_plugin_service, mock_session,
365373 target_plugin .connect (
366374 target_driver_func = mocker .MagicMock (),
367375 driver_dialect = mock_dialect ,
368- host_info = HostInfo ("pg.testdb.us-east-2.rds.amazonaws .com" ),
376+ host_info = HostInfo ("bar.foo .com" ),
369377 props = test_props ,
370378 is_initial_connection = False ,
371379 connect_func = mock_func )
@@ -376,16 +384,38 @@ def test_connect_with_specified_host(mocker, mock_plugin_service, mock_session,
376384 DBUsername = "postgresqlUser"
377385 )
378386
379- actual_token = _token_cache .get ("us-east-2:foo.testdb.us-east-2.rds.amazonaws.com:5432:postgresqlUser" )
387+ actual_token = _token_cache .get (f"us-east-2:{ iam_host } :5432:postgresqlUser" )
388+ assert actual_token is not None
380389 assert _GENERATED_TOKEN != actual_token .token
381- assert f"{ _TEST_TOKEN } :foo.testdb.us-east-2.rds.amazonaws.com " == actual_token .token
390+ assert f"{ _TEST_TOKEN } :{ iam_host } " == actual_token .token
382391 assert actual_token .is_expired () is False
383392
384- # Assert password has been updated to the value in token cache
385- expected_props = {"iam_host" : "foo.testdb.us-east-2.rds.amazonaws.com" , "user" : "postgresqlUser" }
386- mock_dialect .set_password .assert_called_with (expected_props , f"{ _TEST_TOKEN } :foo.testdb.us-east-2.rds.amazonaws.com" )
387-
388393
389394def test_aws_supported_regions_url_exists ():
390395 url = "https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html"
391396 assert 200 == urllib .request .urlopen (url ).getcode ()
397+
398+
399+ @pytest .mark .parametrize ("host" , [
400+ pytest .param ("<>" ),
401+ pytest .param ("#" ),
402+ pytest .param ("'" ),
403+ pytest .param ("\" " ),
404+ pytest .param ("%" ),
405+ pytest .param ("^" ),
406+ pytest .param ("https://foo.com/abc.html" ),
407+ pytest .param ("foo.boo//" ),
408+ pytest .param ("8.8.8.8" ),
409+ pytest .param ("a.b" ),
410+ ])
411+ def test_invalid_iam_host (host , mocker , mock_plugin_service , mock_session , mock_func , mock_client , mock_dialect ):
412+ test_props : Properties = Properties ({"user" : "postgresqlUser" })
413+ with pytest .raises (AwsWrapperError ):
414+ target_plugin : IamAuthPlugin = IamAuthPlugin (mock_plugin_service , mock_session )
415+ target_plugin .connect (
416+ target_driver_func = mocker .MagicMock (),
417+ driver_dialect = mock_dialect ,
418+ host_info = HostInfo (host ),
419+ props = test_props ,
420+ is_initial_connection = False ,
421+ connect_func = mock_func )
0 commit comments