Skip to content

Commit 304a8fc

Browse files
committed
Fix some tests.
1 parent e4b3b21 commit 304a8fc

File tree

9 files changed

+202
-184
lines changed

9 files changed

+202
-184
lines changed

Diff for: runtests.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#!/usr/bin/env python
22
import unittest
33

4-
from tornado_mysql import PYPY, JYTHON, IRONPYTHON
4+
from tornado_mysql._compat import PYPY, JYTHON, IRONPYTHON
5+
import tornado_mysql.tests
56

67
if not (PYPY or JYTHON or IRONPYTHON):
78
import atexit
@@ -22,5 +23,4 @@ def report_uncollectable():
2223
print("referrer:", ref)
2324
print('---')
2425

25-
import pymysql.tests
26-
unittest.main(pymysql.tests)
26+
unittest.main(tornado_mysql.tests)

Diff for: setup.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
#!/usr/bin/env python
22
from setuptools import setup, find_packages
33

4-
version_tuple = __import__('tornado_mysql').VERSION
5-
6-
if version_tuple[3] is not None:
7-
version = "%d.%d.%d_%s" % version_tuple
8-
else:
9-
version = "%d.%d.%d" % version_tuple[:3]
4+
version = "0.1dev1"
105

116
try:
127
with open('README.rst') as f:
@@ -15,13 +10,11 @@
1510
readme = ''
1611

1712
setup(
18-
name="PyMySQL",
13+
name="Tornado-MySQL",
1914
version=version,
2015
url='https://github.com/PyMySQL/Tornado-MySQL',
21-
author='yutaka.matsubara',
22-
author_email='[email protected]',
23-
maintainer='INADA Naoki',
24-
maintainer_email='[email protected]',
16+
author='INADA Naoki',
17+
author_email='[email protected]',
2518
description='Pure-Python MySQL Driver for Tornado',
2619
install_requires=['tornado>=4.0'],
2720
long_description=readme,
@@ -39,5 +32,5 @@
3932
'Intended Audience :: Developers',
4033
'License :: OSI Approved :: MIT License',
4134
'Topic :: Database',
42-
]
35+
],
4336
)

Diff for: tornado_mysql/connections.py

+25-21
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def close(self):
576576
@gen.coroutine
577577
def close_async(self):
578578
send_data = struct.pack('<i', 1) + int2byte(COM_QUIT)
579-
yield stream.write(send_data)
579+
yield self._stream.write(send_data)
580580
self.close()
581581

582582
@property
@@ -632,6 +632,7 @@ def rollback(self):
632632
yield self._execute_command(COM_QUERY, "ROLLBACK")
633633
yield self._read_ok_packet()
634634

635+
@gen.coroutine
635636
def select_db(self, db):
636637
'''Set current db'''
637638
yield self._execute_command(COM_INIT_DB, db)
@@ -673,7 +674,7 @@ def query(self, sql, unbuffered=False):
673674
@gen.coroutine
674675
def next_result(self):
675676
yield self._read_query_result()
676-
return self._affected_rows
677+
raise gen.Return(self._affected_rows)
677678

