Skip to content

Commit 432c1e3

Browse files
committed
test: Fix pymysql tests to make database content predicatable
Signed-off-by: Ferenc Géczi <[email protected]>
1 parent f0eed1e commit 432c1e3

File tree

1 file changed

+43
-53
lines changed

1 file changed

+43
-53
lines changed

tests/clients/test_pymysql.py

Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,56 +12,44 @@
1212

1313
logger = logging.getLogger(__name__)
1414

15-
create_table_query = 'CREATE TABLE IF NOT EXISTS users(id serial primary key, \
16-
name varchar(40) NOT NULL, email varchar(40) NOT NULL)'
17-
18-
create_proc_query = """
19-
CREATE PROCEDURE test_proc(IN t VARCHAR(255))
20-
BEGIN
21-
SELECT name FROM users WHERE name = t;
22-
END
23-
"""
24-
25-
db = pymysql.connect(host=testenv['mysql_host'], port=testenv['mysql_port'],
26-
user=testenv['mysql_user'], passwd=testenv['mysql_pw'],
27-
db=testenv['mysql_db'])
28-
29-
cursor = db.cursor()
30-
cursor.execute(create_table_query)
31-
32-
while cursor.nextset() is not None:
33-
pass
34-
35-
cursor.execute('DROP PROCEDURE IF EXISTS test_proc')
36-
37-
while cursor.nextset() is not None:
38-
pass
39-
40-
cursor.execute(create_proc_query)
41-
42-
while cursor.nextset() is not None:
43-
pass
44-
45-
cursor.close()
46-
db.close()
47-
4815

4916
class TestPyMySQL(unittest.TestCase):
5017
def setUp(self):
5118
self.db = pymysql.connect(host=testenv['mysql_host'], port=testenv['mysql_port'],
5219
user=testenv['mysql_user'], passwd=testenv['mysql_pw'],
5320
db=testenv['mysql_db'])
21+
database_setup_query = """
22+
DROP TABLE IF EXISTS users; |
23+
CREATE TABLE users(
24+
id serial primary key,
25+
name varchar(40) NOT NULL,
26+
email varchar(40) NOT NULL
27+
); |
28+
INSERT INTO users(name, email) VALUES('kermit', '[email protected]'); |
29+
DROP PROCEDURE IF EXISTS test_proc; |
30+
CREATE PROCEDURE test_proc(IN t VARCHAR(255))
31+
BEGIN
32+
SELECT name FROM users WHERE name = t;
33+
END
34+
"""
35+
setup_cursor = self.db.cursor()
36+
for s in database_setup_query.split('|'):
37+
setup_cursor.execute(s)
38+
5439
self.cursor = self.db.cursor()
5540
self.recorder = tracer.recorder
5641
self.recorder.clear_spans()
5742
tracer.cur_ctx = None
5843

5944
def tearDown(self):
60-
""" Do nothing for now """
61-
return None
45+
if self.cursor and self.cursor.connection.open:
46+
self.cursor.close()
47+
if self.db and self.db.open:
48+
self.db.close()
6249

6350
def test_vanilla_query(self):
64-
self.cursor.execute("""SELECT * from users""")
51+
affected_rows = self.cursor.execute("""SELECT * from users""")
52+
self.assertEqual(1, affected_rows)
6553
result = self.cursor.fetchone()
6654
self.assertEqual(3, len(result))
6755

@@ -70,10 +58,11 @@ def test_vanilla_query(self):
7058

7159
def test_basic_query(self):
7260
with tracer.start_active_span('test'):
73-
result = self.cursor.execute("""SELECT * from users""")
74-
self.cursor.fetchone()
61+
affected_rows = self.cursor.execute("""SELECT * from users""")
62+
result = self.cursor.fetchone()
7563

76-
self.assertTrue(result >= 0)
64+
self.assertEqual(1, affected_rows)
65+
self.assertEqual(3, len(result))
7766

7867
spans = self.recorder.queued_spans()
7968
self.assertEqual(2, len(spans))
@@ -95,10 +84,11 @@ def test_basic_query(self):
9584

9685
def test_query_with_params(self):
9786
with tracer.start_active_span('test'):
98-
result = self.cursor.execute("""SELECT * from users where id=1""")
99-
self.cursor.fetchone()
87+
affected_rows = self.cursor.execute("""SELECT * from users where id=1""")
88+
result = self.cursor.fetchone()
10089

101-
self.assertTrue(result >= 0)
90+
self.assertEqual(1, affected_rows)
91+
self.assertEqual(3, len(result))
10292

10393
spans = self.recorder.queued_spans()
10494
self.assertEqual(2, len(spans))
@@ -120,11 +110,11 @@ def test_query_with_params(self):
120110

121111
def test_basic_insert(self):
122112
with tracer.start_active_span('test'):
123-
result = self.cursor.execute(
113+
affected_rows = self.cursor.execute(
124114
"""INSERT INTO users(name, email) VALUES(%s, %s)""",
125115
('beaker', '[email protected]'))
126116

127-
self.assertEqual(1, result)
117+
self.assertEqual(1, affected_rows)
128118

129119
spans = self.recorder.queued_spans()
130120
self.assertEqual(2, len(spans))
@@ -146,11 +136,11 @@ def test_basic_insert(self):
146136

147137
def test_executemany(self):
148138
with tracer.start_active_span('test'):
149-
result = self.cursor.executemany("INSERT INTO users(name, email) VALUES(%s, %s)",
139+
affected_rows = self.cursor.executemany("INSERT INTO users(name, email) VALUES(%s, %s)",
150140
[('beaker', '[email protected]'), ('beaker', '[email protected]')])
151141
self.db.commit()
152142

153-
self.assertEqual(2, result)
143+
self.assertEqual(2, affected_rows)
154144

155145
spans = self.recorder.queued_spans()
156146
self.assertEqual(2, len(spans))
@@ -172,9 +162,9 @@ def test_executemany(self):
172162

173163
def test_call_proc(self):
174164
with tracer.start_active_span('test'):
175-
result = self.cursor.callproc('test_proc', ('beaker',))
165+
callproc_result = self.cursor.callproc('test_proc', ('beaker',))
176166

177-
self.assertTrue(result)
167+
self.assertIsInstance(callproc_result, tuple)
178168

179169
spans = self.recorder.queued_spans()
180170
self.assertEqual(2, len(spans))
@@ -195,15 +185,14 @@ def test_call_proc(self):
195185
self.assertEqual(db_span.data["mysql"]["port"], testenv['mysql_port'])
196186

197187
def test_error_capture(self):
198-
result = None
188+
affected_rows = None
199189
try:
200190
with tracer.start_active_span('test'):
201-
result = self.cursor.execute("""SELECT * from blah""")
202-
self.cursor.fetchone()
191+
affected_rows = self.cursor.execute("""SELECT * from blah""")
203192
except Exception:
204193
pass
205194

206-
self.assertIsNone(result)
195+
self.assertIsNone(affected_rows)
207196

208197
spans = self.recorder.queued_spans()
209198
self.assertEqual(2, len(spans))
@@ -228,8 +217,9 @@ def test_connect_cursor_ctx_mgr(self):
228217
with tracer.start_active_span("test"):
229218
with self.db as connection:
230219
with connection.cursor() as cursor:
231-
cursor.execute("""SELECT * from users""")
220+
affected_rows = cursor.execute("""SELECT * from users""")
232221

222+
self.assertEqual(1, affected_rows)
233223
spans = self.recorder.queued_spans()
234224
self.assertEqual(2, len(spans))
235225

0 commit comments

Comments
 (0)