cleaned up charset and unicode functionality, aligned BIT type handling with MySQLdb, disambiguated length coded binary and strings, raised exception after connection is close()'d

This commit is contained in:
Pete Hunt
2010-12-10 05:32:06 +00:00
parent bfa2cf574d
commit d51400138c
8 changed files with 449 additions and 222 deletions

View File

@@ -1,5 +1,3 @@
MBLENGTH = {
8:1,
33:3,
@@ -7,4 +5,170 @@ MBLENGTH = {
91:2
}
class Charset:
def __init__(self, id, name, collation, is_default):
self.id, self.name, self.collation = id, name, collation
self.is_default = is_default == 'Yes'
class Charsets:
def __init__(self):
self._by_id = {}
def add(self, c):
self._by_id[c.id] = c
def by_id(self, id):
return self._by_id[id]
def by_name(self, name):
for c in self._by_id.values():
if c.name == name and c.is_default:
return c
_charsets = Charsets()
"""
Generated with:
mysql -N -s -e "select id, character_set_name, collation_name, is_default
from information_schema.collations order by id;" | python -c "import sys
for l in sys.stdin.readlines():
id, name, collation, is_default = l.split(chr(9))
print '_charsets.add(Charset(%s, \'%s\', \'%s\', \'%s\'))' \
% (id, name, collation, is_default.strip())
"
"""
_charsets.add(Charset(1, 'big5', 'big5_chinese_ci', 'Yes'))
_charsets.add(Charset(2, 'latin2', 'latin2_czech_cs', ''))
_charsets.add(Charset(3, 'dec8', 'dec8_swedish_ci', 'Yes'))
_charsets.add(Charset(4, 'cp850', 'cp850_general_ci', 'Yes'))
_charsets.add(Charset(5, 'latin1', 'latin1_german1_ci', ''))
_charsets.add(Charset(6, 'hp8', 'hp8_english_ci', 'Yes'))
_charsets.add(Charset(7, 'koi8r', 'koi8r_general_ci', 'Yes'))
_charsets.add(Charset(8, 'latin1', 'latin1_swedish_ci', 'Yes'))
_charsets.add(Charset(9, 'latin2', 'latin2_general_ci', 'Yes'))
_charsets.add(Charset(10, 'swe7', 'swe7_swedish_ci', 'Yes'))
_charsets.add(Charset(11, 'ascii', 'ascii_general_ci', 'Yes'))
_charsets.add(Charset(12, 'ujis', 'ujis_japanese_ci', 'Yes'))
_charsets.add(Charset(13, 'sjis', 'sjis_japanese_ci', 'Yes'))
_charsets.add(Charset(14, 'cp1251', 'cp1251_bulgarian_ci', ''))
_charsets.add(Charset(15, 'latin1', 'latin1_danish_ci', ''))
_charsets.add(Charset(16, 'hebrew', 'hebrew_general_ci', 'Yes'))
_charsets.add(Charset(18, 'tis620', 'tis620_thai_ci', 'Yes'))
_charsets.add(Charset(19, 'euckr', 'euckr_korean_ci', 'Yes'))
_charsets.add(Charset(20, 'latin7', 'latin7_estonian_cs', ''))
_charsets.add(Charset(21, 'latin2', 'latin2_hungarian_ci', ''))
_charsets.add(Charset(22, 'koi8u', 'koi8u_general_ci', 'Yes'))
_charsets.add(Charset(23, 'cp1251', 'cp1251_ukrainian_ci', ''))
_charsets.add(Charset(24, 'gb2312', 'gb2312_chinese_ci', 'Yes'))
_charsets.add(Charset(25, 'greek', 'greek_general_ci', 'Yes'))
_charsets.add(Charset(26, 'cp1250', 'cp1250_general_ci', 'Yes'))
_charsets.add(Charset(27, 'latin2', 'latin2_croatian_ci', ''))
_charsets.add(Charset(28, 'gbk', 'gbk_chinese_ci', 'Yes'))
_charsets.add(Charset(29, 'cp1257', 'cp1257_lithuanian_ci', ''))
_charsets.add(Charset(30, 'latin5', 'latin5_turkish_ci', 'Yes'))
_charsets.add(Charset(31, 'latin1', 'latin1_german2_ci', ''))
_charsets.add(Charset(32, 'armscii8', 'armscii8_general_ci', 'Yes'))
_charsets.add(Charset(33, 'utf8', 'utf8_general_ci', 'Yes'))
_charsets.add(Charset(34, 'cp1250', 'cp1250_czech_cs', ''))
_charsets.add(Charset(35, 'ucs2', 'ucs2_general_ci', 'Yes'))
_charsets.add(Charset(36, 'cp866', 'cp866_general_ci', 'Yes'))
_charsets.add(Charset(37, 'keybcs2', 'keybcs2_general_ci', 'Yes'))
_charsets.add(Charset(38, 'macce', 'macce_general_ci', 'Yes'))
_charsets.add(Charset(39, 'macroman', 'macroman_general_ci', 'Yes'))
_charsets.add(Charset(40, 'cp852', 'cp852_general_ci', 'Yes'))
_charsets.add(Charset(41, 'latin7', 'latin7_general_ci', 'Yes'))
_charsets.add(Charset(42, 'latin7', 'latin7_general_cs', ''))
_charsets.add(Charset(43, 'macce', 'macce_bin', ''))
_charsets.add(Charset(44, 'cp1250', 'cp1250_croatian_ci', ''))
_charsets.add(Charset(47, 'latin1', 'latin1_bin', ''))
_charsets.add(Charset(48, 'latin1', 'latin1_general_ci', ''))
_charsets.add(Charset(49, 'latin1', 'latin1_general_cs', ''))
_charsets.add(Charset(50, 'cp1251', 'cp1251_bin', ''))
_charsets.add(Charset(51, 'cp1251', 'cp1251_general_ci', 'Yes'))
_charsets.add(Charset(52, 'cp1251', 'cp1251_general_cs', ''))
_charsets.add(Charset(53, 'macroman', 'macroman_bin', ''))
_charsets.add(Charset(57, 'cp1256', 'cp1256_general_ci', 'Yes'))
_charsets.add(Charset(58, 'cp1257', 'cp1257_bin', ''))
_charsets.add(Charset(59, 'cp1257', 'cp1257_general_ci', 'Yes'))
_charsets.add(Charset(63, 'binary', 'binary', 'Yes'))
_charsets.add(Charset(64, 'armscii8', 'armscii8_bin', ''))
_charsets.add(Charset(65, 'ascii', 'ascii_bin', ''))
_charsets.add(Charset(66, 'cp1250', 'cp1250_bin', ''))
_charsets.add(Charset(67, 'cp1256', 'cp1256_bin', ''))
_charsets.add(Charset(68, 'cp866', 'cp866_bin', ''))
_charsets.add(Charset(69, 'dec8', 'dec8_bin', ''))
_charsets.add(Charset(70, 'greek', 'greek_bin', ''))
_charsets.add(Charset(71, 'hebrew', 'hebrew_bin', ''))
_charsets.add(Charset(72, 'hp8', 'hp8_bin', ''))
_charsets.add(Charset(73, 'keybcs2', 'keybcs2_bin', ''))
_charsets.add(Charset(74, 'koi8r', 'koi8r_bin', ''))
_charsets.add(Charset(75, 'koi8u', 'koi8u_bin', ''))
_charsets.add(Charset(77, 'latin2', 'latin2_bin', ''))
_charsets.add(Charset(78, 'latin5', 'latin5_bin', ''))
_charsets.add(Charset(79, 'latin7', 'latin7_bin', ''))
_charsets.add(Charset(80, 'cp850', 'cp850_bin', ''))
_charsets.add(Charset(81, 'cp852', 'cp852_bin', ''))
_charsets.add(Charset(82, 'swe7', 'swe7_bin', ''))
_charsets.add(Charset(83, 'utf8', 'utf8_bin', ''))
_charsets.add(Charset(84, 'big5', 'big5_bin', ''))
_charsets.add(Charset(85, 'euckr', 'euckr_bin', ''))
_charsets.add(Charset(86, 'gb2312', 'gb2312_bin', ''))
_charsets.add(Charset(87, 'gbk', 'gbk_bin', ''))
_charsets.add(Charset(88, 'sjis', 'sjis_bin', ''))
_charsets.add(Charset(89, 'tis620', 'tis620_bin', ''))
_charsets.add(Charset(90, 'ucs2', 'ucs2_bin', ''))
_charsets.add(Charset(91, 'ujis', 'ujis_bin', ''))
_charsets.add(Charset(92, 'geostd8', 'geostd8_general_ci', 'Yes'))
_charsets.add(Charset(93, 'geostd8', 'geostd8_bin', ''))
_charsets.add(Charset(94, 'latin1', 'latin1_spanish_ci', ''))
_charsets.add(Charset(95, 'cp932', 'cp932_japanese_ci', 'Yes'))
_charsets.add(Charset(96, 'cp932', 'cp932_bin', ''))
_charsets.add(Charset(97, 'eucjpms', 'eucjpms_japanese_ci', 'Yes'))
_charsets.add(Charset(98, 'eucjpms', 'eucjpms_bin', ''))
_charsets.add(Charset(99, 'cp1250', 'cp1250_polish_ci', ''))
_charsets.add(Charset(128, 'ucs2', 'ucs2_unicode_ci', ''))
_charsets.add(Charset(129, 'ucs2', 'ucs2_icelandic_ci', ''))
_charsets.add(Charset(130, 'ucs2', 'ucs2_latvian_ci', ''))
_charsets.add(Charset(131, 'ucs2', 'ucs2_romanian_ci', ''))
_charsets.add(Charset(132, 'ucs2', 'ucs2_slovenian_ci', ''))
_charsets.add(Charset(133, 'ucs2', 'ucs2_polish_ci', ''))
_charsets.add(Charset(134, 'ucs2', 'ucs2_estonian_ci', ''))
_charsets.add(Charset(135, 'ucs2', 'ucs2_spanish_ci', ''))
_charsets.add(Charset(136, 'ucs2', 'ucs2_swedish_ci', ''))
_charsets.add(Charset(137, 'ucs2', 'ucs2_turkish_ci', ''))
_charsets.add(Charset(138, 'ucs2', 'ucs2_czech_ci', ''))
_charsets.add(Charset(139, 'ucs2', 'ucs2_danish_ci', ''))
_charsets.add(Charset(140, 'ucs2', 'ucs2_lithuanian_ci', ''))
_charsets.add(Charset(141, 'ucs2', 'ucs2_slovak_ci', ''))
_charsets.add(Charset(142, 'ucs2', 'ucs2_spanish2_ci', ''))
_charsets.add(Charset(143, 'ucs2', 'ucs2_roman_ci', ''))
_charsets.add(Charset(144, 'ucs2', 'ucs2_persian_ci', ''))
_charsets.add(Charset(145, 'ucs2', 'ucs2_esperanto_ci', ''))
_charsets.add(Charset(146, 'ucs2', 'ucs2_hungarian_ci', ''))
_charsets.add(Charset(192, 'utf8', 'utf8_unicode_ci', ''))
_charsets.add(Charset(193, 'utf8', 'utf8_icelandic_ci', ''))
_charsets.add(Charset(194, 'utf8', 'utf8_latvian_ci', ''))
_charsets.add(Charset(195, 'utf8', 'utf8_romanian_ci', ''))
_charsets.add(Charset(196, 'utf8', 'utf8_slovenian_ci', ''))
_charsets.add(Charset(197, 'utf8', 'utf8_polish_ci', ''))
_charsets.add(Charset(198, 'utf8', 'utf8_estonian_ci', ''))
_charsets.add(Charset(199, 'utf8', 'utf8_spanish_ci', ''))
_charsets.add(Charset(200, 'utf8', 'utf8_swedish_ci', ''))
_charsets.add(Charset(201, 'utf8', 'utf8_turkish_ci', ''))
_charsets.add(Charset(202, 'utf8', 'utf8_czech_ci', ''))
_charsets.add(Charset(203, 'utf8', 'utf8_danish_ci', ''))
_charsets.add(Charset(204, 'utf8', 'utf8_lithuanian_ci', ''))
_charsets.add(Charset(205, 'utf8', 'utf8_slovak_ci', ''))
_charsets.add(Charset(206, 'utf8', 'utf8_spanish2_ci', ''))
_charsets.add(Charset(207, 'utf8', 'utf8_roman_ci', ''))
_charsets.add(Charset(208, 'utf8', 'utf8_persian_ci', ''))
_charsets.add(Charset(209, 'utf8', 'utf8_esperanto_ci', ''))
_charsets.add(Charset(210, 'utf8', 'utf8_hungarian_ci', ''))
def charset_by_name(name):
return _charsets.by_name(name)
def charset_by_id(id):
return _charsets.by_id(id)

View File

@@ -1,8 +1,6 @@
# Python implementation of the MySQL client-server protocol
# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
import re
try:
import hashlib
sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs)
@@ -22,13 +20,13 @@ try:
except ImportError:
import StringIO
from charset import MBLENGTH
from charset import MBLENGTH, charset_by_name, charset_by_id
from cursors import Cursor
from constants import FIELD_TYPE
from constants import FIELD_TYPE, FLAG
from constants import SERVER_STATUS
from constants.CLIENT import *
from constants.COMMAND import *
from converters import escape_item, encoders, decoders, field_decoders
from converters import escape_item, encoders, decoders
from err import raise_mysql_exception, Warning, Error, \
InterfaceError, DataError, DatabaseError, OperationalError, \
IntegrityError, InternalError, NotSupportedError, ProgrammingError
@@ -64,7 +62,8 @@ def dump_packet(data):
dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0]
for d in dump_data:
print ' '.join(map(lambda x:"%02X" % ord(x), d)) + \
' ' * (16 - len(d)) + ' ' * 2 + ' '.join(map(lambda x:"%s" % is_ascii(x), d))
' ' * (16 - len(d)) + ' ' * 2 + \
' '.join(map(lambda x:"%s" % is_ascii(x), d))
print "-" * 88
print ""
@@ -84,7 +83,8 @@ def _my_crypt(message1, message2):
length = len(message1)
result = struct.pack('B', length)
for i in xrange(length):
x = (struct.unpack('B', message1[i:i+1])[0] ^ struct.unpack('B', message2[i:i+1])[0])
x = (struct.unpack('B', message1[i:i+1])[0] ^ \
struct.unpack('B', message2[i:i+1])[0])
result += struct.pack('B', x)
return result
@@ -161,9 +161,10 @@ def unpack_int64(n):
(struct.unpack('B',n[6])[0] << 48) + (struct.unpack('B',n[7])[0]<<56)
def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
raise
err = errorclass, errorvalue
if DEBUG:
raise
if cursor:
cursor.messages.append(err)
else:
@@ -271,8 +272,8 @@ class MysqlPacket(object):
"""
return self.__data[position:(position+length)]
def read_coded_length(self):
"""Read a 'Length Coded' number from the data buffer.
def read_length_coded_binary(self):
"""Read a 'Length Coded Binary' number from the data buffer.
Length coded numbers can be anywhere from 1 to 9 bytes depending
on the value of the first byte.
@@ -290,16 +291,17 @@ class MysqlPacket(object):
# TODO: what was 'longlong'? confirm it wasn't used?
return unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
def read_length_coded_binary(self):
"""Read a 'Length Coded Binary' from the data buffer.
def read_length_coded_string(self):
"""Read a 'Length Coded String' from the data buffer.
A 'Length Coded Binary' consists first of a length coded
A 'Length Coded String' consists first of a length coded
(unsigned, positive) integer represented in 1-9 bytes followed by
that many bytes of binary data. (For example "cat" would be "3cat".)
"""
length = self.read_coded_length()
if length:
return self.read(length)
length = self.read_length_coded_binary()
if length is None:
return None
return self.read(length)
def is_ok_packet(self):
return ord(self.get_bytes(0)) == 0
@@ -342,19 +344,17 @@ class FieldDescriptorPacket(MysqlPacket):
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
"""
self.catalog = self.read_length_coded_binary()
self.db = self.read_length_coded_binary()
self.table_name = self.read_length_coded_binary()
self.org_table = self.read_length_coded_binary()
self.name = self.read_length_coded_binary()
self.org_name = self.read_length_coded_binary()
self.catalog = self.read_length_coded_string()
self.db = self.read_length_coded_string()
self.table_name = self.read_length_coded_string()
self.org_table = self.read_length_coded_string()
self.name = self.read_length_coded_string()
self.org_name = self.read_length_coded_string()
self.advance(1) # non-null filler
self.charsetnr = struct.unpack('<h', self.read(2))[0]
self.length = struct.unpack('<i', self.read(4))[0]
self.charsetnr = struct.unpack('<H', self.read(2))[0]
self.length = struct.unpack('<I', self.read(4))[0]
self.type_code = ord(self.read(1))
flags = struct.unpack('<h', self.read(2))
# TODO: what is going on here with this flag parsing???
self.flags = int(("%02X" % flags)[1:], 16)
self.flags = struct.unpack('<H', self.read(2))[0]
self.scale = ord(self.read(1)) # "decimals"
self.advance(2) # filler (always 0x00)
@@ -401,8 +401,8 @@ class Connection(object):
def __init__(self, host="localhost", user=None, passwd="",
db=None, port=3306, unix_socket=None,
charset=DEFAULT_CHARSET, sql_mode=None,
read_default_file=None, conv=decoders, use_unicode=True,
charset='', sql_mode=None,
read_default_file=None, conv=decoders, use_unicode=False,
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, ssl=None, read_default_group=None,
compress=None, named_pipe=None):
@@ -457,7 +457,7 @@ class Connection(object):
return cfg.get("client",key)
except:
return default
user = _config("user",user)
passwd = _config("password",passwd)
host = _config("host", host)
@@ -465,15 +465,22 @@ class Connection(object):
unix_socket = _config("socket",unix_socket)
port = _config("port", port)
charset = _config("default-character-set", charset)
self.host = host
self.port = port
self.user = user
self.password = passwd
self.db = db
self.unix_socket = unix_socket
self.use_unicode = use_unicode
self.charset = DEFAULT_CHARSET
if charset:
self.charset = charset
self.use_unicode = True
else:
self.charset = DEFAULT_CHARSET
self.use_unicode = False
if use_unicode:
self.use_unicode = use_unicode
client_flag |= CAPABILITIES
client_flag |= MULTI_STATEMENTS
@@ -483,20 +490,19 @@ class Connection(object):
self.cursorclass = cursorclass
self.connect_timeout = connect_timeout
self._connect()
self.set_charset_set(charset)
self.messages = []
self.set_charset(charset)
self.encoders = encoders
self.decoders = conv
self.field_decoders = field_decoders
self._affected_rows = 0
self.host_info = "Not connected"
self.autocommit(False)
if sql_mode is not None:
c = self.cursor()
c.execute("SET sql_mode=%s", (sql_mode,))
@@ -506,21 +512,20 @@ class Connection(object):
if init_command is not None:
c = self.cursor()
c.execute(init_command)
self.commit()
def close(self):
''' Send the quit message and close the socket '''
try:
if self.socket:
send_data = struct.pack('<i',1) + COM_QUIT
sock = self.socket
sock.send(send_data)
sock.close()
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
self.socket.send(send_data)
self.socket.close()
self.socket = None
else:
self.errorhandler(None, InterfaceError, "(0, '')")
def autocommit(self, value):
''' Set whether or not to commit after every execute() '''
try:
@@ -560,7 +565,7 @@ class Connection(object):
def cursor(self):
''' Create a new cursor to execute queries with '''
return self.cursorclass(self)
def __enter__(self):
''' Context manager that returns a Cursor '''
return self.cursor()
@@ -577,7 +582,7 @@ class Connection(object):
self._execute_command(COM_QUERY, sql)
self._affected_rows = self._read_query_result()
return self._affected_rows
def next_result(self):
self._affected_rows = self._read_query_result()
return self._affected_rows
@@ -595,7 +600,7 @@ class Connection(object):
return
pkt = self.read_packet()
return pkt.is_ok_packet()
def ping(self, reconnect=True):
''' Check if the server is alive '''
try:
@@ -612,14 +617,13 @@ class Connection(object):
pkt = self.read_packet()
return pkt.is_ok_packet()
def set_charset_set(self, charset):
def set_charset(self, charset):
try:
sock = self.socket
if charset and self.charset != charset:
if charset:
self._execute_command(COM_QUERY, "SET NAMES %s" %
self.escape(charset))
self.read_packet()
self.charset = charset
self.charset = charset
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
@@ -647,7 +651,7 @@ class Connection(object):
self._request_authentication()
except socket.error, e:
raise OperationalError(2003, "Can't connect to MySQL server on %r (%d)" % (self.host, e.args[0]))
def read_packet(self, packet_type=MysqlPacket):
"""Read an entire "mysql packet" in its entirety from the network
and return a MysqlPacket type that represents the results."""
@@ -673,7 +677,9 @@ class Connection(object):
pckt_no = 0
while len(buf) >= MAX_PACKET_LENGTH:
header = struct.pack('<i', MAX_PACKET_LENGTH)[:-1]+chr(pckt_no)
self.socket.send(header+buf[:MAX_PACKET_LENGTH])
send_data = header + buf[:MAX_PACKET_LENGTH]
self.socket.send(send_data)
if DEBUG: dump_packet(send_data)
buf = buf[MAX_PACKET_LENGTH:]
pckt_no += 1
header = struct.pack('<i', len(buf))[:-1]+chr(pckt_no)
@@ -683,13 +689,12 @@ class Connection(object):
#sock = self.socket
#sock.send(send_data)
if DEBUG: dump_packet(send_data)
#
def _execute_command(self, command, sql):
self._send_command(command, sql)
def _request_authentication(self):
sock = self.socket
self._send_authentication()
def _send_authentication(self):
@@ -700,9 +705,12 @@ class Connection(object):
if self.user is None:
raise ValueError, "Did not specify a username"
data_init = (struct.pack('<i', self.client_flag)) \
+ "\0\0\0\x01" + '\x08' + '\0'*23
charset_id = charset_by_name(self.charset).id
self.user = self.user.encode(self.charset)
data_init = struct.pack('<i', self.client_flag) + "\0\0\0\x01" + \
chr(charset_id) + '\0'*23
next_packet = 1
@@ -722,13 +730,14 @@ class Connection(object):
data = data_init + self.user+"\0" + _scramble(self.password, self.salt)
if self.db:
data += self.db.encode(self.charset) + "\0"
self.db = self.db.encode(self.charset)
data += self.db + "\0"
data = pack_int24(len(data)) + chr(next_packet) + data
next_packet += 2
if DEBUG: dump_packet(data)
sock.send(data)
auth_packet = MysqlPacket(sock)
@@ -743,13 +752,13 @@ class Connection(object):
#raise NotImplementedError, "old_passwords are not supported. Check to see if mysqld was started with --old-passwords, if old-passwords=1 in a my.cnf file, or if there are some short hashes in your mysql.user table."
data = _scramble_323(self.password, self.salt) + "\0"
data = pack_int24(len(data)) + chr(next_packet) + data
sock.send(data)
auth_packet = MysqlPacket(sock)
auth_packet.check_error()
if DEBUG: auth_packet.dump()
# _mysql support
def thread_id(self):
return self.server_thread_id[0]
@@ -759,10 +768,10 @@ class Connection(object):
def get_host_info(self):
return self.host_info
def get_proto_info(self):
return self.protocol_version
def _get_server_information(self):
sock = self.socket
i = 0
@@ -773,27 +782,28 @@ class Connection(object):
#packet_len = ord(data[i:i+1])
#i += 4
self.protocol_version = ord(data[i:i+1])
i += 1
server_end = data.find("\0", i)
self.server_version = data[i:server_end]
i = server_end + 1
self.server_thread_id = struct.unpack('<h', data[i:i+2])
i += 4
self.salt = data[i:i+8]
i += 9
if len(data) >= i + 1:
i += 1
self.server_capabilities = struct.unpack('<h', data[i:i+2])[0]
i += 1
self.server_language = ord(data[i:i+1])
i += 16
self.server_charset = charset_by_id(self.server_language).name
i += 16
if len(data) >= i+12-1:
rest_salt = data[i:i+12]
self.salt += rest_salt
@@ -840,8 +850,8 @@ class MySQLResult(object):
def _read_ok_packet(self):
self.first_packet.advance(1) # field_count (always '0')
self.affected_rows = self.first_packet.read_coded_length()
self.insert_id = self.first_packet.read_coded_length()
self.affected_rows = self.first_packet.read_length_coded_binary()
self.insert_id = self.first_packet.read_length_coded_binary()
self.server_status = struct.unpack('<H', self.first_packet.read(2))[0]
self.warning_count = struct.unpack('<H', self.first_packet.read(2))[0]
self.message = self.first_packet.read_all()
@@ -871,14 +881,7 @@ class MySQLResult(object):
converter = self.connection.decoders[field.type_code]
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
data = packet.read_length_coded_binary()
converted = None
if data != None:
converted = converter(data)
else:
converter = self.connection.field_decoders[field.type_code]
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
data = packet.read_length_coded_binary()
data = packet.read_length_coded_string()
converted = None
if data != None:
converted = converter(self.connection, field, data)

View File

@@ -29,4 +29,4 @@ STRING = 254
GEOMETRY = 255
CHAR = TINY
INTERVAL = ENUM
INTERVAL = ENUM

View File

@@ -1,11 +1,9 @@
import re
import datetime
import time
import array
import struct
from times import Date, Time, TimeDelta, Timestamp
from constants import FIELD_TYPE
from constants import FIELD_TYPE, FLAG
from charset import charset_by_id
try:
set
@@ -20,8 +18,16 @@ ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
def escape_item(val, charset):
if type(val) in [tuple, list, set]:
return escape_sequence(val, charset)
if type(val) is dict:
return escape_dict(val, charset)
encoder = encoders[type(val)]
return encoder(val, charset)
val = encoder(val)
if type(val) is str:
return val
val = val.encode(charset)
return val
def escape_dict(val, charset):
n = {}
@@ -37,99 +43,96 @@ def escape_sequence(val, charset):
n.append(quoted)
return tuple(n)
def escape_bool(value, charset):
return str(int(value)).encode(charset)
def escape_set(val, charset):
val = map(lambda x: escape_item(x, charset), val)
return ','.join(val)
def escape_object(value, charset):
return str(value).encode(charset)
def escape_bool(value):
return str(int(value))
def escape_object(value):
return str(value)
escape_int = escape_long = escape_object
def escape_float(value, charset):
return ('%.15g' % value).encode(charset)
def escape_float(value):
return ('%.15g' % value)
def escape_string(value, charset):
r = ("'%s'" % ESCAPE_REGEX.sub(
lambda match: ESCAPE_MAP.get(match.group(0)), value))
# TODO: make sure that encodings are handled correctly here.
# Since we may be dealing with binary data, the encoding
# routine below is commented out.
#if not charset is None:
# r = r.encode(charset)
return r
def escape_unicode(value, charset):
# pass None as the charset because we already encode it
return escape_string(value.encode(charset), None)
def escape_string(value):
return ("'%s'" % ESCAPE_REGEX.sub(
lambda match: ESCAPE_MAP.get(match.group(0)), value))
def escape_None(value, charset):
return 'NULL'.encode(charset)
def escape_unicode(value):
return escape_string(value)
def escape_timedelta(obj, charset):
def escape_None(value):
return 'NULL'
def escape_timedelta(obj):
seconds = int(obj.seconds) % 60
minutes = int(obj.seconds // 60) % 60
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds), charset)
return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds))
def escape_time(obj, charset):
def escape_time(obj):
s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute),
int(obj.second))
if obj.microsecond:
s += ".%f" % obj.microsecond
return escape_string(s, charset)
return escape_string(s)
def escape_datetime(obj, charset):
return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"), charset)
def escape_datetime(obj):
return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"))
def escape_date(obj, charset):
return escape_string(obj.strftime("%Y-%m-%d"), charset)
def escape_date(obj):
return escape_string(obj.strftime("%Y-%m-%d"))
def escape_struct_time(obj, charset):
return escape_datetime(datetime.datetime(*obj[:6]), charset)
def escape_struct_time(obj):
return escape_datetime(datetime.datetime(*obj[:6]))
def convert_datetime(obj):
def convert_datetime(connection, field, obj):
"""Returns a DATETIME or TIMESTAMP column value as a datetime object:
>>> datetime_or_None('2007-02-25 23:06:20')
datetime.datetime(2007, 2, 25, 23, 6, 20)
>>> datetime_or_None('2007-02-25T23:06:20')
datetime.datetime(2007, 2, 25, 23, 6, 20)
Illegal values are returned as None:
>>> datetime_or_None('2007-02-31T23:06:20') is None
True
>>> datetime_or_None('0000-00-00 00:00:00') is None
True
"""
if ' ' in obj:
sep = ' '
elif 'T' in obj:
sep = 'T'
else:
return convert_date(obj)
return convert_date(connection, field, obj)
try:
ymd, hms = obj.split(sep, 1)
return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ])
except ValueError:
return convert_date(obj)
return convert_date(connection, field, obj)
def convert_timedelta(obj):
def convert_timedelta(connection, field, obj):
"""Returns a TIME column as a timedelta object:
>>> timedelta_or_None('25:06:17')
datetime.timedelta(1, 3977)
>>> timedelta_or_None('-25:06:17')
datetime.timedelta(-2, 83177)
Illegal values are returned as None:
>>> timedelta_or_None('random crap') is None
True
Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
can accept values as (+|-)DD HH:MM:SS. The latter format will not
be parsed correctly by this function.
@@ -147,23 +150,23 @@ def convert_timedelta(obj):
except ValueError:
return None
def convert_time(obj):
def convert_time(connection, field, obj):
"""Returns a TIME column as a time object:
>>> time_or_None('15:06:17')
datetime.time(15, 6, 17)
Illegal values are returned as None:
>>> time_or_None('-25:06:17') is None
True
>>> time_or_None('random crap') is None
True
Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
can accept values as (+|-)DD HH:MM:SS. The latter format will not
be parsed correctly by this function.
Also note that MySQL's TIME column corresponds more closely to
Python's timedelta and not time. However if you want TIME columns
to be treated as time-of-day and not a time offset, then you can
@@ -172,53 +175,54 @@ def convert_time(obj):
from math import modf
try:
hour, minute, second = obj.split(':')
return datetime.time(hour=int(hour), minute=int(minute), second=int(second),
microsecond=int(modf(float(second))[0]*1000000))
return datetime.time(hour=int(hour), minute=int(minute),
second=int(second),
microsecond=int(modf(float(second))[0]*1000000))
except ValueError:
return None
def convert_date(obj):
def convert_date(connection, field, obj):
"""Returns a DATE column as a date object:
>>> date_or_None('2007-02-26')
datetime.date(2007, 2, 26)
Illegal values are returned as None:
>>> date_or_None('2007-02-31') is None
True
>>> date_or_None('0000-00-00') is None
True
"""
try:
return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
except ValueError:
return None
def convert_mysql_timestamp(timestamp):
def convert_mysql_timestamp(connection, field, timestamp):
"""Convert a MySQL TIMESTAMP to a Timestamp object.
MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME:
>>> mysql_timestamp_converter('2007-02-25 22:32:17')
datetime.datetime(2007, 2, 25, 22, 32, 17)
MySQL < 4.1 uses a big string of numbers:
>>> mysql_timestamp_converter('20070225223217')
datetime.datetime(2007, 2, 25, 22, 32, 17)
Illegal values are returned as None:
>>> mysql_timestamp_converter('2007-02-31 22:32:17') is None
True
>>> mysql_timestamp_converter('00000000000000') is None
True
"""
if timestamp[4] == '-':
return convert_datetime(timestamp)
return convert_datetime(connection, field, timestamp)
timestamp += "0"*(14-len(timestamp)) # padding
year, month, day, hour, minute, second = \
int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \
@@ -229,13 +233,38 @@ def convert_mysql_timestamp(timestamp):
return None
def convert_set(s):
# TODO: this may not be correct
return set(s.split(","))
def convert_bit(b):
b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
return struct.unpack(">Q", b)[0]
def convert_bit(connection, field, b):
#b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
#return struct.unpack(">Q", b)[0]
#
# the snippet above is right, but MySQLdb doesn't process bits,
# so we shouldn't either
return b
def convert_characters(connection, field, data):
if field.flags & FLAG.SET:
return convert_set(data)
if field.flags & FLAG.BINARY:
return data
field_charset = charset_by_id(field.charsetnr).name
if connection.use_unicode:
data = data.decode(field_charset)
elif connection.charset != field_charset:
data = data.decode(field_charset)
data = data.encode(connection.charset)
return data
def convert_int(connection, field, data):
return int(data)
def convert_long(connection, field, data):
return long(data)
def convert_float(connection, field, data):
return float(data)
encoders = {
bool: escape_bool,
int: escape_int,
@@ -257,21 +286,28 @@ encoders = {
decoders = {
FIELD_TYPE.BIT: convert_bit,
FIELD_TYPE.TINY: int,
FIELD_TYPE.SHORT: int,
FIELD_TYPE.LONG: long,
FIELD_TYPE.FLOAT: float,
FIELD_TYPE.DOUBLE: float,
FIELD_TYPE.DECIMAL: float,
FIELD_TYPE.NEWDECIMAL: float,
FIELD_TYPE.LONGLONG: long,
FIELD_TYPE.INT24: int,
FIELD_TYPE.YEAR: int,
FIELD_TYPE.TINY: convert_int,
FIELD_TYPE.SHORT: convert_int,
FIELD_TYPE.LONG: convert_long,
FIELD_TYPE.FLOAT: convert_float,
FIELD_TYPE.DOUBLE: convert_float,
FIELD_TYPE.DECIMAL: convert_float,
FIELD_TYPE.NEWDECIMAL: convert_float,
FIELD_TYPE.LONGLONG: convert_long,
FIELD_TYPE.INT24: convert_int,
FIELD_TYPE.YEAR: convert_int,
FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp,
FIELD_TYPE.DATETIME: convert_datetime,
FIELD_TYPE.TIME: convert_timedelta,
FIELD_TYPE.DATE: convert_date,
FIELD_TYPE.SET: convert_set,
FIELD_TYPE.BLOB: convert_characters,
FIELD_TYPE.TINY_BLOB: convert_characters,
FIELD_TYPE.MEDIUM_BLOB: convert_characters,
FIELD_TYPE.LONG_BLOB: convert_characters,
FIELD_TYPE.STRING: convert_characters,
FIELD_TYPE.VAR_STRING: convert_characters,
FIELD_TYPE.VARCHAR: convert_characters,
#FIELD_TYPE.BLOB: str,
#FIELD_TYPE.STRING: str,
#FIELD_TYPE.VAR_STRING: str,
@@ -279,28 +315,13 @@ decoders = {
}
conversions = decoders # for MySQLdb compatibility
def decode_characters(connection, field, data):
if field.charsetnr == 63 or not connection.use_unicode:
# binary data, leave it alone
return data
return data.decode(connection.charset)
# These take a field instance rather than just the data.
field_decoders = {
FIELD_TYPE.BLOB: decode_characters,
FIELD_TYPE.TINY_BLOB: decode_characters,
FIELD_TYPE.MEDIUM_BLOB: decode_characters,
FIELD_TYPE.LONG_BLOB: decode_characters,
FIELD_TYPE.STRING: decode_characters,
FIELD_TYPE.VAR_STRING: decode_characters,
FIELD_TYPE.VARCHAR: decode_characters,
}
try:
# python version > 2.3
from decimal import Decimal
decoders[FIELD_TYPE.DECIMAL] = Decimal
decoders[FIELD_TYPE.NEWDECIMAL] = Decimal
def convert_decimal(connection, field, data):
return Decimal(data)
decoders[FIELD_TYPE.DECIMAL] = convert_decimal
decoders[FIELD_TYPE.NEWDECIMAL] = convert_decimal
def escape_decimal(obj, charset):
return unicode(obj).encode(charset)

View File

@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import struct
import re
@@ -52,23 +53,23 @@ class Cursor(object):
if not self.connection:
self.errorhandler(self, ProgrammingError, "cursor closed")
return self.connection
def _check_executed(self):
if not self._executed:
self.errorhandler(self, ProgrammingError, "execute() first")
def setinputsizes(self, *args):
"""Does nothing, required by DB API."""
def setoutputsizes(self, *args):
"""Does nothing, required by DB API."""
def nextset(self):
''' Get the next query set '''
if self._executed:
self.fetchall()
del self.messages[:]
if not self._has_next:
return None
connection = self._get_db()
@@ -79,11 +80,11 @@ class Cursor(object):
def execute(self, query, args=None):
''' Execute a query '''
from sys import exc_info
conn = self._get_db()
charset = conn.charset
del self.messages[:]
# this ordering is good because conn.escape() returns
# an encoded string.
if isinstance(query, unicode):
@@ -91,7 +92,7 @@ class Cursor(object):
if args is not None:
query = query % conn.escape(args)
result = 0
try:
result = self._query(query)
@@ -103,7 +104,7 @@ class Cursor(object):
self._executed = query
return result
def executemany(self, query, args):
''' Run several data against one query '''
del self.messages[:]
@@ -113,30 +114,66 @@ class Cursor(object):
charset = conn.charset
if isinstance(query, unicode):
query = query.encode(charset)
self.rowcount = sum([ self.execute(query, arg) for arg in args ])
return self.rowcount
def callproc(self, procname, args=()):
''' Call a stored procedure. Take care to ensure that procname is
properly escaped. '''
if not isinstance(args, tuple):
args = (args,)
"""Execute stored procedure procname with args
argstr = ("%s," * len(args))[:-1]
procname -- string, name of procedure to execute on server
return self.execute("CALL `%s`(%s)" % (procname, argstr), args)
args -- Sequence of parameters to use with procedure
Returns the original args.
Compatibility warning: PEP-249 specifies that any modified
parameters must be returned. This is currently impossible
as they are only available by storing them in a server
variable and then retrieved by a query. Since stored
procedures return zero or more result sets, there is no
reliable way to get at OUT or INOUT parameters via callproc.
The server variables are named @_procname_n, where procname
is the parameter above and n is the position of the parameter
(from zero). Once all result sets generated by the procedure
have been fetched, you can issue a SELECT @_procname_0, ...
query using .execute() to get any OUT or INOUT values.
Compatibility warning: The act of calling a stored procedure
itself creates an empty result set. This appears after any
result sets generated by the procedure. This is non-standard
behavior with respect to the DB-API. Be sure to use nextset()
to advance through all result sets; otherwise you may get
disconnected.
"""
conn = self._get_db()
for index, arg in enumerate(args):
q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg))
if isinstance(q, unicode):
q = q.encode(conn.charset)
self._query(q)
self.nextset()
q = "CALL %s(%s)" % (procname,
','.join(['@_%s_%d' % (procname, i)
for i in range(len(args))]))
if isinstance(q, unicode):
q = q.encode(conn.charset)
self._query(q)
self._executed = q
return args
def fetchone(self):
''' Fetch the next row '''
self._check_executed()
self._check_executed()
if self._rows is None or self.rownumber >= len(self._rows):
return None
result = self._rows[self.rownumber]
self.rownumber += 1
return result
def fetchmany(self, size=None):
''' Fetch several rows '''
self._check_executed()
@@ -158,15 +195,15 @@ class Cursor(object):
result = self._rows
self.rownumber = len(self._rows)
return result
def scroll(self, value, mode='relative'):
self._check_executed()
if mode == 'relative':
r = self.rownumber + value
elif mode == 'absolute':
r = value
else:
self.errorhandler(self, ProgrammingError,
self.errorhandler(self, ProgrammingError,
"unknown scroll mode %s" % mode)
if r < 0 or r >= len(self._rows):
@@ -179,23 +216,23 @@ class Cursor(object):
conn.query(q)
self._do_get_result()
return self.rowcount
def _do_get_result(self):
conn = self._get_db()
self.rowcount = conn._result.affected_rows
self.rownumber = 0
self.description = conn._result.description
self.lastrowid = conn._result.insert_id
self._rows = conn._result.rows
self._has_next = conn._result.has_next
conn._result = None
def __iter__(self):
self._check_executed()
result = self.rownumber and self._rows[self.rownumber:] or self._rows
return iter(result)
Warning = Warning
Error = Error
InterfaceError = InterfaceError

View File

@@ -3,7 +3,8 @@ import unittest
class PyMySQLTestCase(unittest.TestCase):
databases = [
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql"},
{"host":"localhost","user":"root",
"passwd":"","db":"test_pymysql", "use_unicode": True},
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}]
def setUp(self):

View File

@@ -15,7 +15,8 @@ class TestConversion(base.PyMySQLTestCase):
c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v)
c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes")
r = c.fetchone()
self.assertEqual(v[:8], r[:8])
self.assertEqual("\x01", r[0])
self.assertEqual(v[1:8], r[1:8])
# mysql throws away microseconds so we need to check datetimes
# specially. additionally times are turned into timedeltas.
self.assertEqual(datetime.datetime(*v[8].timetuple()[:6]), r[8])

View File

@@ -104,8 +104,8 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
self.assertEqual('1', pymysql.converters.escape_item(1, "utf8"))
self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8"))
self.assertEqual('1', pymysql.converters.escape_object(1, "utf8"))
self.assertEqual('1', pymysql.converters.escape_object(1L, "utf8"))
self.assertEqual('1', pymysql.converters.escape_object(1))
self.assertEqual('1', pymysql.converters.escape_object(1L))
def test_issue_15(self):
""" query should be expanded before perform character encoding """