678679
def affected_rows(self):
679680
return self._affected_rows
@@ -687,7 +688,7 @@ def kill(self, thread_id):
687688
@gen.coroutine
688689
def ping(self, reconnect=True):
689690
"""Check if the server is alive"""
690-
if self.socket is None:
691+
if self._stream is None:
691692
if reconnect:
692693
yield self.connect()
693694
reconnect = False
@@ -762,25 +763,28 @@ def _read_packet(self, packet_type=MysqlPacket):
762763
and return a MysqlPacket type that represents the results.
763764
"""
764765
buff = b''
765-
while True:
766-
packet_header = yield self._stream.read_bytes(4)
767-
if DEBUG: dump_packet(packet_header)
768-
packet_length_bin = packet_header[:3]
769-
770-
#TODO: check sequence id
771-
# packet_number
772-
byte2int(packet_header[3])
773-
774-
bin_length = packet_length_bin + b'\0' # pad little-endian number
775-
bytes_to_read = struct.unpack('<I', bin_length)[0]
776-
recv_data = yield self._stream.read_bytes(bytes_to_read)
777-
if DEBUG: dump_packet(recv_data)
778-
buff += recv_data
779-
if bytes_to_read < MAX_PACKET_LEN:
780-
break
766+
try:
767+
while True:
768+
packet_header = yield self._stream.read_bytes(4)
769+
if DEBUG: dump_packet(packet_header)
770+
packet_length_bin = packet_header[:3]
771+
772+
#TODO: check sequence id
773+
# packet_number
774+
byte2int(packet_header[3])
775+
776+
bin_length = packet_length_bin + b'\0' # pad little-endian number
777+
bytes_to_read = struct.unpack('<I', bin_length)[0]
778+
recv_data = yield self._stream.read_bytes(bytes_to_read)
779+
if DEBUG: dump_packet(recv_data)
780+
buff += recv_data
781+
if bytes_to_read < MAX_PACKET_LEN:
782+
break
783+
except iostream.StreamClosedError as e:
784+
raise OperationalError(2006, "MySQL server has gone away (%s)" % (e,))
781785
packet = packet_type(buff, self.encoding)
782786
packet.check_error()
783-
return packet
787+
raise gen.Return(packet)
784788

785789
def _write_bytes(self, data):
786790
return self._stream.write(data)
@@ -1001,7 +1005,7 @@ def read(self):
10011005
if first_packet.is_ok_packet():
10021006
self._read_ok_packet(first_packet)
10031007
else:
1004-
self._read_result_packet(first_packet)
1008+
yield self._read_result_packet(first_packet)
10051009
finally:
10061010
self.connection = None
10071011

Diff for: tornado_mysql/cursors.py

+21-27
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def ensure_bytes(x):
134134

135135
yield self._query(query)
136136
self._executed = query
137+
raise gen.Return(self.rowcount)
137138

138139
@gen.coroutine
139140
def executemany(self, query, args):
@@ -159,6 +160,7 @@ def executemany(self, query, args):
159160
yield self.execute(query, arg)
160161
rows += self.rowcount
161162
self.rowcount = rows
163+
raise gen.Return(self.rowcount)
162164

163165
@gen.coroutine
164166
def _do_execute_many(self, prefix, values, args, max_stmt_length, encoding):
@@ -179,14 +181,13 @@ def _do_execute_many(self, prefix, values, args, max_stmt_length, encoding):
179181
v = v.encode(encoding)
180182
if len(sql) + len(v) + 1 > max_stmt_length:
181183
print(sql)
182-
yield self.execute(sql)
184+
yield self.execute(bytes(sql))
183185
rows += self.rowcount
184186
sql = bytearray(prefix)
185187
else:
186188
sql += b','
187189
sql += v
188-
print(sql)
189-
yield self.execute(sql)
190+
yield self.execute(bytes(sql))
190191
rows += self.rowcount
191192
self.rowcount = rows
192193

@@ -313,9 +314,8 @@ class DictCursorMixin(object):
313314
# You can override this to use OrderedDict or other dict-like types.
314315
dict_type = dict
315316

316-
@gen.coroutine
317317
def _do_get_result(self):
318-
yield super(DictCursorMixin, self)._do_get_result()
318+
super(DictCursorMixin, self)._do_get_result()
319319
fields = []
320320
if self.description:
321321
for f in self._result.fields:
@@ -378,7 +378,7 @@ def _query(self, q):
378378
self._last_executed = q
379379
yield conn.query(q, unbuffered=True)
380380
yield self._do_get_result()
381-
return self.rowcount
381+
raise gen.Return(self.rowcount)
382382

383383
@gen.coroutine
384384
def read_next(self):
@@ -393,45 +393,39 @@ def fetchone(self):
393393
self._check_executed()
394394
row = yield self.read_next()
395395
if row is None:
396-
return None
396+
raise gen.Return()
397397
self.rownumber += 1
398-
return row
398+
raise gen.Return(row)
399399

400+
@gen.coroutine
400401
def fetchall(self):
401402
"""
402403
Fetch all, as per MySQLdb. Pretty useless for large queries, as
403-
it is buffered. See fetchall_unbuffered(), if you want an unbuffered
404-
generator version of this method.
405-
406-
"""
407-
return list(self.fetchall_unbuffered())
408-
409-
def fetchall_unbuffered(self):
410-
"""
411-
Fetch all, implemented as a generator, which isn't to standard,
412-
however, it doesn't make sense to return everything in a list, as that
413-
would use ridiculous memory for large result sets.
404+
it is buffered.
414405
"""
415-
return iter(self.fetchone, None)
416-
417-
def __iter__(self):
418-
return self.fetchall_unbuffered()
406+
rows = []
407+
while True:
408+
row = yield self.fetchone()
409+
if row is None:
410+
break
411+
rows.append(row)
412+
raise gen.Return(rows)
419413

414+
@gen.coroutine
420415
def fetchmany(self, size=None):
421-
""" Fetch many """
422-
416+
"""Fetch many"""
423417
self._check_executed()
424418
if size is None:
425419
size = self.arraysize
426420

427421
rows = []
428422
for i in range_type(size):
429-
row = self.read_next()
423+
row = yield self.read_next()
430424
if row is None:
431425
break
432426
rows.append(row)
433427
self.rownumber += 1
434-
return rows
428+
raise gen.Return(rows)
435429

436430
def scroll(self, value, mode='relative'):
437431
self._check_executed()

Diff for: tornado_mysql/tests/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#from tornado_mysql.tests.test_issues import *
1+
from tornado_mysql.tests.test_issues import *
22
from tornado_mysql.tests.test_basic import *
33
#from tornado_mysql.tests.test_nextset import *
44
#from tornado_mysql.tests.test_DictCursor import *

Diff for: tornado_mysql/tests/test_basic.py

-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def prepare():
270270
PRIMARY KEY (id)
271271
)
272272
""")
273-
print("created bulkinsert")
274273
self.io_loop.run_sync(prepare)
275274

