12
12
13
13
logger = logging .getLogger (__name__ )
14
14
15
- create_table_query = 'CREATE TABLE IF NOT EXISTS users(id serial primary key, \
16
- name varchar(40) NOT NULL, email varchar(40) NOT NULL)'
17
-
18
- create_proc_query = """
19
- CREATE PROCEDURE test_proc(IN t VARCHAR(255))
20
- BEGIN
21
- SELECT name FROM users WHERE name = t;
22
- END
23
- """
24
-
25
- db = pymysql .connect (host = testenv ['mysql_host' ], port = testenv ['mysql_port' ],
26
- user = testenv ['mysql_user' ], passwd = testenv ['mysql_pw' ],
27
- db = testenv ['mysql_db' ])
28
-
29
- cursor = db .cursor ()
30
- cursor .execute (create_table_query )
31
-
32
- while cursor .nextset () is not None :
33
- pass
34
-
35
- cursor .execute ('DROP PROCEDURE IF EXISTS test_proc' )
36
-
37
- while cursor .nextset () is not None :
38
- pass
39
-
40
- cursor .execute (create_proc_query )
41
-
42
- while cursor .nextset () is not None :
43
- pass
44
-
45
- cursor .close ()
46
- db .close ()
47
-
48
15
49
16
class TestPyMySQL (unittest .TestCase ):
50
17
def setUp (self ):
51
18
self .db = pymysql .connect (host = testenv ['mysql_host' ], port = testenv ['mysql_port' ],
52
19
user = testenv ['mysql_user' ], passwd = testenv ['mysql_pw' ],
53
20
db = testenv ['mysql_db' ])
21
+ database_setup_query = """
22
+ DROP TABLE IF EXISTS users; |
23
+ CREATE TABLE users(
24
+ id serial primary key,
25
+ name varchar(40) NOT NULL,
26
+ email varchar(40) NOT NULL
27
+ ); |
28
+ INSERT INTO users(name, email) VALUES('kermit', '[email protected] '); |
29
+ DROP PROCEDURE IF EXISTS test_proc; |
30
+ CREATE PROCEDURE test_proc(IN t VARCHAR(255))
31
+ BEGIN
32
+ SELECT name FROM users WHERE name = t;
33
+ END
34
+ """
35
+ setup_cursor = self .db .cursor ()
36
+ for s in database_setup_query .split ('|' ):
37
+ setup_cursor .execute (s )
38
+
54
39
self .cursor = self .db .cursor ()
55
40
self .recorder = tracer .recorder
56
41
self .recorder .clear_spans ()
57
42
tracer .cur_ctx = None
58
43
59
44
def tearDown (self ):
60
- """ Do nothing for now """
61
- return None
45
+ if self .cursor and self .cursor .connection .open :
46
+ self .cursor .close ()
47
+ if self .db and self .db .open :
48
+ self .db .close ()
62
49
63
50
def test_vanilla_query (self ):
64
- self .cursor .execute ("""SELECT * from users""" )
51
+ affected_rows = self .cursor .execute ("""SELECT * from users""" )
52
+ self .assertEqual (1 , affected_rows )
65
53
result = self .cursor .fetchone ()
66
54
self .assertEqual (3 , len (result ))
67
55
@@ -70,10 +58,11 @@ def test_vanilla_query(self):
70
58
71
59
def test_basic_query (self ):
72
60
with tracer .start_active_span ('test' ):
73
- result = self .cursor .execute ("""SELECT * from users""" )
74
- self .cursor .fetchone ()
61
+ affected_rows = self .cursor .execute ("""SELECT * from users""" )
62
+ result = self .cursor .fetchone ()
75
63
76
- self .assertTrue (result >= 0 )
64
+ self .assertEqual (1 , affected_rows )
65
+ self .assertEqual (3 , len (result ))
77
66
78
67
spans = self .recorder .queued_spans ()
79
68
self .assertEqual (2 , len (spans ))
@@ -95,10 +84,11 @@ def test_basic_query(self):
95
84
96
85
def test_query_with_params (self ):
97
86
with tracer .start_active_span ('test' ):
98
- result = self .cursor .execute ("""SELECT * from users where id=1""" )
99
- self .cursor .fetchone ()
87
+ affected_rows = self .cursor .execute ("""SELECT * from users where id=1""" )
88
+ result = self .cursor .fetchone ()
100
89
101
- self .assertTrue (result >= 0 )
90
+ self .assertEqual (1 , affected_rows )
91
+ self .assertEqual (3 , len (result ))
102
92
103
93
spans = self .recorder .queued_spans ()
104
94
self .assertEqual (2 , len (spans ))
@@ -120,11 +110,11 @@ def test_query_with_params(self):
120
110
121
111
def test_basic_insert (self ):
122
112
with tracer .start_active_span ('test' ):
123
- result = self .cursor .execute (
113
+ affected_rows = self .cursor .execute (
124
114
"""INSERT INTO users(name, email) VALUES(%s, %s)""" ,
125
115
126
116
127
- self .assertEqual (1 , result )
117
+ self .assertEqual (1 , affected_rows )
128
118
129
119
spans = self .recorder .queued_spans ()
130
120
self .assertEqual (2 , len (spans ))
@@ -146,11 +136,11 @@ def test_basic_insert(self):
146
136
147
137
def test_executemany (self ):
148
138
with tracer .start_active_span ('test' ):
149
- result = self .cursor .executemany ("INSERT INTO users(name, email) VALUES(%s, %s)" ,
139
+ affected_rows = self .cursor .executemany ("INSERT INTO users(name, email) VALUES(%s, %s)" ,
150
140
151
141
self .db .commit ()
152
142
153
- self .assertEqual (2 , result )
143
+ self .assertEqual (2 , affected_rows )
154
144
155
145
spans = self .recorder .queued_spans ()
156
146
self .assertEqual (2 , len (spans ))
@@ -172,9 +162,9 @@ def test_executemany(self):
172
162
173
163
def test_call_proc (self ):
174
164
with tracer .start_active_span ('test' ):
175
- result = self .cursor .callproc ('test_proc' , ('beaker' ,))
165
+ callproc_result = self .cursor .callproc ('test_proc' , ('beaker' ,))
176
166
177
- self .assertTrue ( result )
167
+ self .assertIsInstance ( callproc_result , tuple )
178
168
179
169
spans = self .recorder .queued_spans ()
180
170
self .assertEqual (2 , len (spans ))
@@ -195,15 +185,14 @@ def test_call_proc(self):
195
185
self .assertEqual (db_span .data ["mysql" ]["port" ], testenv ['mysql_port' ])
196
186
197
187
def test_error_capture (self ):
198
- result = None
188
+ affected_rows = None
199
189
try :
200
190
with tracer .start_active_span ('test' ):
201
- result = self .cursor .execute ("""SELECT * from blah""" )
202
- self .cursor .fetchone ()
191
+ affected_rows = self .cursor .execute ("""SELECT * from blah""" )
203
192
except Exception :
204
193
pass
205
194
206
- self .assertIsNone (result )
195
+ self .assertIsNone (affected_rows )
207
196
208
197
spans = self .recorder .queued_spans ()
209
198
self .assertEqual (2 , len (spans ))
@@ -228,8 +217,9 @@ def test_connect_cursor_ctx_mgr(self):
228
217
with tracer .start_active_span ("test" ):
229
218
with self .db as connection :
230
219
with connection .cursor () as cursor :
231
- cursor .execute ("""SELECT * from users""" )
220
+ affected_rows = cursor .execute ("""SELECT * from users""" )
232
221
222
+ self .assertEqual (1 , affected_rows )
233
223
spans = self .recorder .queued_spans ()
234
224
self .assertEqual (2 , len (spans ))
235
225
0 commit comments