14
14
15
15
logger = logging .getLogger (__name__ )
16
16
17
- create_table_query = """
18
- CREATE TABLE IF NOT EXISTS users(
19
- id serial PRIMARY KEY,
20
- name VARCHAR (50),
21
- password VARCHAR (50),
22
- email VARCHAR (355),
23
- created_on TIMESTAMP,
24
- last_login TIMESTAMP
25
- );
26
- """
27
-
28
- create_proc_query = """\
29
- CREATE OR REPLACE FUNCTION test_proc(candidate VARCHAR(70))
30
- RETURNS text AS $$
31
- BEGIN
32
- RETURN(SELECT name FROM users where email = candidate);
33
- END;
34
- $$ LANGUAGE plpgsql;
35
- """
36
-
37
- drop_proc_query = "DROP FUNCTION IF EXISTS test_proc(VARCHAR(70));"
38
-
39
- db = psycopg2 .connect (host = testenv ['postgresql_host' ], port = testenv ['postgresql_port' ],
40
- user = testenv ['postgresql_user' ], password = testenv ['postgresql_pw' ],
41
- database = testenv ['postgresql_db' ])
42
-
43
- cursor = db .cursor ()
44
- cursor .execute (create_table_query )
45
- cursor .execute (drop_proc_query )
46
- cursor .execute (create_proc_query )
47
- db .commit ()
48
- cursor .close ()
49
- db .close ()
50
-
51
17
52
18
class TestPsycoPG2 (unittest .TestCase ):
53
19
def setUp (self ):
54
20
self .db = psycopg2 .connect (host = testenv ['postgresql_host' ], port = testenv ['postgresql_port' ],
55
21
user = testenv ['postgresql_user' ], password = testenv ['postgresql_pw' ],
56
22
database = testenv ['postgresql_db' ])
23
+
24
+ database_setup_query = """
25
+ DROP TABLE IF EXISTS users;
26
+ CREATE TABLE users(
27
+ id serial PRIMARY KEY,
28
+ name VARCHAR (50),
29
+ password VARCHAR (50),
30
+ email VARCHAR (355),
31
+ created_on TIMESTAMP,
32
+ last_login TIMESTAMP
33
+ );
34
+ INSERT INTO users(name, email) VALUES('kermit', '[email protected] ');
35
+ DROP FUNCTION IF EXISTS test_proc(VARCHAR(70));
36
+ CREATE FUNCTION test_proc(candidate VARCHAR(70))
37
+ RETURNS text AS $$
38
+ BEGIN
39
+ RETURN(SELECT name FROM users where email = candidate);
40
+ END;
41
+ $$ LANGUAGE plpgsql;
42
+ """
43
+ cursor = self .db .cursor ()
44
+ cursor .execute (database_setup_query )
45
+ self .db .commit ()
46
+ cursor .close ()
47
+
48
+
57
49
self .cursor = self .db .cursor ()
58
50
self .recorder = tracer .recorder
59
51
self .recorder .clear_spans ()
60
52
tracer .cur_ctx = None
61
53
62
54
def tearDown (self ):
63
- """ Do nothing for now """
64
- return None
55
+ if self .cursor and not self .cursor .connection .closed :
56
+ self .cursor .close ()
57
+ if self .db and not self .db .closed :
58
+ self .db .close ()
65
59
66
60
def test_vanilla_query (self ):
67
61
self .assertTrue (psycopg2 .extras .register_uuid (None , self .db ))
68
62
self .assertTrue (psycopg2 .extras .register_uuid (None , self .db .cursor ()))
69
63
70
64
self .cursor .execute ("""SELECT * from users""" )
65
+ affected_rows = self .cursor .rowcount
66
+ self .assertEqual (1 , affected_rows )
71
67
result = self .cursor .fetchone ()
72
68
73
69
self .assertEqual (6 , len (result ))
@@ -78,9 +74,13 @@ def test_vanilla_query(self):
78
74
def test_basic_query (self ):
79
75
with tracer .start_active_span ('test' ):
80
76
self .cursor .execute ("""SELECT * from users""" )
81
- self .cursor .fetchone ()
77
+ affected_rows = self .cursor .rowcount
78
+ result = self .cursor .fetchone ()
82
79
self .db .commit ()
83
80
81
+ self .assertEqual (1 , affected_rows )
82
+ self .assertEqual (6 , len (result ))
83
+
84
84
spans = self .recorder .queued_spans ()
85
85
self .assertEqual (2 , len (spans ))
86
86
@@ -102,6 +102,9 @@ def test_basic_query(self):
102
102
def test_basic_insert (self ):
103
103
with tracer .start_active_span ('test' ):
104
104
self .
cursor .
execute (
"""INSERT INTO users(name, email) VALUES(%s, %s)""" , (
'beaker' ,
'[email protected] ' ))
105
+ affected_rows = self .cursor .rowcount
106
+
107
+ self .assertEqual (1 , affected_rows )
105
108
106
109
spans = self .recorder .queued_spans ()
107
110
self .assertEqual (2 , len (spans ))
@@ -123,10 +126,13 @@ def test_basic_insert(self):
123
126
124
127
def test_executemany (self ):
125
128
with tracer .start_active_span ('test' ):
126
- result = self .cursor .executemany ("INSERT INTO users(name, email) VALUES(%s, %s)" ,
127
-
129
+ self .cursor .executemany ("INSERT INTO users(name, email) VALUES(%s, %s)" ,
130
+
131
+ affected_rows = self .cursor .rowcount
128
132
self .db .commit ()
129
133
134
+ self .assertEqual (2 , affected_rows )
135
+
130
136
spans = self .recorder .queued_spans ()
131
137
self .assertEqual (2 , len (spans ))
132
138
@@ -147,9 +153,9 @@ def test_executemany(self):
147
153
148
154
def test_call_proc (self ):
149
155
with tracer .start_active_span ('test' ):
150
- result = self .cursor .callproc ('test_proc' , ('beaker' ,))
156
+ callproc_result = self .cursor .callproc ('test_proc' , ('beaker' ,))
151
157
152
- self .assertIsInstance (result , tuple )
158
+ self .assertIsInstance (callproc_result , tuple )
153
159
154
160
spans = self .recorder .queued_spans ()
155
161
self .assertEqual (2 , len (spans ))
@@ -170,14 +176,16 @@ def test_call_proc(self):
170
176
self .assertEqual (db_span .data ["pg" ]["port" ], testenv ['postgresql_port' ])
171
177
172
178
def test_error_capture (self ):
173
- result = None
179
+ affected_rows = result = None
174
180
try :
175
181
with tracer .start_active_span ('test' ):
176
- result = self .cursor .execute ("""SELECT * from blah""" )
182
+ self .cursor .execute ("""SELECT * from blah""" )
183
+ affected_rows = self .cursor .rowcount
177
184
self .cursor .fetchone ()
178
185
except Exception :
179
186
pass
180
187
188
+ self .assertIsNone (affected_rows )
181
189
self .assertIsNone (result )
182
190
183
191
spans = self .recorder .queued_spans ()
@@ -246,6 +254,11 @@ def test_connect_cursor_ctx_mgr(self):
246
254
with self .db as connection :
247
255
with connection .cursor () as cursor :
248
256
cursor .execute ("""SELECT * from users""" )
257
+ affected_rows = cursor .rowcount
258
+ result = cursor .fetchone ()
259
+
260
+ self .assertEqual (1 , affected_rows )
261
+ self .assertEqual (6 , len (result ))
249
262
250
263
spans = self .recorder .queued_spans ()
251
264
self .assertEqual (2 , len (spans ))
@@ -270,6 +283,11 @@ def test_connect_ctx_mgr(self):
270
283
with self .db as connection :
271
284
cursor = connection .cursor ()
272
285
cursor .execute ("""SELECT * from users""" )
286
+ affected_rows = cursor .rowcount
287
+ result = cursor .fetchone ()
288
+
289
+ self .assertEqual (1 , affected_rows )
290
+ self .assertEqual (6 , len (result ))
273
291
274
292
spans = self .recorder .queued_spans ()
275
293
self .assertEqual (2 , len (spans ))
@@ -294,6 +312,11 @@ def test_cursor_ctx_mgr(self):
294
312
connection = self .db
295
313
with connection .cursor () as cursor :
296
314
cursor .execute ("""SELECT * from users""" )
315
+ affected_rows = cursor .rowcount
316
+ result = cursor .fetchone ()
317
+
318
+ self .assertEqual (1 , affected_rows )
319
+ self .assertEqual (6 , len (result ))
297
320
298
321
spans = self .recorder .queued_spans ()
299
322
self .assertEqual (2 , len (spans ))
0 commit comments