Skip to content
Snippets Groups Projects
Commit aa355a9e3829 authored by INADA Naoki's avatar INADA Naoki
Browse files

Merge pull request #300 from mathieulongtin/issue_288_executemany

fix issue 288: executemany now works with "insert ... on duplicate update"
Branches
No related tags found
No related merge requests found
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#: Regular expression for :meth:`Cursor.executemany`. #: Regular expression for :meth:`Cursor.executemany`.
#: executemany only suports simple bulk insert. #: executemany only suports simple bulk insert.
#: You can use it to load large dataset. #: You can use it to load large dataset.
RE_INSERT_VALUES = re.compile(r"""INSERT\s.+\sVALUES\s+(\(\s*%s\s*(,\s*%s\s*)*\))\s*\Z""", RE_INSERT_VALUES = re.compile(r"""(INSERT\s.+\sVALUES\s+)(\(\s*%s\s*(?:,\s*%s\s*)*\))(\s*(?:ON DUPLICATE.*)?)\Z""",
re.IGNORECASE | re.DOTALL) re.IGNORECASE | re.DOTALL)
...@@ -145,5 +145,7 @@ ...@@ -145,5 +145,7 @@
m = RE_INSERT_VALUES.match(query) m = RE_INSERT_VALUES.match(query)
if m: if m:
q_values = m.group(1).rstrip() q_prefix = m.group(1)
q_values = m.group(2).rstrip()
q_postfix = m.group(3) or ''
assert q_values[0] == '(' and q_values[-1] == ')' assert q_values[0] == '(' and q_values[-1] == ')'
...@@ -149,9 +151,8 @@ ...@@ -149,9 +151,8 @@
assert q_values[0] == '(' and q_values[-1] == ')' assert q_values[0] == '(' and q_values[-1] == ')'
q_prefix = query[:m.start(1)] return self._do_execute_many(q_prefix, q_values, q_postfix, args,
return self._do_execute_many(q_prefix, q_values, args,
self.max_stmt_length, self.max_stmt_length,
self._get_db().encoding) self._get_db().encoding)
self.rowcount = sum(self.execute(query, arg) for arg in args) self.rowcount = sum(self.execute(query, arg) for arg in args)
return self.rowcount return self.rowcount
...@@ -152,11 +153,11 @@ ...@@ -152,11 +153,11 @@
self.max_stmt_length, self.max_stmt_length,
self._get_db().encoding) self._get_db().encoding)
self.rowcount = sum(self.execute(query, arg) for arg in args) self.rowcount = sum(self.execute(query, arg) for arg in args)
return self.rowcount return self.rowcount
def _do_execute_many(self, prefix, values, args, max_stmt_length, encoding): def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
conn = self._get_db() conn = self._get_db()
escape = self._escape_args escape = self._escape_args
if isinstance(prefix, text_type): if isinstance(prefix, text_type):
prefix = prefix.encode(encoding) prefix = prefix.encode(encoding)
...@@ -159,7 +160,9 @@ ...@@ -159,7 +160,9 @@
conn = self._get_db() conn = self._get_db()
escape = self._escape_args escape = self._escape_args
if isinstance(prefix, text_type): if isinstance(prefix, text_type):
prefix = prefix.encode(encoding) prefix = prefix.encode(encoding)
if isinstance(postfix, text_type):
postfix = postfix.encode(encoding)
sql = bytearray(prefix) sql = bytearray(prefix)
args = iter(args) args = iter(args)
v = values % escape(next(args), conn) v = values % escape(next(args), conn)
...@@ -171,9 +174,9 @@ ...@@ -171,9 +174,9 @@
v = values % escape(arg, conn) v = values % escape(arg, conn)
if isinstance(v, text_type): if isinstance(v, text_type):
v = v.encode(encoding) v = v.encode(encoding)
if len(sql) + len(v) + 1 > max_stmt_length: if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
rows += self.execute(sql) rows += self.execute(sql + postfix)
sql = bytearray(prefix) sql = bytearray(prefix)
else: else:
sql += b',' sql += b','
sql += v sql += v
...@@ -176,8 +179,8 @@ ...@@ -176,8 +179,8 @@
sql = bytearray(prefix) sql = bytearray(prefix)
else: else:
sql += b',' sql += b','
sql += v sql += v
rows += self.execute(sql) rows += self.execute(sql + postfix)
self.rowcount = rows self.rowcount = rows
return rows return rows
......
...@@ -296,7 +296,7 @@ ...@@ -296,7 +296,7 @@
%s , %s, %s , %s,
%s ) %s )
""", data) """, data)
self.assertEqual(cursor._last_executed, bytearray(b"""insert self.assertEqual(cursor._last_executed.strip(), bytearray(b"""insert
into bulkinsert (id, name, into bulkinsert (id, name,
age, height) age, height)
values (0, values (0,
...@@ -318,6 +318,33 @@ ...@@ -318,6 +318,33 @@
cursor.execute('commit') cursor.execute('commit')
self._verify_records(data) self._verify_records(data)
def test_issue_288(self):
"""executemany should work with "insert ... on update" """
conn = self.connections[0]
cursor = conn.cursor()
data = [(0, "bob", 21, 123), (1, "jim", 56, 45), (2, "fred", 100, 180)]
cursor.executemany("""insert
into bulkinsert (id, name,
age, height)
values (%s,
%s , %s,
%s ) on duplicate key update
age = values(age)
""", data)
self.assertEqual(cursor._last_executed.strip(), bytearray(b"""insert
into bulkinsert (id, name,
age, height)
values (0,
'bob' , 21,
123 ),(1,
'jim' , 56,
45 ),(2,
'fred' , 100,
180 ) on duplicate key update
age = values(age)"""))
cursor.execute('commit')
self._verify_records(data)
def test_warnings(self): def test_warnings(self):
con = self.connections[0] con = self.connections[0]
cur = con.cursor() cur = con.cursor()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment