Skip to content

Commit 6dc9877

Browse files
chore: improve tests with mocks and coverage (GoogleCloudPlatform#238)
* chore: update refresh_utils tests to use mocks * chore: add basic tests for Connector
1 parent 8eaae43 commit 6dc9877

File tree

5 files changed

+175
-34
lines changed

5 files changed

+175
-34
lines changed

noxfile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def lint(session):
3939
"sqlalchemy-stubs==0.4",
4040
"types-pkg-resources==0.1.3",
4141
"types-PyMySQL==1.0.6",
42+
"types-mock==4.0.5",
4243
"twine==3.7.1",
4344
)
4445
session.install("-r", "requirements.txt")

tests/system/test_connector_object.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
""""
2+
Copyright 2021 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
import os
17+
import pymysql
18+
import sqlalchemy
19+
import logging
20+
import google.auth
21+
from google.cloud.sql.connector import connector
22+
23+
24+
def init_connection_engine(
25+
custom_connector: connector.Connector,
26+
) -> sqlalchemy.engine.Engine:
27+
def getconn() -> pymysql.connections.Connection:
28+
conn = custom_connector.connect(
29+
os.environ["MYSQL_CONNECTION_NAME"],
30+
"pymysql",
31+
user=os.environ["MYSQL_USER"],
32+
password=os.environ["MYSQL_PASS"],
33+
db=os.environ["MYSQL_DB"],
34+
)
35+
return conn
36+
37+
engine = sqlalchemy.create_engine(
38+
"mysql+pymysql://",
39+
creator=getconn,
40+
)
41+
return engine
42+
43+
44+
def test_connector_with_credentials() -> None:
45+
"""Test Connector object connection with credentials loaded from file."""
46+
credentials, project = google.auth.load_credentials_from_file(
47+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
48+
)
49+
custom_connector = connector.Connector(credentials=credentials)
50+
try:
51+
pool = init_connection_engine(custom_connector)
52+
53+
with pool.connect() as conn:
54+
conn.execute("SELECT 1")
55+
56+
except Exception as e:
57+
logging.exception("Failed to connect with credentials from file!", e)
58+
59+
60+
def test_multiple_connectors() -> None:
61+
"""Test that same Cloud SQL instance can connect with two Connector objects."""
62+
first_connector = connector.Connector()
63+
second_connector = connector.Connector()
64+
try:
65+
pool = init_connection_engine(first_connector)
66+
pool2 = init_connection_engine(second_connector)
67+
68+
with pool.connect() as conn:
69+
conn.execute("SELECT 1")
70+
71+
with pool2.connect() as conn:
72+
conn.execute("SELECT 1")
73+
74+
instance_connection_string = os.environ["MYSQL_CONNECTION_NAME"]
75+
assert instance_connection_string in first_connector._instances
76+
assert instance_connection_string in second_connector._instances
77+
assert (
78+
first_connector._instances[instance_connection_string]
79+
!= second_connector._instances[instance_connection_string]
80+
)
81+
except Exception as e:
82+
logging.exception("Failed to connect with multiple Connector objects!", e)

tests/unit/test_connector.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import pytest # noqa F401 Needed to run the tests
1818
from google.cloud.sql.connector.instance_connection_manager import (
19+
IPTypes,
1920
InstanceConnectionManager,
2021
)
2122
from google.cloud.sql.connector import connector
@@ -53,3 +54,12 @@ async def timeout_stub(*args: Any, **kwargs: Any) -> None:
5354
"pymysql",
5455
timeout=timeout,
5556
)
57+
58+
59+
def test_default_Connector_Init() -> None:
60+
"""Test that default Connector __init__ sets properties properly."""
61+
default_connector = connector.Connector()
62+
assert default_connector._ip_type == IPTypes.PUBLIC
63+
assert default_connector._enable_iam_auth is False
64+
assert default_connector._timeout == 30
65+
assert default_connector._credentials is None

tests/unit/test_instance_connection_manager.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,18 @@
2929

3030

3131
@pytest.fixture
32-
def mock_credentials() -> Credentials:
32+
def credentials() -> Credentials:
3333
return Mock(spec=Credentials)
3434

3535

3636
@pytest.fixture
3737
def icm(
38-
async_loop: asyncio.AbstractEventLoop, connect_string: str
38+
async_loop: asyncio.AbstractEventLoop,
3939
) -> InstanceConnectionManager:
4040
keys = asyncio.run_coroutine_threadsafe(generate_keys(), async_loop)
41-
icm = InstanceConnectionManager(connect_string, "pymysql", keys, async_loop)
42-
41+
icm = InstanceConnectionManager(
42+
"my-project:my-region:my-instance", "pymysql", keys, async_loop
43+
)
4344
return icm
4445

4546

@@ -155,6 +156,12 @@ async def test_perform_refresh_replaces_invalid_result(
155156
# allow more frequent refreshes for tests
156157
setattr(icm, "_refresh_rate_limiter", test_rate_limiter)
157158

159+
# set current to valid MockMetadata instance
160+
setattr(icm, "_get_instance_data", _get_metadata_success)
161+
icm._current = asyncio.run_coroutine_threadsafe(
162+
icm._perform_refresh(), icm._loop
163+
).result(timeout=10)
164+
158165
# stub _get_instance_data to throw an error
159166
setattr(icm, "_get_instance_data", _get_metadata_error)
160167
icm._current = asyncio.run_coroutine_threadsafe(
@@ -197,7 +204,7 @@ async def test_force_refresh_cancels_pending_refresh(
197204

198205

199206
def test_auth_init_with_credentials_object(
200-
icm: InstanceConnectionManager, mock_credentials: Credentials
207+
icm: InstanceConnectionManager, credentials: Credentials
201208
) -> None:
202209
"""
203210
Test that InstanceConnectionManager's _auth_init initializes _credentials
@@ -207,22 +214,22 @@ def test_auth_init_with_credentials_object(
207214
with patch(
208215
"google.cloud.sql.connector.instance_connection_manager.with_scopes_if_required"
209216
) as mock_auth:
210-
mock_auth.return_value = mock_credentials
211-
icm._auth_init(credentials=mock_credentials)
217+
mock_auth.return_value = credentials
218+
icm._auth_init(credentials=credentials)
212219
assert isinstance(icm._credentials, Credentials)
213220
mock_auth.assert_called_once()
214221

215222

216223
def test_auth_init_with_default_credentials(
217-
icm: InstanceConnectionManager, mock_credentials: Credentials
224+
icm: InstanceConnectionManager, credentials: Credentials
218225
) -> None:
219226
"""
220227
Test that InstanceConnectionManager's _auth_init initializes _credentials
221228
with application default credentials when credentials are not specified.
222229
"""
223230
setattr(icm, "_credentials", None)
224231
with patch("google.auth.default") as mock_auth:
225-
mock_auth.return_value = mock_credentials, None
232+
mock_auth.return_value = credentials, None
226233
icm._auth_init(credentials=None)
227234
assert isinstance(icm._credentials, Credentials)
228235
mock_auth.assert_called_once()

tests/unit/test_refresh_utils.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,75 @@
1616
from typing import Any
1717

1818
import aiohttp
19-
import google.auth
19+
from google.auth.credentials import Credentials
20+
import json
2021
import pytest # noqa F401 Needed to run the tests
22+
from mock import AsyncMock, Mock, patch
2123

2224
from google.cloud.sql.connector.refresh_utils import _get_ephemeral, _get_metadata
2325
from google.cloud.sql.connector.utils import generate_keys
2426

2527

28+
class FakeClientSessionGet:
29+
"""Helper class to return mock data for get request."""
30+
31+
async def text(self) -> str:
32+
response = {
33+
"kind": "sql#connectSettings",
34+
"serverCaCert": {
35+
"kind": "sql#sslCert",
36+
"certSerialNumber": "0",
37+
"cert": "-----BEGIN CERTIFICATE-----\nabc123\n-----END CERTIFICATE-----",
38+
"commonName": "Google",
39+
"sha1Fingerprint": "abc",
40+
"instance": "my-instance",
41+
"createTime": "2021-10-18T18:48:03.785Z",
42+
"expirationTime": "2031-10-16T18:49:03.785Z",
43+
},
44+
"ipAddresses": [
45+
{"type": "PRIMARY", "ipAddress": "0.0.0.0"},
46+
{"type": "PRIVATE", "ipAddress": "1.0.0.0"},
47+
],
48+
"region": "my-region",
49+
"databaseVersion": "MYSQL_8_0",
50+
"backendType": "SECOND_GEN",
51+
}
52+
return json.dumps(response)
53+
54+
55+
class FakeClientSessionPost:
56+
"""Helper class to return mock data for post request."""
57+
58+
async def text(self) -> str:
59+
response = {
60+
"ephemeralCert": {
61+
"kind": "sql#sslCert",
62+
"certSerialNumber": "",
63+
"cert": "-----BEGIN CERTIFICATE-----\nabc123\n-----END CERTIFICATE-----",
64+
}
65+
}
66+
return json.dumps(response)
67+
68+
69+
@pytest.fixture
70+
def credentials() -> Credentials:
71+
credentials = Mock(spec=Credentials)
72+
credentials.valid = True
73+
credentials.token = "12345"
74+
return credentials
75+
76+
2677
@pytest.mark.asyncio
27-
async def test_get_ephemeral(connect_string: str) -> None:
78+
@patch("aiohttp.ClientSession.post", new_callable=AsyncMock)
79+
async def test_get_ephemeral(mock_post: AsyncMock, credentials: Credentials) -> None:
2880
"""
29-
Test to check whether _get_ephemeral runs without problems given a valid
30-
connection string.
81+
Test to check whether _get_ephemeral runs without problems given valid
82+
parameters.
3183
"""
84+
mock_post.return_value = FakeClientSessionPost()
3285

33-
project = connect_string.split(":")[0]
34-
instance = connect_string.split(":")[2]
35-
36-
credentials, project = google.auth.default(
37-
scopes=[
38-
"https://www.googleapis.com/auth/sqlservice.admin",
39-
"https://www.googleapis.com/auth/cloud-platform",
40-
]
41-
)
86+
project = "my-project"
87+
instance = "my-instance"
4288

4389
_, pub_key = await generate_keys()
4490

@@ -56,21 +102,16 @@ async def test_get_ephemeral(connect_string: str) -> None:
56102

57103

58104
@pytest.mark.asyncio
59-
async def test_get_metadata(connect_string: str) -> None:
105+
@patch("aiohttp.ClientSession.get", new_callable=AsyncMock)
106+
async def test_get_metadata(mock_get: AsyncMock, credentials: Credentials) -> None:
60107
"""
61-
Test to check whether _get_metadata runs without problems given a valid
62-
connection string.
108+
Test to check whether _get_metadata runs without problems given valid
109+
parameters.
63110
"""
111+
mock_get.return_value = FakeClientSessionGet()
64112

65-
project = connect_string.split(":")[0]
66-
instance = connect_string.split(":")[2]
67-
68-
credentials, project = google.auth.default(
69-
scopes=[
70-
"https://www.googleapis.com/auth/sqlservice.admin",
71-
"https://www.googleapis.com/auth/cloud-platform",
72-
]
73-
)
113+
project = "my-project"
114+
instance = "my-instance"
74115

75116
async with aiohttp.ClientSession() as client_session:
76117
result = await _get_metadata(client_session, credentials, project, instance)

0 commit comments

Comments
 (0)