cleaned up charset and unicode functionality, aligned BIT type handling with MySQLdb, disambiguated length coded binary and strings, raised exception after connection is close()'d
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
MBLENGTH = {
|
||||
8:1,
|
||||
33:3,
|
||||
@@ -7,4 +5,170 @@ MBLENGTH = {
|
||||
91:2
|
||||
}
|
||||
|
||||
class Charset:
|
||||
def __init__(self, id, name, collation, is_default):
|
||||
self.id, self.name, self.collation = id, name, collation
|
||||
self.is_default = is_default == 'Yes'
|
||||
|
||||
class Charsets:
|
||||
def __init__(self):
|
||||
self._by_id = {}
|
||||
|
||||
def add(self, c):
|
||||
self._by_id[c.id] = c
|
||||
|
||||
def by_id(self, id):
|
||||
return self._by_id[id]
|
||||
|
||||
def by_name(self, name):
|
||||
for c in self._by_id.values():
|
||||
if c.name == name and c.is_default:
|
||||
return c
|
||||
|
||||
_charsets = Charsets()
|
||||
"""
|
||||
Generated with:
|
||||
|
||||
mysql -N -s -e "select id, character_set_name, collation_name, is_default
|
||||
from information_schema.collations order by id;" | python -c "import sys
|
||||
for l in sys.stdin.readlines():
|
||||
id, name, collation, is_default = l.split(chr(9))
|
||||
print '_charsets.add(Charset(%s, \'%s\', \'%s\', \'%s\'))' \
|
||||
% (id, name, collation, is_default.strip())
|
||||
"
|
||||
|
||||
"""
|
||||
_charsets.add(Charset(1, 'big5', 'big5_chinese_ci', 'Yes'))
|
||||
_charsets.add(Charset(2, 'latin2', 'latin2_czech_cs', ''))
|
||||
_charsets.add(Charset(3, 'dec8', 'dec8_swedish_ci', 'Yes'))
|
||||
_charsets.add(Charset(4, 'cp850', 'cp850_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(5, 'latin1', 'latin1_german1_ci', ''))
|
||||
_charsets.add(Charset(6, 'hp8', 'hp8_english_ci', 'Yes'))
|
||||
_charsets.add(Charset(7, 'koi8r', 'koi8r_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(8, 'latin1', 'latin1_swedish_ci', 'Yes'))
|
||||
_charsets.add(Charset(9, 'latin2', 'latin2_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(10, 'swe7', 'swe7_swedish_ci', 'Yes'))
|
||||
_charsets.add(Charset(11, 'ascii', 'ascii_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(12, 'ujis', 'ujis_japanese_ci', 'Yes'))
|
||||
_charsets.add(Charset(13, 'sjis', 'sjis_japanese_ci', 'Yes'))
|
||||
_charsets.add(Charset(14, 'cp1251', 'cp1251_bulgarian_ci', ''))
|
||||
_charsets.add(Charset(15, 'latin1', 'latin1_danish_ci', ''))
|
||||
_charsets.add(Charset(16, 'hebrew', 'hebrew_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(18, 'tis620', 'tis620_thai_ci', 'Yes'))
|
||||
_charsets.add(Charset(19, 'euckr', 'euckr_korean_ci', 'Yes'))
|
||||
_charsets.add(Charset(20, 'latin7', 'latin7_estonian_cs', ''))
|
||||
_charsets.add(Charset(21, 'latin2', 'latin2_hungarian_ci', ''))
|
||||
_charsets.add(Charset(22, 'koi8u', 'koi8u_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(23, 'cp1251', 'cp1251_ukrainian_ci', ''))
|
||||
_charsets.add(Charset(24, 'gb2312', 'gb2312_chinese_ci', 'Yes'))
|
||||
_charsets.add(Charset(25, 'greek', 'greek_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(26, 'cp1250', 'cp1250_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(27, 'latin2', 'latin2_croatian_ci', ''))
|
||||
_charsets.add(Charset(28, 'gbk', 'gbk_chinese_ci', 'Yes'))
|
||||
_charsets.add(Charset(29, 'cp1257', 'cp1257_lithuanian_ci', ''))
|
||||
_charsets.add(Charset(30, 'latin5', 'latin5_turkish_ci', 'Yes'))
|
||||
_charsets.add(Charset(31, 'latin1', 'latin1_german2_ci', ''))
|
||||
_charsets.add(Charset(32, 'armscii8', 'armscii8_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(33, 'utf8', 'utf8_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(34, 'cp1250', 'cp1250_czech_cs', ''))
|
||||
_charsets.add(Charset(35, 'ucs2', 'ucs2_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(36, 'cp866', 'cp866_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(37, 'keybcs2', 'keybcs2_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(38, 'macce', 'macce_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(39, 'macroman', 'macroman_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(40, 'cp852', 'cp852_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(41, 'latin7', 'latin7_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(42, 'latin7', 'latin7_general_cs', ''))
|
||||
_charsets.add(Charset(43, 'macce', 'macce_bin', ''))
|
||||
_charsets.add(Charset(44, 'cp1250', 'cp1250_croatian_ci', ''))
|
||||
_charsets.add(Charset(47, 'latin1', 'latin1_bin', ''))
|
||||
_charsets.add(Charset(48, 'latin1', 'latin1_general_ci', ''))
|
||||
_charsets.add(Charset(49, 'latin1', 'latin1_general_cs', ''))
|
||||
_charsets.add(Charset(50, 'cp1251', 'cp1251_bin', ''))
|
||||
_charsets.add(Charset(51, 'cp1251', 'cp1251_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(52, 'cp1251', 'cp1251_general_cs', ''))
|
||||
_charsets.add(Charset(53, 'macroman', 'macroman_bin', ''))
|
||||
_charsets.add(Charset(57, 'cp1256', 'cp1256_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(58, 'cp1257', 'cp1257_bin', ''))
|
||||
_charsets.add(Charset(59, 'cp1257', 'cp1257_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(63, 'binary', 'binary', 'Yes'))
|
||||
_charsets.add(Charset(64, 'armscii8', 'armscii8_bin', ''))
|
||||
_charsets.add(Charset(65, 'ascii', 'ascii_bin', ''))
|
||||
_charsets.add(Charset(66, 'cp1250', 'cp1250_bin', ''))
|
||||
_charsets.add(Charset(67, 'cp1256', 'cp1256_bin', ''))
|
||||
_charsets.add(Charset(68, 'cp866', 'cp866_bin', ''))
|
||||
_charsets.add(Charset(69, 'dec8', 'dec8_bin', ''))
|
||||
_charsets.add(Charset(70, 'greek', 'greek_bin', ''))
|
||||
_charsets.add(Charset(71, 'hebrew', 'hebrew_bin', ''))
|
||||
_charsets.add(Charset(72, 'hp8', 'hp8_bin', ''))
|
||||
_charsets.add(Charset(73, 'keybcs2', 'keybcs2_bin', ''))
|
||||
_charsets.add(Charset(74, 'koi8r', 'koi8r_bin', ''))
|
||||
_charsets.add(Charset(75, 'koi8u', 'koi8u_bin', ''))
|
||||
_charsets.add(Charset(77, 'latin2', 'latin2_bin', ''))
|
||||
_charsets.add(Charset(78, 'latin5', 'latin5_bin', ''))
|
||||
_charsets.add(Charset(79, 'latin7', 'latin7_bin', ''))
|
||||
_charsets.add(Charset(80, 'cp850', 'cp850_bin', ''))
|
||||
_charsets.add(Charset(81, 'cp852', 'cp852_bin', ''))
|
||||
_charsets.add(Charset(82, 'swe7', 'swe7_bin', ''))
|
||||
_charsets.add(Charset(83, 'utf8', 'utf8_bin', ''))
|
||||
_charsets.add(Charset(84, 'big5', 'big5_bin', ''))
|
||||
_charsets.add(Charset(85, 'euckr', 'euckr_bin', ''))
|
||||
_charsets.add(Charset(86, 'gb2312', 'gb2312_bin', ''))
|
||||
_charsets.add(Charset(87, 'gbk', 'gbk_bin', ''))
|
||||
_charsets.add(Charset(88, 'sjis', 'sjis_bin', ''))
|
||||
_charsets.add(Charset(89, 'tis620', 'tis620_bin', ''))
|
||||
_charsets.add(Charset(90, 'ucs2', 'ucs2_bin', ''))
|
||||
_charsets.add(Charset(91, 'ujis', 'ujis_bin', ''))
|
||||
_charsets.add(Charset(92, 'geostd8', 'geostd8_general_ci', 'Yes'))
|
||||
_charsets.add(Charset(93, 'geostd8', 'geostd8_bin', ''))
|
||||
_charsets.add(Charset(94, 'latin1', 'latin1_spanish_ci', ''))
|
||||
_charsets.add(Charset(95, 'cp932', 'cp932_japanese_ci', 'Yes'))
|
||||
_charsets.add(Charset(96, 'cp932', 'cp932_bin', ''))
|
||||
_charsets.add(Charset(97, 'eucjpms', 'eucjpms_japanese_ci', 'Yes'))
|
||||
_charsets.add(Charset(98, 'eucjpms', 'eucjpms_bin', ''))
|
||||
_charsets.add(Charset(99, 'cp1250', 'cp1250_polish_ci', ''))
|
||||
_charsets.add(Charset(128, 'ucs2', 'ucs2_unicode_ci', ''))
|
||||
_charsets.add(Charset(129, 'ucs2', 'ucs2_icelandic_ci', ''))
|
||||
_charsets.add(Charset(130, 'ucs2', 'ucs2_latvian_ci', ''))
|
||||
_charsets.add(Charset(131, 'ucs2', 'ucs2_romanian_ci', ''))
|
||||
_charsets.add(Charset(132, 'ucs2', 'ucs2_slovenian_ci', ''))
|
||||
_charsets.add(Charset(133, 'ucs2', 'ucs2_polish_ci', ''))
|
||||
_charsets.add(Charset(134, 'ucs2', 'ucs2_estonian_ci', ''))
|
||||
_charsets.add(Charset(135, 'ucs2', 'ucs2_spanish_ci', ''))
|
||||
_charsets.add(Charset(136, 'ucs2', 'ucs2_swedish_ci', ''))
|
||||
_charsets.add(Charset(137, 'ucs2', 'ucs2_turkish_ci', ''))
|
||||
_charsets.add(Charset(138, 'ucs2', 'ucs2_czech_ci', ''))
|
||||
_charsets.add(Charset(139, 'ucs2', 'ucs2_danish_ci', ''))
|
||||
_charsets.add(Charset(140, 'ucs2', 'ucs2_lithuanian_ci', ''))
|
||||
_charsets.add(Charset(141, 'ucs2', 'ucs2_slovak_ci', ''))
|
||||
_charsets.add(Charset(142, 'ucs2', 'ucs2_spanish2_ci', ''))
|
||||
_charsets.add(Charset(143, 'ucs2', 'ucs2_roman_ci', ''))
|
||||
_charsets.add(Charset(144, 'ucs2', 'ucs2_persian_ci', ''))
|
||||
_charsets.add(Charset(145, 'ucs2', 'ucs2_esperanto_ci', ''))
|
||||
_charsets.add(Charset(146, 'ucs2', 'ucs2_hungarian_ci', ''))
|
||||
_charsets.add(Charset(192, 'utf8', 'utf8_unicode_ci', ''))
|
||||
_charsets.add(Charset(193, 'utf8', 'utf8_icelandic_ci', ''))
|
||||
_charsets.add(Charset(194, 'utf8', 'utf8_latvian_ci', ''))
|
||||
_charsets.add(Charset(195, 'utf8', 'utf8_romanian_ci', ''))
|
||||
_charsets.add(Charset(196, 'utf8', 'utf8_slovenian_ci', ''))
|
||||
_charsets.add(Charset(197, 'utf8', 'utf8_polish_ci', ''))
|
||||
_charsets.add(Charset(198, 'utf8', 'utf8_estonian_ci', ''))
|
||||
_charsets.add(Charset(199, 'utf8', 'utf8_spanish_ci', ''))
|
||||
_charsets.add(Charset(200, 'utf8', 'utf8_swedish_ci', ''))
|
||||
_charsets.add(Charset(201, 'utf8', 'utf8_turkish_ci', ''))
|
||||
_charsets.add(Charset(202, 'utf8', 'utf8_czech_ci', ''))
|
||||
_charsets.add(Charset(203, 'utf8', 'utf8_danish_ci', ''))
|
||||
_charsets.add(Charset(204, 'utf8', 'utf8_lithuanian_ci', ''))
|
||||
_charsets.add(Charset(205, 'utf8', 'utf8_slovak_ci', ''))
|
||||
_charsets.add(Charset(206, 'utf8', 'utf8_spanish2_ci', ''))
|
||||
_charsets.add(Charset(207, 'utf8', 'utf8_roman_ci', ''))
|
||||
_charsets.add(Charset(208, 'utf8', 'utf8_persian_ci', ''))
|
||||
_charsets.add(Charset(209, 'utf8', 'utf8_esperanto_ci', ''))
|
||||
_charsets.add(Charset(210, 'utf8', 'utf8_hungarian_ci', ''))
|
||||
|
||||
def charset_by_name(name):
|
||||
return _charsets.by_name(name)
|
||||
|
||||
def charset_by_id(id):
|
||||
return _charsets.by_id(id)
|
||||
|
||||
|
@@ -1,8 +1,6 @@
|
||||
# Python implementation of the MySQL client-server protocol
|
||||
# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
|
||||
|
||||
import re
|
||||
|
||||
try:
|
||||
import hashlib
|
||||
sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs)
|
||||
@@ -22,13 +20,13 @@ try:
|
||||
except ImportError:
|
||||
import StringIO
|
||||
|
||||
from charset import MBLENGTH
|
||||
from charset import MBLENGTH, charset_by_name, charset_by_id
|
||||
from cursors import Cursor
|
||||
from constants import FIELD_TYPE
|
||||
from constants import FIELD_TYPE, FLAG
|
||||
from constants import SERVER_STATUS
|
||||
from constants.CLIENT import *
|
||||
from constants.COMMAND import *
|
||||
from converters import escape_item, encoders, decoders, field_decoders
|
||||
from converters import escape_item, encoders, decoders
|
||||
from err import raise_mysql_exception, Warning, Error, \
|
||||
InterfaceError, DataError, DatabaseError, OperationalError, \
|
||||
IntegrityError, InternalError, NotSupportedError, ProgrammingError
|
||||
@@ -64,7 +62,8 @@ def dump_packet(data):
|
||||
dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0]
|
||||
for d in dump_data:
|
||||
print ' '.join(map(lambda x:"%02X" % ord(x), d)) + \
|
||||
' ' * (16 - len(d)) + ' ' * 2 + ' '.join(map(lambda x:"%s" % is_ascii(x), d))
|
||||
' ' * (16 - len(d)) + ' ' * 2 + \
|
||||
' '.join(map(lambda x:"%s" % is_ascii(x), d))
|
||||
print "-" * 88
|
||||
print ""
|
||||
|
||||
@@ -84,7 +83,8 @@ def _my_crypt(message1, message2):
|
||||
length = len(message1)
|
||||
result = struct.pack('B', length)
|
||||
for i in xrange(length):
|
||||
x = (struct.unpack('B', message1[i:i+1])[0] ^ struct.unpack('B', message2[i:i+1])[0])
|
||||
x = (struct.unpack('B', message1[i:i+1])[0] ^ \
|
||||
struct.unpack('B', message2[i:i+1])[0])
|
||||
result += struct.pack('B', x)
|
||||
return result
|
||||
|
||||
@@ -161,9 +161,10 @@ def unpack_int64(n):
|
||||
(struct.unpack('B',n[6])[0] << 48) + (struct.unpack('B',n[7])[0]<<56)
|
||||
|
||||
def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
|
||||
raise
|
||||
err = errorclass, errorvalue
|
||||
|
||||
if DEBUG:
|
||||
raise
|
||||
|
||||
if cursor:
|
||||
cursor.messages.append(err)
|
||||
else:
|
||||
@@ -271,8 +272,8 @@ class MysqlPacket(object):
|
||||
"""
|
||||
return self.__data[position:(position+length)]
|
||||
|
||||
def read_coded_length(self):
|
||||
"""Read a 'Length Coded' number from the data buffer.
|
||||
def read_length_coded_binary(self):
|
||||
"""Read a 'Length Coded Binary' number from the data buffer.
|
||||
|
||||
Length coded numbers can be anywhere from 1 to 9 bytes depending
|
||||
on the value of the first byte.
|
||||
@@ -290,16 +291,17 @@ class MysqlPacket(object):
|
||||
# TODO: what was 'longlong'? confirm it wasn't used?
|
||||
return unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
|
||||
|
||||
def read_length_coded_binary(self):
|
||||
"""Read a 'Length Coded Binary' from the data buffer.
|
||||
def read_length_coded_string(self):
|
||||
"""Read a 'Length Coded String' from the data buffer.
|
||||
|
||||
A 'Length Coded Binary' consists first of a length coded
|
||||
A 'Length Coded String' consists first of a length coded
|
||||
(unsigned, positive) integer represented in 1-9 bytes followed by
|
||||
that many bytes of binary data. (For example "cat" would be "3cat".)
|
||||
"""
|
||||
length = self.read_coded_length()
|
||||
if length:
|
||||
return self.read(length)
|
||||
length = self.read_length_coded_binary()
|
||||
if length is None:
|
||||
return None
|
||||
return self.read(length)
|
||||
|
||||
def is_ok_packet(self):
|
||||
return ord(self.get_bytes(0)) == 0
|
||||
@@ -342,19 +344,17 @@ class FieldDescriptorPacket(MysqlPacket):
|
||||
|
||||
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
|
||||
"""
|
||||
self.catalog = self.read_length_coded_binary()
|
||||
self.db = self.read_length_coded_binary()
|
||||
self.table_name = self.read_length_coded_binary()
|
||||
self.org_table = self.read_length_coded_binary()
|
||||
self.name = self.read_length_coded_binary()
|
||||
self.org_name = self.read_length_coded_binary()
|
||||
self.catalog = self.read_length_coded_string()
|
||||
self.db = self.read_length_coded_string()
|
||||
self.table_name = self.read_length_coded_string()
|
||||
self.org_table = self.read_length_coded_string()
|
||||
self.name = self.read_length_coded_string()
|
||||
self.org_name = self.read_length_coded_string()
|
||||
self.advance(1) # non-null filler
|
||||
self.charsetnr = struct.unpack('<h', self.read(2))[0]
|
||||
self.length = struct.unpack('<i', self.read(4))[0]
|
||||
self.charsetnr = struct.unpack('<H', self.read(2))[0]
|
||||
self.length = struct.unpack('<I', self.read(4))[0]
|
||||
self.type_code = ord(self.read(1))
|
||||
flags = struct.unpack('<h', self.read(2))
|
||||
# TODO: what is going on here with this flag parsing???
|
||||
self.flags = int(("%02X" % flags)[1:], 16)
|
||||
self.flags = struct.unpack('<H', self.read(2))[0]
|
||||
self.scale = ord(self.read(1)) # "decimals"
|
||||
self.advance(2) # filler (always 0x00)
|
||||
|
||||
@@ -401,8 +401,8 @@ class Connection(object):
|
||||
|
||||
def __init__(self, host="localhost", user=None, passwd="",
|
||||
db=None, port=3306, unix_socket=None,
|
||||
charset=DEFAULT_CHARSET, sql_mode=None,
|
||||
read_default_file=None, conv=decoders, use_unicode=True,
|
||||
charset='', sql_mode=None,
|
||||
read_default_file=None, conv=decoders, use_unicode=False,
|
||||
client_flag=0, cursorclass=Cursor, init_command=None,
|
||||
connect_timeout=None, ssl=None, read_default_group=None,
|
||||
compress=None, named_pipe=None):
|
||||
@@ -457,7 +457,7 @@ class Connection(object):
|
||||
return cfg.get("client",key)
|
||||
except:
|
||||
return default
|
||||
|
||||
|
||||
user = _config("user",user)
|
||||
passwd = _config("password",passwd)
|
||||
host = _config("host", host)
|
||||
@@ -465,15 +465,22 @@ class Connection(object):
|
||||
unix_socket = _config("socket",unix_socket)
|
||||
port = _config("port", port)
|
||||
charset = _config("default-character-set", charset)
|
||||
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = passwd
|
||||
self.db = db
|
||||
self.unix_socket = unix_socket
|
||||
self.use_unicode = use_unicode
|
||||
self.charset = DEFAULT_CHARSET
|
||||
if charset:
|
||||
self.charset = charset
|
||||
self.use_unicode = True
|
||||
else:
|
||||
self.charset = DEFAULT_CHARSET
|
||||
self.use_unicode = False
|
||||
|
||||
if use_unicode:
|
||||
self.use_unicode = use_unicode
|
||||
|
||||
client_flag |= CAPABILITIES
|
||||
client_flag |= MULTI_STATEMENTS
|
||||
@@ -483,20 +490,19 @@ class Connection(object):
|
||||
|
||||
self.cursorclass = cursorclass
|
||||
self.connect_timeout = connect_timeout
|
||||
|
||||
|
||||
self._connect()
|
||||
|
||||
self.set_charset_set(charset)
|
||||
|
||||
self.messages = []
|
||||
self.set_charset(charset)
|
||||
self.encoders = encoders
|
||||
self.decoders = conv
|
||||
self.field_decoders = field_decoders
|
||||
|
||||
self._affected_rows = 0
|
||||
self.host_info = "Not connected"
|
||||
|
||||
|
||||
self.autocommit(False)
|
||||
|
||||
|
||||
if sql_mode is not None:
|
||||
c = self.cursor()
|
||||
c.execute("SET sql_mode=%s", (sql_mode,))
|
||||
@@ -506,21 +512,20 @@ class Connection(object):
|
||||
if init_command is not None:
|
||||
c = self.cursor()
|
||||
c.execute(init_command)
|
||||
|
||||
|
||||
self.commit()
|
||||
|
||||
|
||||
|
||||
def close(self):
|
||||
''' Send the quit message and close the socket '''
|
||||
try:
|
||||
if self.socket:
|
||||
send_data = struct.pack('<i',1) + COM_QUIT
|
||||
sock = self.socket
|
||||
sock.send(send_data)
|
||||
sock.close()
|
||||
except:
|
||||
exc,value,tb = sys.exc_info()
|
||||
self.errorhandler(None, exc, value)
|
||||
|
||||
self.socket.send(send_data)
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
else:
|
||||
self.errorhandler(None, InterfaceError, "(0, '')")
|
||||
|
||||
def autocommit(self, value):
|
||||
''' Set whether or not to commit after every execute() '''
|
||||
try:
|
||||
@@ -560,7 +565,7 @@ class Connection(object):
|
||||
def cursor(self):
|
||||
''' Create a new cursor to execute queries with '''
|
||||
return self.cursorclass(self)
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
''' Context manager that returns a Cursor '''
|
||||
return self.cursor()
|
||||
@@ -577,7 +582,7 @@ class Connection(object):
|
||||
self._execute_command(COM_QUERY, sql)
|
||||
self._affected_rows = self._read_query_result()
|
||||
return self._affected_rows
|
||||
|
||||
|
||||
def next_result(self):
|
||||
self._affected_rows = self._read_query_result()
|
||||
return self._affected_rows
|
||||
@@ -595,7 +600,7 @@ class Connection(object):
|
||||
return
|
||||
pkt = self.read_packet()
|
||||
return pkt.is_ok_packet()
|
||||
|
||||
|
||||
def ping(self, reconnect=True):
|
||||
''' Check if the server is alive '''
|
||||
try:
|
||||
@@ -612,14 +617,13 @@ class Connection(object):
|
||||
pkt = self.read_packet()
|
||||
return pkt.is_ok_packet()
|
||||
|
||||
def set_charset_set(self, charset):
|
||||
def set_charset(self, charset):
|
||||
try:
|
||||
sock = self.socket
|
||||
if charset and self.charset != charset:
|
||||
if charset:
|
||||
self._execute_command(COM_QUERY, "SET NAMES %s" %
|
||||
self.escape(charset))
|
||||
self.read_packet()
|
||||
self.charset = charset
|
||||
self.charset = charset
|
||||
except:
|
||||
exc,value,tb = sys.exc_info()
|
||||
self.errorhandler(None, exc, value)
|
||||
@@ -647,7 +651,7 @@ class Connection(object):
|
||||
self._request_authentication()
|
||||
except socket.error, e:
|
||||
raise OperationalError(2003, "Can't connect to MySQL server on %r (%d)" % (self.host, e.args[0]))
|
||||
|
||||
|
||||
def read_packet(self, packet_type=MysqlPacket):
|
||||
"""Read an entire "mysql packet" in its entirety from the network
|
||||
and return a MysqlPacket type that represents the results."""
|
||||
@@ -673,7 +677,9 @@ class Connection(object):
|
||||
pckt_no = 0
|
||||
while len(buf) >= MAX_PACKET_LENGTH:
|
||||
header = struct.pack('<i', MAX_PACKET_LENGTH)[:-1]+chr(pckt_no)
|
||||
self.socket.send(header+buf[:MAX_PACKET_LENGTH])
|
||||
send_data = header + buf[:MAX_PACKET_LENGTH]
|
||||
self.socket.send(send_data)
|
||||
if DEBUG: dump_packet(send_data)
|
||||
buf = buf[MAX_PACKET_LENGTH:]
|
||||
pckt_no += 1
|
||||
header = struct.pack('<i', len(buf))[:-1]+chr(pckt_no)
|
||||
@@ -683,13 +689,12 @@ class Connection(object):
|
||||
#sock = self.socket
|
||||
#sock.send(send_data)
|
||||
|
||||
if DEBUG: dump_packet(send_data)
|
||||
#
|
||||
|
||||
def _execute_command(self, command, sql):
|
||||
self._send_command(command, sql)
|
||||
|
||||
def _request_authentication(self):
|
||||
sock = self.socket
|
||||
self._send_authentication()
|
||||
|
||||
def _send_authentication(self):
|
||||
@@ -700,9 +705,12 @@ class Connection(object):
|
||||
|
||||
if self.user is None:
|
||||
raise ValueError, "Did not specify a username"
|
||||
|
||||
data_init = (struct.pack('<i', self.client_flag)) \
|
||||
+ "\0\0\0\x01" + '\x08' + '\0'*23
|
||||
|
||||
charset_id = charset_by_name(self.charset).id
|
||||
self.user = self.user.encode(self.charset)
|
||||
|
||||
data_init = struct.pack('<i', self.client_flag) + "\0\0\0\x01" + \
|
||||
chr(charset_id) + '\0'*23
|
||||
|
||||
next_packet = 1
|
||||
|
||||
@@ -722,13 +730,14 @@ class Connection(object):
|
||||
data = data_init + self.user+"\0" + _scramble(self.password, self.salt)
|
||||
|
||||
if self.db:
|
||||
data += self.db.encode(self.charset) + "\0"
|
||||
self.db = self.db.encode(self.charset)
|
||||
data += self.db + "\0"
|
||||
|
||||
data = pack_int24(len(data)) + chr(next_packet) + data
|
||||
next_packet += 2
|
||||
|
||||
|
||||
if DEBUG: dump_packet(data)
|
||||
|
||||
|
||||
sock.send(data)
|
||||
|
||||
auth_packet = MysqlPacket(sock)
|
||||
@@ -743,13 +752,13 @@ class Connection(object):
|
||||
#raise NotImplementedError, "old_passwords are not supported. Check to see if mysqld was started with --old-passwords, if old-passwords=1 in a my.cnf file, or if there are some short hashes in your mysql.user table."
|
||||
data = _scramble_323(self.password, self.salt) + "\0"
|
||||
data = pack_int24(len(data)) + chr(next_packet) + data
|
||||
|
||||
|
||||
sock.send(data)
|
||||
auth_packet = MysqlPacket(sock)
|
||||
auth_packet.check_error()
|
||||
if DEBUG: auth_packet.dump()
|
||||
|
||||
|
||||
|
||||
|
||||
# _mysql support
|
||||
def thread_id(self):
|
||||
return self.server_thread_id[0]
|
||||
@@ -759,10 +768,10 @@ class Connection(object):
|
||||
|
||||
def get_host_info(self):
|
||||
return self.host_info
|
||||
|
||||
|
||||
def get_proto_info(self):
|
||||
return self.protocol_version
|
||||
|
||||
|
||||
def _get_server_information(self):
|
||||
sock = self.socket
|
||||
i = 0
|
||||
@@ -773,27 +782,28 @@ class Connection(object):
|
||||
#packet_len = ord(data[i:i+1])
|
||||
#i += 4
|
||||
self.protocol_version = ord(data[i:i+1])
|
||||
|
||||
|
||||
i += 1
|
||||
server_end = data.find("\0", i)
|
||||
self.server_version = data[i:server_end]
|
||||
|
||||
|
||||
i = server_end + 1
|
||||
self.server_thread_id = struct.unpack('<h', data[i:i+2])
|
||||
|
||||
i += 4
|
||||
self.salt = data[i:i+8]
|
||||
|
||||
|
||||
i += 9
|
||||
if len(data) >= i + 1:
|
||||
i += 1
|
||||
|
||||
|
||||
self.server_capabilities = struct.unpack('<h', data[i:i+2])[0]
|
||||
|
||||
i += 1
|
||||
self.server_language = ord(data[i:i+1])
|
||||
|
||||
i += 16
|
||||
self.server_charset = charset_by_id(self.server_language).name
|
||||
|
||||
i += 16
|
||||
if len(data) >= i+12-1:
|
||||
rest_salt = data[i:i+12]
|
||||
self.salt += rest_salt
|
||||
@@ -840,8 +850,8 @@ class MySQLResult(object):
|
||||
|
||||
def _read_ok_packet(self):
|
||||
self.first_packet.advance(1) # field_count (always '0')
|
||||
self.affected_rows = self.first_packet.read_coded_length()
|
||||
self.insert_id = self.first_packet.read_coded_length()
|
||||
self.affected_rows = self.first_packet.read_length_coded_binary()
|
||||
self.insert_id = self.first_packet.read_length_coded_binary()
|
||||
self.server_status = struct.unpack('<H', self.first_packet.read(2))[0]
|
||||
self.warning_count = struct.unpack('<H', self.first_packet.read(2))[0]
|
||||
self.message = self.first_packet.read_all()
|
||||
@@ -871,14 +881,7 @@ class MySQLResult(object):
|
||||
converter = self.connection.decoders[field.type_code]
|
||||
|
||||
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
|
||||
data = packet.read_length_coded_binary()
|
||||
converted = None
|
||||
if data != None:
|
||||
converted = converter(data)
|
||||
else:
|
||||
converter = self.connection.field_decoders[field.type_code]
|
||||
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
|
||||
data = packet.read_length_coded_binary()
|
||||
data = packet.read_length_coded_string()
|
||||
converted = None
|
||||
if data != None:
|
||||
converted = converter(self.connection, field, data)
|
||||
|
@@ -29,4 +29,4 @@ STRING = 254
|
||||
GEOMETRY = 255
|
||||
|
||||
CHAR = TINY
|
||||
INTERVAL = ENUM
|
||||
INTERVAL = ENUM
|
||||
|
@@ -1,11 +1,9 @@
|
||||
import re
|
||||
import datetime
|
||||
import time
|
||||
import array
|
||||
import struct
|
||||
|
||||
from times import Date, Time, TimeDelta, Timestamp
|
||||
from constants import FIELD_TYPE
|
||||
from constants import FIELD_TYPE, FLAG
|
||||
from charset import charset_by_id
|
||||
|
||||
try:
|
||||
set
|
||||
@@ -20,8 +18,16 @@ ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
|
||||
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
|
||||
|
||||
def escape_item(val, charset):
|
||||
if type(val) in [tuple, list, set]:
|
||||
return escape_sequence(val, charset)
|
||||
if type(val) is dict:
|
||||
return escape_dict(val, charset)
|
||||
encoder = encoders[type(val)]
|
||||
return encoder(val, charset)
|
||||
val = encoder(val)
|
||||
if type(val) is str:
|
||||
return val
|
||||
val = val.encode(charset)
|
||||
return val
|
||||
|
||||
def escape_dict(val, charset):
|
||||
n = {}
|
||||
@@ -37,99 +43,96 @@ def escape_sequence(val, charset):
|
||||
n.append(quoted)
|
||||
return tuple(n)
|
||||
|
||||
def escape_bool(value, charset):
|
||||
return str(int(value)).encode(charset)
|
||||
def escape_set(val, charset):
|
||||
val = map(lambda x: escape_item(x, charset), val)
|
||||
return ','.join(val)
|
||||
|
||||
def escape_object(value, charset):
|
||||
return str(value).encode(charset)
|
||||
def escape_bool(value):
|
||||
return str(int(value))
|
||||
|
||||
def escape_object(value):
|
||||
return str(value)
|
||||
|
||||
escape_int = escape_long = escape_object
|
||||
|
||||
def escape_float(value, charset):
|
||||
return ('%.15g' % value).encode(charset)
|
||||
def escape_float(value):
|
||||
return ('%.15g' % value)
|
||||
|
||||
def escape_string(value, charset):
|
||||
r = ("'%s'" % ESCAPE_REGEX.sub(
|
||||
lambda match: ESCAPE_MAP.get(match.group(0)), value))
|
||||
# TODO: make sure that encodings are handled correctly here.
|
||||
# Since we may be dealing with binary data, the encoding
|
||||
# routine below is commented out.
|
||||
#if not charset is None:
|
||||
# r = r.encode(charset)
|
||||
return r
|
||||
|
||||
def escape_unicode(value, charset):
|
||||
# pass None as the charset because we already encode it
|
||||
return escape_string(value.encode(charset), None)
|
||||
def escape_string(value):
|
||||
return ("'%s'" % ESCAPE_REGEX.sub(
|
||||
lambda match: ESCAPE_MAP.get(match.group(0)), value))
|
||||
|
||||
def escape_None(value, charset):
|
||||
return 'NULL'.encode(charset)
|
||||
def escape_unicode(value):
|
||||
return escape_string(value)
|
||||
|
||||
def escape_timedelta(obj, charset):
|
||||
def escape_None(value):
|
||||
return 'NULL'
|
||||
|
||||
def escape_timedelta(obj):
|
||||
seconds = int(obj.seconds) % 60
|
||||
minutes = int(obj.seconds // 60) % 60
|
||||
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
|
||||
return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds), charset)
|
||||
return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds))
|
||||
|
||||
def escape_time(obj, charset):
|
||||
def escape_time(obj):
|
||||
s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute),
|
||||
int(obj.second))
|
||||
if obj.microsecond:
|
||||
s += ".%f" % obj.microsecond
|
||||
|
||||
return escape_string(s, charset)
|
||||
return escape_string(s)
|
||||
|
||||
def escape_datetime(obj, charset):
|
||||
return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"), charset)
|
||||
def escape_datetime(obj):
|
||||
return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
def escape_date(obj, charset):
|
||||
return escape_string(obj.strftime("%Y-%m-%d"), charset)
|
||||
def escape_date(obj):
|
||||
return escape_string(obj.strftime("%Y-%m-%d"))
|
||||
|
||||
def escape_struct_time(obj, charset):
|
||||
return escape_datetime(datetime.datetime(*obj[:6]), charset)
|
||||
def escape_struct_time(obj):
|
||||
return escape_datetime(datetime.datetime(*obj[:6]))
|
||||
|
||||
def convert_datetime(obj):
|
||||
def convert_datetime(connection, field, obj):
|
||||
"""Returns a DATETIME or TIMESTAMP column value as a datetime object:
|
||||
|
||||
|
||||
>>> datetime_or_None('2007-02-25 23:06:20')
|
||||
datetime.datetime(2007, 2, 25, 23, 6, 20)
|
||||
>>> datetime_or_None('2007-02-25T23:06:20')
|
||||
datetime.datetime(2007, 2, 25, 23, 6, 20)
|
||||
|
||||
|
||||
Illegal values are returned as None:
|
||||
|
||||
|
||||
>>> datetime_or_None('2007-02-31T23:06:20') is None
|
||||
True
|
||||
>>> datetime_or_None('0000-00-00 00:00:00') is None
|
||||
True
|
||||
|
||||
|
||||
"""
|
||||
if ' ' in obj:
|
||||
sep = ' '
|
||||
elif 'T' in obj:
|
||||
sep = 'T'
|
||||
else:
|
||||
return convert_date(obj)
|
||||
return convert_date(connection, field, obj)
|
||||
|
||||
try:
|
||||
ymd, hms = obj.split(sep, 1)
|
||||
return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ])
|
||||
except ValueError:
|
||||
return convert_date(obj)
|
||||
return convert_date(connection, field, obj)
|
||||
|
||||
def convert_timedelta(obj):
|
||||
def convert_timedelta(connection, field, obj):
|
||||
"""Returns a TIME column as a timedelta object:
|
||||
|
||||
>>> timedelta_or_None('25:06:17')
|
||||
datetime.timedelta(1, 3977)
|
||||
>>> timedelta_or_None('-25:06:17')
|
||||
datetime.timedelta(-2, 83177)
|
||||
|
||||
|
||||
Illegal values are returned as None:
|
||||
|
||||
|
||||
>>> timedelta_or_None('random crap') is None
|
||||
True
|
||||
|
||||
|
||||
Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
|
||||
can accept values as (+|-)DD HH:MM:SS. The latter format will not
|
||||
be parsed correctly by this function.
|
||||
@@ -147,23 +150,23 @@ def convert_timedelta(obj):
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def convert_time(obj):
|
||||
def convert_time(connection, field, obj):
|
||||
"""Returns a TIME column as a time object:
|
||||
|
||||
>>> time_or_None('15:06:17')
|
||||
datetime.time(15, 6, 17)
|
||||
|
||||
|
||||
Illegal values are returned as None:
|
||||
|
||||
|
||||
>>> time_or_None('-25:06:17') is None
|
||||
True
|
||||
>>> time_or_None('random crap') is None
|
||||
True
|
||||
|
||||
|
||||
Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
|
||||
can accept values as (+|-)DD HH:MM:SS. The latter format will not
|
||||
be parsed correctly by this function.
|
||||
|
||||
|
||||
Also note that MySQL's TIME column corresponds more closely to
|
||||
Python's timedelta and not time. However if you want TIME columns
|
||||
to be treated as time-of-day and not a time offset, then you can
|
||||
@@ -172,53 +175,54 @@ def convert_time(obj):
|
||||
from math import modf
|
||||
try:
|
||||
hour, minute, second = obj.split(':')
|
||||
return datetime.time(hour=int(hour), minute=int(minute), second=int(second),
|
||||
microsecond=int(modf(float(second))[0]*1000000))
|
||||
return datetime.time(hour=int(hour), minute=int(minute),
|
||||
second=int(second),
|
||||
microsecond=int(modf(float(second))[0]*1000000))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def convert_date(obj):
|
||||
def convert_date(connection, field, obj):
|
||||
"""Returns a DATE column as a date object:
|
||||
|
||||
>>> date_or_None('2007-02-26')
|
||||
datetime.date(2007, 2, 26)
|
||||
|
||||
|
||||
Illegal values are returned as None:
|
||||
|
||||
|
||||
>>> date_or_None('2007-02-31') is None
|
||||
True
|
||||
>>> date_or_None('0000-00-00') is None
|
||||
True
|
||||
|
||||
|
||||
"""
|
||||
try:
|
||||
return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def convert_mysql_timestamp(timestamp):
|
||||
def convert_mysql_timestamp(connection, field, timestamp):
|
||||
"""Convert a MySQL TIMESTAMP to a Timestamp object.
|
||||
|
||||
MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME:
|
||||
|
||||
|
||||
>>> mysql_timestamp_converter('2007-02-25 22:32:17')
|
||||
datetime.datetime(2007, 2, 25, 22, 32, 17)
|
||||
|
||||
|
||||
MySQL < 4.1 uses a big string of numbers:
|
||||
|
||||
|
||||
>>> mysql_timestamp_converter('20070225223217')
|
||||
datetime.datetime(2007, 2, 25, 22, 32, 17)
|
||||
|
||||
|
||||
Illegal values are returned as None:
|
||||
|
||||
|
||||
>>> mysql_timestamp_converter('2007-02-31 22:32:17') is None
|
||||
True
|
||||
>>> mysql_timestamp_converter('00000000000000') is None
|
||||
True
|
||||
|
||||
|
||||
"""
|
||||
if timestamp[4] == '-':
|
||||
return convert_datetime(timestamp)
|
||||
return convert_datetime(connection, field, timestamp)
|
||||
timestamp += "0"*(14-len(timestamp)) # padding
|
||||
year, month, day, hour, minute, second = \
|
||||
int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \
|
||||
@@ -229,13 +233,38 @@ def convert_mysql_timestamp(timestamp):
|
||||
return None
|
||||
|
||||
def convert_set(s):
|
||||
# TODO: this may not be correct
|
||||
return set(s.split(","))
|
||||
|
||||
def convert_bit(b):
|
||||
b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
|
||||
return struct.unpack(">Q", b)[0]
|
||||
|
||||
def convert_bit(connection, field, b):
|
||||
#b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
|
||||
#return struct.unpack(">Q", b)[0]
|
||||
#
|
||||
# the snippet above is right, but MySQLdb doesn't process bits,
|
||||
# so we shouldn't either
|
||||
return b
|
||||
|
||||
def convert_characters(connection, field, data):
|
||||
if field.flags & FLAG.SET:
|
||||
return convert_set(data)
|
||||
if field.flags & FLAG.BINARY:
|
||||
return data
|
||||
field_charset = charset_by_id(field.charsetnr).name
|
||||
if connection.use_unicode:
|
||||
data = data.decode(field_charset)
|
||||
elif connection.charset != field_charset:
|
||||
data = data.decode(field_charset)
|
||||
data = data.encode(connection.charset)
|
||||
return data
|
||||
|
||||
def convert_int(connection, field, data):
|
||||
return int(data)
|
||||
|
||||
def convert_long(connection, field, data):
|
||||
return long(data)
|
||||
|
||||
def convert_float(connection, field, data):
|
||||
return float(data)
|
||||
|
||||
encoders = {
|
||||
bool: escape_bool,
|
||||
int: escape_int,
|
||||
@@ -257,21 +286,28 @@ encoders = {
|
||||
|
||||
decoders = {
|
||||
FIELD_TYPE.BIT: convert_bit,
|
||||
FIELD_TYPE.TINY: int,
|
||||
FIELD_TYPE.SHORT: int,
|
||||
FIELD_TYPE.LONG: long,
|
||||
FIELD_TYPE.FLOAT: float,
|
||||
FIELD_TYPE.DOUBLE: float,
|
||||
FIELD_TYPE.DECIMAL: float,
|
||||
FIELD_TYPE.NEWDECIMAL: float,
|
||||
FIELD_TYPE.LONGLONG: long,
|
||||
FIELD_TYPE.INT24: int,
|
||||
FIELD_TYPE.YEAR: int,
|
||||
FIELD_TYPE.TINY: convert_int,
|
||||
FIELD_TYPE.SHORT: convert_int,
|
||||
FIELD_TYPE.LONG: convert_long,
|
||||
FIELD_TYPE.FLOAT: convert_float,
|
||||
FIELD_TYPE.DOUBLE: convert_float,
|
||||
FIELD_TYPE.DECIMAL: convert_float,
|
||||
FIELD_TYPE.NEWDECIMAL: convert_float,
|
||||
FIELD_TYPE.LONGLONG: convert_long,
|
||||
FIELD_TYPE.INT24: convert_int,
|
||||
FIELD_TYPE.YEAR: convert_int,
|
||||
FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp,
|
||||
FIELD_TYPE.DATETIME: convert_datetime,
|
||||
FIELD_TYPE.TIME: convert_timedelta,
|
||||
FIELD_TYPE.DATE: convert_date,
|
||||
FIELD_TYPE.SET: convert_set,
|
||||
FIELD_TYPE.BLOB: convert_characters,
|
||||
FIELD_TYPE.TINY_BLOB: convert_characters,
|
||||
FIELD_TYPE.MEDIUM_BLOB: convert_characters,
|
||||
FIELD_TYPE.LONG_BLOB: convert_characters,
|
||||
FIELD_TYPE.STRING: convert_characters,
|
||||
FIELD_TYPE.VAR_STRING: convert_characters,
|
||||
FIELD_TYPE.VARCHAR: convert_characters,
|
||||
#FIELD_TYPE.BLOB: str,
|
||||
#FIELD_TYPE.STRING: str,
|
||||
#FIELD_TYPE.VAR_STRING: str,
|
||||
@@ -279,28 +315,13 @@ decoders = {
|
||||
}
|
||||
conversions = decoders # for MySQLdb compatibility
|
||||
|
||||
def decode_characters(connection, field, data):
|
||||
if field.charsetnr == 63 or not connection.use_unicode:
|
||||
# binary data, leave it alone
|
||||
return data
|
||||
return data.decode(connection.charset)
|
||||
|
||||
# These take a field instance rather than just the data.
|
||||
field_decoders = {
|
||||
FIELD_TYPE.BLOB: decode_characters,
|
||||
FIELD_TYPE.TINY_BLOB: decode_characters,
|
||||
FIELD_TYPE.MEDIUM_BLOB: decode_characters,
|
||||
FIELD_TYPE.LONG_BLOB: decode_characters,
|
||||
FIELD_TYPE.STRING: decode_characters,
|
||||
FIELD_TYPE.VAR_STRING: decode_characters,
|
||||
FIELD_TYPE.VARCHAR: decode_characters,
|
||||
}
|
||||
|
||||
try:
|
||||
# python version > 2.3
|
||||
from decimal import Decimal
|
||||
decoders[FIELD_TYPE.DECIMAL] = Decimal
|
||||
decoders[FIELD_TYPE.NEWDECIMAL] = Decimal
|
||||
def convert_decimal(connection, field, data):
|
||||
return Decimal(data)
|
||||
decoders[FIELD_TYPE.DECIMAL] = convert_decimal
|
||||
decoders[FIELD_TYPE.NEWDECIMAL] = convert_decimal
|
||||
|
||||
def escape_decimal(obj, charset):
|
||||
return unicode(obj).encode(charset)
|
||||
|
@@ -1,3 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import struct
|
||||
import re
|
||||
|
||||
@@ -52,23 +53,23 @@ class Cursor(object):
|
||||
if not self.connection:
|
||||
self.errorhandler(self, ProgrammingError, "cursor closed")
|
||||
return self.connection
|
||||
|
||||
|
||||
def _check_executed(self):
|
||||
if not self._executed:
|
||||
self.errorhandler(self, ProgrammingError, "execute() first")
|
||||
|
||||
|
||||
def setinputsizes(self, *args):
|
||||
"""Does nothing, required by DB API."""
|
||||
|
||||
|
||||
def setoutputsizes(self, *args):
|
||||
"""Does nothing, required by DB API."""
|
||||
|
||||
|
||||
def nextset(self):
|
||||
''' Get the next query set '''
|
||||
if self._executed:
|
||||
self.fetchall()
|
||||
del self.messages[:]
|
||||
|
||||
|
||||
if not self._has_next:
|
||||
return None
|
||||
connection = self._get_db()
|
||||
@@ -79,11 +80,11 @@ class Cursor(object):
|
||||
def execute(self, query, args=None):
|
||||
''' Execute a query '''
|
||||
from sys import exc_info
|
||||
|
||||
|
||||
conn = self._get_db()
|
||||
charset = conn.charset
|
||||
del self.messages[:]
|
||||
|
||||
|
||||
# this ordering is good because conn.escape() returns
|
||||
# an encoded string.
|
||||
if isinstance(query, unicode):
|
||||
@@ -91,7 +92,7 @@ class Cursor(object):
|
||||
|
||||
if args is not None:
|
||||
query = query % conn.escape(args)
|
||||
|
||||
|
||||
result = 0
|
||||
try:
|
||||
result = self._query(query)
|
||||
@@ -103,7 +104,7 @@ class Cursor(object):
|
||||
|
||||
self._executed = query
|
||||
return result
|
||||
|
||||
|
||||
def executemany(self, query, args):
|
||||
''' Run several data against one query '''
|
||||
del self.messages[:]
|
||||
@@ -113,30 +114,66 @@ class Cursor(object):
|
||||
charset = conn.charset
|
||||
if isinstance(query, unicode):
|
||||
query = query.encode(charset)
|
||||
|
||||
|
||||
self.rowcount = sum([ self.execute(query, arg) for arg in args ])
|
||||
return self.rowcount
|
||||
|
||||
|
||||
|
||||
|
||||
def callproc(self, procname, args=()):
|
||||
''' Call a stored procedure. Take care to ensure that procname is
|
||||
properly escaped. '''
|
||||
if not isinstance(args, tuple):
|
||||
args = (args,)
|
||||
"""Execute stored procedure procname with args
|
||||
|
||||
argstr = ("%s," * len(args))[:-1]
|
||||
procname -- string, name of procedure to execute on server
|
||||
|
||||
return self.execute("CALL `%s`(%s)" % (procname, argstr), args)
|
||||
args -- Sequence of parameters to use with procedure
|
||||
|
||||
Returns the original args.
|
||||
|
||||
Compatibility warning: PEP-249 specifies that any modified
|
||||
parameters must be returned. This is currently impossible
|
||||
as they are only available by storing them in a server
|
||||
variable and then retrieved by a query. Since stored
|
||||
procedures return zero or more result sets, there is no
|
||||
reliable way to get at OUT or INOUT parameters via callproc.
|
||||
The server variables are named @_procname_n, where procname
|
||||
is the parameter above and n is the position of the parameter
|
||||
(from zero). Once all result sets generated by the procedure
|
||||
have been fetched, you can issue a SELECT @_procname_0, ...
|
||||
query using .execute() to get any OUT or INOUT values.
|
||||
|
||||
Compatibility warning: The act of calling a stored procedure
|
||||
itself creates an empty result set. This appears after any
|
||||
result sets generated by the procedure. This is non-standard
|
||||
behavior with respect to the DB-API. Be sure to use nextset()
|
||||
to advance through all result sets; otherwise you may get
|
||||
disconnected.
|
||||
"""
|
||||
conn = self._get_db()
|
||||
for index, arg in enumerate(args):
|
||||
q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg))
|
||||
if isinstance(q, unicode):
|
||||
q = q.encode(conn.charset)
|
||||
self._query(q)
|
||||
self.nextset()
|
||||
|
||||
q = "CALL %s(%s)" % (procname,
|
||||
','.join(['@_%s_%d' % (procname, i)
|
||||
for i in range(len(args))]))
|
||||
if isinstance(q, unicode):
|
||||
q = q.encode(conn.charset)
|
||||
self._query(q)
|
||||
self._executed = q
|
||||
|
||||
return args
|
||||
|
||||
def fetchone(self):
|
||||
''' Fetch the next row '''
|
||||
self._check_executed()
|
||||
self._check_executed()
|
||||
if self._rows is None or self.rownumber >= len(self._rows):
|
||||
return None
|
||||
result = self._rows[self.rownumber]
|
||||
self.rownumber += 1
|
||||
return result
|
||||
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
''' Fetch several rows '''
|
||||
self._check_executed()
|
||||
@@ -158,15 +195,15 @@ class Cursor(object):
|
||||
result = self._rows
|
||||
self.rownumber = len(self._rows)
|
||||
return result
|
||||
|
||||
|
||||
def scroll(self, value, mode='relative'):
|
||||
|
||||
self._check_executed()
|
||||
if mode == 'relative':
|
||||
r = self.rownumber + value
|
||||
elif mode == 'absolute':
|
||||
r = value
|
||||
else:
|
||||
self.errorhandler(self, ProgrammingError,
|
||||
self.errorhandler(self, ProgrammingError,
|
||||
"unknown scroll mode %s" % mode)
|
||||
|
||||
if r < 0 or r >= len(self._rows):
|
||||
@@ -179,23 +216,23 @@ class Cursor(object):
|
||||
conn.query(q)
|
||||
self._do_get_result()
|
||||
return self.rowcount
|
||||
|
||||
|
||||
def _do_get_result(self):
|
||||
conn = self._get_db()
|
||||
self.rowcount = conn._result.affected_rows
|
||||
|
||||
|
||||
self.rownumber = 0
|
||||
self.description = conn._result.description
|
||||
self.lastrowid = conn._result.insert_id
|
||||
self._rows = conn._result.rows
|
||||
self._has_next = conn._result.has_next
|
||||
conn._result = None
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
self._check_executed()
|
||||
result = self.rownumber and self._rows[self.rownumber:] or self._rows
|
||||
return iter(result)
|
||||
|
||||
|
||||
Warning = Warning
|
||||
Error = Error
|
||||
InterfaceError = InterfaceError
|
||||
|
@@ -3,7 +3,8 @@ import unittest
|
||||
|
||||
class PyMySQLTestCase(unittest.TestCase):
|
||||
databases = [
|
||||
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql"},
|
||||
{"host":"localhost","user":"root",
|
||||
"passwd":"","db":"test_pymysql", "use_unicode": True},
|
||||
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}]
|
||||
|
||||
def setUp(self):
|
||||
|
@@ -15,7 +15,8 @@ class TestConversion(base.PyMySQLTestCase):
|
||||
c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v)
|
||||
c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes")
|
||||
r = c.fetchone()
|
||||
self.assertEqual(v[:8], r[:8])
|
||||
self.assertEqual("\x01", r[0])
|
||||
self.assertEqual(v[1:8], r[1:8])
|
||||
# mysql throws away microseconds so we need to check datetimes
|
||||
# specially. additionally times are turned into timedeltas.
|
||||
self.assertEqual(datetime.datetime(*v[8].timetuple()[:6]), r[8])
|
||||
|
@@ -104,8 +104,8 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
|
||||
self.assertEqual('1', pymysql.converters.escape_item(1, "utf8"))
|
||||
self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8"))
|
||||
|
||||
self.assertEqual('1', pymysql.converters.escape_object(1, "utf8"))
|
||||
self.assertEqual('1', pymysql.converters.escape_object(1L, "utf8"))
|
||||
self.assertEqual('1', pymysql.converters.escape_object(1))
|
||||
self.assertEqual('1', pymysql.converters.escape_object(1L))
|
||||
|
||||
def test_issue_15(self):
|
||||
""" query should be expanded before perform character encoding """
|
||||
|
Reference in New Issue
Block a user