276275
@gen.coroutine

Diff for: tornado_mysql/tests/test_connection.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,26 @@
77

88

99
class TestConnection(base.PyMySQLTestCase):
10-
@gen.test
10+
@gen_test
1111
def test_utf8mb4(self):
1212
"""This test requires MySQL >= 5.5"""
1313
arg = self.databases[0].copy()
1414
arg['charset'] = 'utf8mb4'
1515
conn = yield tornado_mysql.connect(**arg)
1616

17-
@gen.test
17+
@gen_test
1818
def test_largedata(self):
1919
"""Large query and response (>=16MB)"""
2020
cur = self.connections[0].cursor()
21-
cur.execute("SELECT @@max_allowed_packet")
21+
yield cur.execute("SELECT @@max_allowed_packet")
2222
if cur.fetchone()[0] < 16*1024*1024 + 10:
2323
print("Set max_allowed_packet to bigger than 17MB")
24-
return
25-
t = 'a' * (16*1024*1024)
26-
yield cur.execute("SELECT '" + t + "'")
27-
assert cur.fetchone()[0] == t
24+
else:
25+
t = 'a' * (16*1024*1024)
26+
yield cur.execute("SELECT '" + t + "'")
27+
assert cur.fetchone()[0] == t
2828

29-
@gen.test
29+
@gen_test
3030
def test_escape_string(self):
3131
con = self.connections[0]
3232
cur = con.cursor()
@@ -35,7 +35,7 @@ def test_escape_string(self):
3535
yield cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES'")
3636
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
3737

38-
@gen.test
38+
@gen_test
3939
def test_autocommit(self):
4040
con = self.connections[0]
4141
self.assertFalse(con.get_autocommit())
@@ -49,7 +49,7 @@ def test_autocommit(self):
4949
yield cur.execute("SELECT @@AUTOCOMMIT")
5050
self.assertEqual(cur.fetchone()[0], 0)
5151

52-
@gen.test
52+
@gen_test
5353
def test_select_db(self):
5454
con = self.connections[0]
5555
current_db = self.databases[0]['db']
@@ -59,11 +59,11 @@ def test_select_db(self):
5959
yield cur.execute('SELECT database()')
6060
self.assertEqual(cur.fetchone()[0], current_db)
6161

62-
con.select_db(other_db)
62+
yield con.select_db(other_db)
6363
yield cur.execute('SELECT database()')
6464
self.assertEqual(cur.fetchone()[0], other_db)
6565

66-
@gen.test
66+
@gen_test
6767
def test_connection_gone_away(self):
6868
"""
6969
http://dev.mysql.com/doc/refman/5.0/en/gone-away.html

0 commit comments

Comments
 (0)