Drizzled Public API Documentation

handshake.py
00001 #!/usr/bin/env python
00002 #
00003 # Drizzle Client & Protocol Library
00004 # 
00005 # Copyright (C) 2008 Eric Day (eday@oddments.org)
00006 # All rights reserved.
00007 #
00008 # Redistribution and use in source and binary forms, with or without
00009 # modification, are permitted provided that the following conditions are
00010 # met:
00011 #
00012 #     * Redistributions of source code must retain the above copyright
00013 # notice, this list of conditions and the following disclaimer.
00014 #
00015 #     * Redistributions in binary form must reproduce the above
00016 # copyright notice, this list of conditions and the following disclaimer
00017 # in the documentation and/or other materials provided with the
00018 # distribution.
00019 #
00020 #     * The names of its contributors may not be used to endorse or
00021 # promote products derived from this software without specific prior
00022 # written permission.
00023 #
00024 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00025 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00026 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
00027 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
00028 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
00029 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
00030 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00031 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00032 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00033 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
00034 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00035 #
00036 
00037 '''
00038 MySQL Protocol Handshake Objects
00039 '''
00040 
00041 import struct
00042 import unittest
00043 import bitfield
00044 
00045 class Capabilities(bitfield.BitField):
00046   _fields = [
00047     'LONG_PASSWORD',
00048     'FOUND_ROWS',
00049     'LONG_FLAG',
00050     'CONNECT_WITH_DB',
00051     'NO_SCHEMA',
00052     'COMPRESS',
00053     'ODBC',
00054     'LOCAL_FILES',
00055     'IGNORE_SPACE',
00056     'PROTOCOL_41',
00057     'INTERACTIVE',
00058     'SSL',
00059     'IGNORE_SIGPIPE',
00060     'TRANSACTIONS',
00061     'RESERVED',
00062     'SECURE_CONNECTION',
00063     'MULTI_STATEMENTS',
00064     'MULTI_RESULTS',
00065     None,
00066     None,
00067     None,
00068     None,
00069     None,
00070     None,
00071     None,
00072     None,
00073     None,
00074     None,
00075     None,
00076     None,
00077     'SSL_VERIFY_SERVER_CERT',
00078     'REMEMBER_OPTIONS'
00079   ]
00080 
00081 class Status(bitfield.BitField):
00082   _fields = [
00083     'IN_TRANS',
00084     'AUTOCOMMIT',
00085     'MORE_RESULTS_EXISTS',
00086     'QUERY_NO_GOOD_INDEX_USED',
00087     'QUERY_NO_INDEX_USED',
00088     'CURSOR_EXISTS',
00089     'LAST_ROW_SENT',
00090     'DB_DROPPED',
00091     'NO_BACKSLASH_ESCAPES',
00092     'QUERY_WAS_SLOW'
00093   ]
00094 
00095 class ServerHandshake(object):
00096   '''This class represents the initial handshake sent from server to client.'''
00097 
00098   def __init__(self, packed=None, protocol_version=10, server_version='',
00099                thread_id=0, scramble=tuple([0] * 20), null1=0, capabilities=0,
00100                charset=0, status=0, unused=tuple([0] * 13), null2=0):
00101     if packed is None:
00102       self.protocol_version = protocol_version
00103       self.server_version = server_version
00104       self.thread_id = thread_id
00105       self.scramble = scramble
00106       self.null1 = null1
00107       self.capabilities = Capabilities(capabilities)
00108       self.charset = charset
00109       self.status = Status(status)
00110       self.unused = unused
00111       self.null2 = null2
00112     else:
00113       self.protocol_version = struct.unpack('B', packed[:1])[0]
00114       server_version_length = packed[1:].index('\x00')
00115       self.server_version = packed[1:1+server_version_length]
00116       data = struct.unpack('<I8BB2BB2B13B12BB', packed[2+server_version_length:])
00117       self.thread_id = data[0]
00118       self.scramble = data[1:9] + data[28:40]
00119       self.null1 = data[9]
00120       self.capabilities = Capabilities(data[10] | (data[11] << 8))
00121       self.charset = data[12]
00122       self.status = Status(data[13] | (data[14] << 8))
00123       self.unused = data[15:28]
00124       self.null2 = data[40]
00125 
00126   def pack(self):
00127     data = struct.pack('B', self.protocol_version)
00128     data += self.server_version + '\x00'
00129     data += struct.pack('<I', self.thread_id)
00130     data += ''.join(map(chr, self.scramble[:8]))
00131     data += struct.pack('B2BB2B',
00132                        self.null1,
00133                        self.capabilities.value() & 0xFF,
00134                        (self.capabilities.value() >> 8) & 0xFF,
00135                        self.charset,
00136                        self.status.value() & 0xFF,
00137                        (self.status.value() >> 8) & 0xFF)
00138     data += ''.join(map(chr, self.unused))
00139     data += ''.join(map(chr, self.scramble[8:]))
00140     data += struct.pack('B', self.null2)
00141     return data
00142 
00143   def __str__(self):
00144     return '''ServerHandshake
00145   protocol_version = %s
00146   server_version = %s
00147   thread_id = %s
00148   scramble = %s
00149   null1 = %s
00150   capabilities = %s
00151   charset = %s
00152   status = %s
00153   unused = %s
00154   null2 = %s
00155 ''' % (self.protocol_version, self.server_version, self.thread_id,
00156        self.scramble, self.null1, self.capabilities, self.charset,
00157        self.status, self.unused, self.null2)
00158 
00159 class TestServerHandshake(unittest.TestCase):
00160 
00161   def testDefaultInit(self):
00162     handshake = ServerHandshake()
00163     self.verifyDefault(handshake)
00164     handshake.__str__()
00165 
00166   def testKeywordInit(self):
00167     handshake = ServerHandshake(protocol_version=11,
00168                                 server_version='test',
00169                                 thread_id=1234,
00170                                 scramble=tuple([5] * 20),
00171                                 null1=1,
00172                                 capabilities=65279,
00173                                 charset=253,
00174                                 status=64508,
00175                                 unused=tuple([6] * 13),
00176                                 null2=2)
00177     self.verifyCustom(handshake)
00178     handshake.__str__()
00179 
00180   def testUnpackInit(self):
00181     data = struct.pack('B', 11)
00182     data += 'test\x00'
00183     data += struct.pack('<I', 1234)
00184     data += ''.join([chr(5)] * 8)
00185     data += struct.pack('B2BB2B', 1, 255, 254, 253, 252, 251)
00186     data += ''.join([chr(6)] * 13)
00187     data += ''.join([chr(5)] * 12)
00188     data += struct.pack('B', 2)
00189 
00190     handshake = ServerHandshake(data)
00191     self.verifyCustom(handshake)
00192 
00193   def testPack(self):
00194     handshake = ServerHandshake(ServerHandshake().pack())
00195     self.verifyDefault(handshake)
00196 
00197   def verifyDefault(self, handshake):
00198     self.assertEqual(handshake.protocol_version, 10)
00199     self.assertEqual(handshake.server_version, '')
00200     self.assertEqual(handshake.thread_id, 0)
00201     self.assertEqual(handshake.scramble, tuple([0] * 20))
00202     self.assertEqual(handshake.null1, 0)
00203     self.assertEqual(handshake.capabilities.value(), 0)
00204     self.assertEqual(handshake.charset, 0)
00205     self.assertEqual(handshake.status.value(), 0)
00206     self.assertEqual(handshake.unused, tuple([0] * 13))
00207     self.assertEqual(handshake.null2, 0)
00208 
00209   def verifyCustom(self, handshake):
00210     self.assertEqual(handshake.protocol_version, 11)
00211     self.assertEqual(handshake.server_version, 'test')
00212     self.assertEqual(handshake.thread_id, 1234)
00213     self.assertEqual(handshake.scramble, tuple([5] * 20))
00214     self.assertEqual(handshake.null1, 1)
00215     self.assertEqual(handshake.capabilities.value(), 65279)
00216     self.assertEqual(handshake.charset, 253)
00217     self.assertEqual(handshake.status.value(), 64508)
00218     self.assertEqual(handshake.unused, tuple([6] * 13))
00219     self.assertEqual(handshake.null2, 2)
00220 
00221 class ClientHandshake(object):
00222   '''This class represents the client handshake sent back to the server.'''
00223 
00224   def __init__(self, packed=None, capabilities=0, max_packet_size=0, charset=0,
00225                unused=tuple([0] * 23), user='', scramble_size=0,
00226                scramble=None, db=''):
00227     if packed is None:
00228       self.capabilities = Capabilities(capabilities)
00229       self.max_packet_size = max_packet_size
00230       self.charset = charset
00231       self.unused = unused
00232       self.user = user
00233       self.scramble_size = scramble_size
00234       self.scramble = scramble
00235       self.db = db
00236     else:
00237       data = struct.unpack('<IIB23B', packed[:32])
00238       self.capabilities = Capabilities(data[0])
00239       self.max_packet_size = data[1]
00240       self.charset = data[2]
00241       self.unused = data[3:]
00242       packed = packed[32:]
00243       user_length = packed.index('\x00')
00244       self.user = packed[:user_length]
00245       packed = packed[1+user_length:]
00246       self.scramble_size = ord(packed[0])
00247       if self.scramble_size == 0:
00248         self.scramble = None
00249       else:
00250         self.scramble = tuple(map(ord, packed[1:21]))
00251       if packed[-1:] == '\x00':
00252         self.db = packed[21:-1]
00253       else:
00254         self.db = packed[21:]
00255 
00256   def pack(self):
00257     data = struct.pack('<IIB', 
00258                        self.capabilities.value(),
00259                        self.max_packet_size,
00260                        self.charset)
00261     data += ''.join(map(chr, self.unused))
00262     data += self.user + '\x00'
00263     data += chr(self.scramble_size)
00264     if self.scramble_size != 0:
00265       data += ''.join(map(chr, self.scramble))
00266     data += self.db + '\x00'
00267     return data
00268 
00269   def __str__(self):
00270     return '''ClientHandshake
00271   capabilities = %s
00272   max_packet_size = %s
00273   charset = %s
00274   unused = %s
00275   user = %s
00276   scramble_size = %s
00277   scramble = %s
00278   db = %s
00279 ''' % (self.capabilities, self.max_packet_size, self.charset, self.unused,
00280        self.user, self.scramble_size, self.scramble, self.db)
00281 
00282 class TestClientHandshake(unittest.TestCase):
00283 
00284   def testDefaultInit(self):
00285     handshake = ClientHandshake()
00286     self.verifyDefault(handshake)
00287     handshake.__str__()
00288 
00289   def testKeywordInit(self):
00290     handshake = ClientHandshake(capabilities=65279,
00291                                 max_packet_size=64508,
00292                                 charset=253,
00293                                 unused=tuple([6] * 23),
00294                                 user='user',
00295                                 scramble_size=20,
00296                                 scramble=tuple([5] * 20),
00297                                 db='db')
00298     self.verifyCustom(handshake)
00299     handshake.__str__()
00300 
00301   def testUnpackInit(self):
00302     data = struct.pack('<IIB', 65279, 64508, 253)
00303     data += ''.join([chr(6)] * 23)
00304     data += 'user\x00'
00305     data += chr(20)
00306     data += ''.join([chr(5)] * 20)
00307     data += 'db\x00'
00308 
00309     handshake = ClientHandshake(data)
00310     self.verifyCustom(handshake)
00311 
00312   def testPack(self):
00313     handshake = ClientHandshake(ClientHandshake().pack())
00314     self.verifyDefault(handshake)
00315 
00316   def verifyDefault(self, handshake):
00317     self.assertEqual(handshake.capabilities.value(), 0)
00318     self.assertEqual(handshake.max_packet_size, 0)
00319     self.assertEqual(handshake.charset, 0)
00320     self.assertEqual(handshake.unused, tuple([0] * 23))
00321     self.assertEqual(handshake.user, '')
00322     self.assertEqual(handshake.scramble_size, 0)
00323     self.assertEqual(handshake.scramble, None)
00324     self.assertEqual(handshake.db, '')
00325 
00326   def verifyCustom(self, handshake):
00327     self.assertEqual(handshake.capabilities.value(), 65279)
00328     self.assertEqual(handshake.max_packet_size, 64508)
00329     self.assertEqual(handshake.charset, 253)
00330     self.assertEqual(handshake.unused, tuple([6] * 23))
00331     self.assertEqual(handshake.user, 'user')
00332     self.assertEqual(handshake.scramble_size, 20)
00333     self.assertEqual(handshake.scramble, tuple([5] * 20))
00334     self.assertEqual(handshake.db, 'db')
00335 
00336 if __name__ == '__main__':
00337   unittest.main()