Skip to content
Snippets Groups Projects
Commit 65663fb36948 authored by INADA Naoki's avatar INADA Naoki
Browse files

Merge branch 'defer_unbuffered_cleanup' of https://github.com/zzzeek/PyMySQL...

Merge branch 'defer_unbuffered_cleanup' of https://github.com/zzzeek/PyMySQL into zzzeek-defer_unbuffered_cleanup

Conflicts:
	pymysql/tests/base.py
Branches
No related tags found
No related merge requests found
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
0.6.4 -Support "LOAD LOCAL INFILE". Thanks @wraziens 0.6.4 -Support "LOAD LOCAL INFILE". Thanks @wraziens
-Show MySQL warnings after execute query. -Show MySQL warnings after execute query.
-Fix MySQLError may be wrapped with OperationalError while connectiong. (#274) -Fix MySQLError may be wrapped with OperationalError while connectiong. (#274)
-SSCursor no longer attempts to expire un-collected rows within __del__,
delaying termination of an interrupted program; cleanup of uncollected
rows is left to the Connection on next execute, which emits a
warning at that time. (#287)
0.6.3 -Fixed multiple result sets with SSCursor. 0.6.3 -Fixed multiple result sets with SSCursor.
-Fixed connection timeout. -Fixed connection timeout.
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
PYPY = hasattr(sys, 'pypy_translation_info') PYPY = hasattr(sys, 'pypy_translation_info')
JYTHON = sys.platform.startswith('java') JYTHON = sys.platform.startswith('java')
IRONPYTHON = sys.platform == 'cli' IRONPYTHON = sys.platform == 'cli'
CPYTHON = not PYPY and not JYTHON and not IRONPYTHON
if PY2: if PY2:
range_type = xrange range_type = xrange
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import socket import socket
import struct import struct
import sys import sys
import warnings
try: try:
import ssl import ssl
...@@ -932,6 +933,7 @@ ...@@ -932,6 +933,7 @@
# If the last query was unbuffered, make sure it finishes before # If the last query was unbuffered, make sure it finishes before
# sending new commands # sending new commands
if self._result is not None and self._result.unbuffered_active: if self._result is not None and self._result.unbuffered_active:
warnings.warn("Previous unbuffered result was left incomplete")
self._result._finish_unbuffered_query() self._result._finish_unbuffered_query()
if isinstance(sql, text_type): if isinstance(sql, text_type):
......
...@@ -40,12 +40,6 @@ ...@@ -40,12 +40,6 @@
self._result = None self._result = None
self._rows = None self._rows = None
def __del__(self):
'''
When this gets GC'd close it.
'''
self.close()
def close(self): def close(self):
''' '''
Closing a cursor just exhausts all remaining data. Closing a cursor just exhausts all remaining data.
......
import gc
import os import os
import json import json
import pymysql import pymysql
import re import re
...@@ -1,8 +2,11 @@ ...@@ -1,8 +2,11 @@
import os import os
import json import json
import pymysql import pymysql
import re import re
from .._compat import CPYTHON
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
import unittest import unittest
...@@ -5,7 +9,8 @@ ...@@ -5,7 +9,8 @@
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
import unittest import unittest
import warnings
class PyMySQLTestCase(unittest.TestCase): class PyMySQLTestCase(unittest.TestCase):
# You can specify your test environment creating a file named # You can specify your test environment creating a file named
...@@ -41,4 +46,5 @@ ...@@ -41,4 +46,5 @@
self.connections = [] self.connections = []
for params in self.databases: for params in self.databases:
self.connections.append(pymysql.connect(**params)) self.connections.append(pymysql.connect(**params))
self.addCleanup(self._teardown_connections)
...@@ -44,4 +50,4 @@ ...@@ -44,4 +50,4 @@
def tearDown(self): def _teardown_connections(self):
for connection in self.connections: for connection in self.connections:
connection.close() connection.close()
...@@ -46,2 +52,40 @@ ...@@ -46,2 +52,40 @@
for connection in self.connections: for connection in self.connections:
connection.close() connection.close()
def safe_create_table(self, connection, tablename, ddl, cleanup=False):
"""create a table.
Ensures any existing version of that table
is first dropped.
Also adds a cleanup rule to drop the table after the test
completes.
"""
cursor = connection.cursor()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
cursor.execute("drop table if exists test")
cursor.execute("create table test (data varchar(10))")
cursor.close()
if cleanup:
self.addCleanup(self.drop_table, connection, tablename)
def drop_table(self, connection, tablename):
cursor = connection.cursor()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
cursor.execute("drop table if exists %s" % tablename)
cursor.close()
def safe_gc_collect(self):
"""Ensure cycles are collected via gc.
Runs additional times on non-CPython platforms.
"""
gc.collect()
if not CPYTHON:
gc.collect()
...@@ -32,6 +32,9 @@ ...@@ -32,6 +32,9 @@
c.execute("drop table dictcursor") c.execute("drop table dictcursor")
super(TestDictCursor, self).tearDown() super(TestDictCursor, self).tearDown()
def _ensure_cursor_expired(self, cursor):
pass
def test_DictCursor(self): def test_DictCursor(self):
bob, jim, fred = self.bob.copy(), self.jim.copy(), self.fred.copy() bob, jim, fred = self.bob.copy(), self.jim.copy(), self.fred.copy()
#all assert test compare to the structure as would come out from MySQLdb #all assert test compare to the structure as would come out from MySQLdb
...@@ -45,6 +48,8 @@ ...@@ -45,6 +48,8 @@
c.execute("SELECT * from dictcursor where name='bob'") c.execute("SELECT * from dictcursor where name='bob'")
r = c.fetchone() r = c.fetchone()
self.assertEqual(bob, r, "fetchone via DictCursor failed") self.assertEqual(bob, r, "fetchone via DictCursor failed")
self._ensure_cursor_expired(c)
# same again, but via fetchall => tuple) # same again, but via fetchall => tuple)
c.execute("SELECT * from dictcursor where name='bob'") c.execute("SELECT * from dictcursor where name='bob'")
r = c.fetchall() r = c.fetchall()
...@@ -65,6 +70,7 @@ ...@@ -65,6 +70,7 @@
c.execute("SELECT * from dictcursor") c.execute("SELECT * from dictcursor")
r = c.fetchmany(2) r = c.fetchmany(2)
self.assertEqual([bob, jim], r, "fetchmany failed via DictCursor") self.assertEqual([bob, jim], r, "fetchmany failed via DictCursor")
self._ensure_cursor_expired(c)
def test_custom_dict(self): def test_custom_dict(self):
class MyDict(dict): pass class MyDict(dict): pass
...@@ -81,6 +87,7 @@ ...@@ -81,6 +87,7 @@
cur.execute("SELECT * FROM dictcursor WHERE name='bob'") cur.execute("SELECT * FROM dictcursor WHERE name='bob'")
r = cur.fetchone() r = cur.fetchone()
self.assertEqual(bob, r, "fetchone() returns MyDictCursor") self.assertEqual(bob, r, "fetchone() returns MyDictCursor")
self._ensure_cursor_expired(cur)
cur.execute("SELECT * FROM dictcursor") cur.execute("SELECT * FROM dictcursor")
r = cur.fetchall() r = cur.fetchall()
...@@ -96,8 +103,9 @@ ...@@ -96,8 +103,9 @@
r = cur.fetchmany(2) r = cur.fetchmany(2)
self.assertEqual([bob, jim], r, self.assertEqual([bob, jim], r,
"list failed via MyDictCursor") "list failed via MyDictCursor")
self._ensure_cursor_expired(cur)
class TestSSDictCursor(TestDictCursor): class TestSSDictCursor(TestDictCursor):
cursor_type = pymysql.cursors.SSDictCursor cursor_type = pymysql.cursors.SSDictCursor
...@@ -99,8 +107,10 @@ ...@@ -99,8 +107,10 @@
class TestSSDictCursor(TestDictCursor): class TestSSDictCursor(TestDictCursor):
cursor_type = pymysql.cursors.SSDictCursor cursor_type = pymysql.cursors.SSDictCursor
def _ensure_cursor_expired(self, cursor):
list(cursor.fetchall_unbuffered())
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest
......
import warnings
from pymysql.tests import base
import pymysql.cursors
class CursorTest(base.PyMySQLTestCase):
def setUp(self):
super(CursorTest, self).setUp()
conn = self.connections[0]
self.safe_create_table(
conn,
"test", "create table test (data varchar(10))",
cleanup=True)
cursor = conn.cursor()
cursor.execute(
"insert into test (data) values "
"('row1'), ('row2'), ('row3'), ('row4'), ('row5')")
cursor.close()
self.test_connection = pymysql.connect(**self.databases[0])
self.addCleanup(self.test_connection.close)
def test_cleanup_rows_unbuffered(self):
conn = self.test_connection
cursor = conn.cursor(pymysql.cursors.SSCursor)
cursor.execute("select * from test as t1, test as t2")
for counter, row in enumerate(cursor):
if counter > 10:
break
del cursor
self.safe_gc_collect()
c2 = conn.cursor()
with warnings.catch_warnings(record=True) as log:
warnings.filterwarnings("always")
c2.execute("select 1")
self.assertGreater(len(log), 0)
self.assertEqual(
"Previous unbuffered result was left incomplete",
str(log[-1].message))
self.assertEqual(
c2.fetchone(), (1,)
)
self.assertIsNone(c2.fetchone())
def test_cleanup_rows_buffered(self):
conn = self.test_connection
cursor = conn.cursor(pymysql.cursors.Cursor)
cursor.execute("select * from test as t1, test as t2")
for counter, row in enumerate(cursor):
if counter > 10:
break
del cursor
self.safe_gc_collect()
c2 = conn.cursor()
c2.execute("select 1")
self.assertEqual(
c2.fetchone(), (1,)
)
self.assertIsNone(c2.fetchone())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment