--[[ Copyright (C) 2010 David Eder. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] require('socket') -- http://w3.impa.br/~diego/software/luasocket/home.html require('lcrypt') -- http://eder.us/projects/lcrypt/index.php function sha1(data) return lcrypt.hashes.sha1:hash(data):done() end function xor(a, b) return lcrypt.xor(a, b) end mysql = { CLIENT_LONG_PASSWORD = 0x0001, -- new more secure passwords CLIENT_FOUND_ROWS = 0x0002, -- Found instead of affected rows CLIENT_LONG_FLAG = 0x0004, -- Get all column flags CLIENT_CONNECT_WITH_DB = 0x0008, -- One can specify db on connect CLIENT_NO_SCHEMA = 0x0010, -- Don't allow database.table.column CLIENT_COMPRESS = 0x0020, -- Can use compression protocol CLIENT_ODBC = 0x0040, -- Odbc client CLIENT_LOCAL_FILES = 0x0080, -- Can use LOAD DATA LOCAL CLIENT_IGNORE_SPACE = 0x0100, -- Ignore spaces before '(' CLIENT_PROTOCOL_41 = 0x0200, -- New 4.1 protocol CLIENT_INTERACTIVE = 0x0400, -- This is an interactive client CLIENT_SSL = 0x0800, -- Switch to SSL after handshake CLIENT_IGNORE_SIGPIPE = 0x1000, -- IGNORE sigpipes CLIENT_TRANSACTIONS = 0x2000, -- Client knows about transactions CLIENT_RESERVED = 0x4000, -- Old flag for 4.1 protocol CLIENT_SECURE_CONNECTION = 0x8000, -- New 4.1 authentication CLIENT_MULTI_STATEMENTS = 0x010000, -- Enable/disable multi-stmt support CLIENT_MULTI_RESULTS = 0x020000, -- Enable/disable multi-results COM_SLEEP = 0x00, -- (none, this is an internal thread state) COM_QUIT = 0x01, -- mysql_close COM_INIT_DB = 0x02, -- mysql_select_db COM_QUERY = 0x03, -- mysql_real_query COM_FIELD_LIST = 0x04, -- mysql_list_fields COM_CREATE_DB = 0x05, -- mysql_create_db (deprecated) COM_DROP_DB = 0x06, -- mysql_drop_db (deprecated) COM_REFRESH = 0x07, -- mysql_refresh COM_SHUTDOWN = 0x08, -- mysql_shutdown COM_STATISTICS = 0x09, -- mysql_stat COM_PROCESS_INFO = 0x0a, -- mysql_list_processes COM_CONNECT = 0x0b, -- (none, this is an internal thread state) COM_PROCESS_KILL = 0x0c, -- mysql_kill COM_DEBUG = 0x0d, -- mysql_dump_debug_info COM_PING = 0x0e, -- mysql_ping COM_TIME = 0x0f, -- (none, this is an internal thread state) COM_DELAYED_INSERT = 0x10, -- (none, this is an internal thread state) COM_CHANGE_USER = 0x11, -- mysql_change_user COM_BINLOG_DUMP = 0x12, -- sent by the slave IO thread to request a binlog COM_TABLE_DUMP = 0x13, -- LOAD TABLE ... FROM MASTER (deprecated) COM_CONNECT_OUT = 0x14, -- (none, this is an internal thread state) COM_REGISTER_SLAVE = 0x15, -- sent by the slave to register with the master (optional) COM_STMT_PREPARE = 0x16, -- mysql_stmt_prepare COM_STMT_EXECUTE = 0x17, -- mysql_stmt_execute COM_STMT_SEND_LONG_DATA = 0x18, -- mysql_stmt_send_long_data COM_STMT_CLOSE = 0x19, -- mysql_stmt_close COM_STMT_RESET = 0x1a, -- mysql_stmt_reset COM_SET_OPTION = 0x1b, -- mysql_set_server_option COM_STMT_FETCH = 0x1c, -- mysql_stmt_fetch REFRESH_GRANT = 0x01, REFRESH_LOG = 0x02, REFRESH_TABLES = 0x04, REFRESH_HOSTS = 0x08, REFRESH_STATUS = 0x10, REFRESH_THREADS = 0x20, REFRESH_SLAVE = 0x40, REFRESH_MASTER = 0x80, SHUTDOWN_DEFAULT = 0x00, SHUTDOWN_WAIT_CONNECTIONS = 0x01, SHUTDOWN_WAIT_TRANSACTIONS = 0x02, SHUTDOWN_WAIT_UPDATES = 0x08, SHUTDOWN_WAIT_ALL_BUFFERS = 0x10, SHUTDOWN_WAIT_CRITICAL_BUFFERS = 0x11, KILL_QUERY = 0xfe, KILL_CONNECTION = 0xff, MYSQL_OPTION_MULTI_STATEMENTS_ON = 0x00, MYSQL_OPTION_MULTI_STATEMENTS_OFF = 0x01, PACKET_OK = 0x00, PACKET_ERROR = 0xff, PACKET_EOF = 0xfe, FIELD_TYPE_DECIMAL = 0x00, FIELD_TYPE_TINY = 0x01, FIELD_TYPE_SHORT = 0x02, FIELD_TYPE_LONG = 0x03, FIELD_TYPE_FLOAT = 0x04, FIELD_TYPE_DOUBLE = 0x05, FIELD_TYPE_NULL = 0x06, FIELD_TYPE_TIMESTAMP = 0x07, FIELD_TYPE_LONGLONG = 0x08, FIELD_TYPE_INT24 = 0x09, FIELD_TYPE_DATE = 0x0a, FIELD_TYPE_TIME = 0x0b, FIELD_TYPE_DATETIME = 0x0c, FIELD_TYPE_YEAR = 0x0d, FIELD_TYPE_NEWDATE = 0x0e, FIELD_TYPE_VARCHAR = 0x0f, -- (new in MySQL 5.0) FIELD_TYPE_BIT = 0x10, -- (new in MySQL 5.0) FIELD_TYPE_NEWDECIMAL = 0xf6, -- (new in MYSQL 5.0) FIELD_TYPE_ENUM = 0xf7, FIELD_TYPE_SET = 0xf8, FIELD_TYPE_TINY_BLOB = 0xf9, FIELD_TYPE_MEDIUM_BLOB = 0xfa, FIELD_TYPE_LONG_BLOB = 0xfb, FIELD_TYPE_BLOB = 0xfc, FIELD_TYPE_VAR_STRING = 0xfd, FIELD_TYPE_STRING = 0xfe, FIELD_TYPE_GEOMETRY = 0xff, NOT_NULL_FLAG = 0x0001, PRI_KEY_FLAG = 0x0002, UNIQUE_KEY_FLAG = 0x0004, MULTIPLE_KEY_FLAG = 0x0008, BLOB_FLAG = 0x0010, UNSIGNED_FLAG = 0x0020, ZEROFILL_FLAG = 0x0040, BINARY_FLAG = 0x0080, ENUM_FLAG = 0x0100, AUTO_INCREMENT_FLAG = 0x0200, TIMESTAMP_FLAG = 0x0400, SET_FLAG = 0x0800, escape_patterns = { ['\\']='\\\\', ["'"]="\\'", ['"']='\\"', ['\n']='\\n', ['\r']='\\r' } } function mysql:escape_string(s) s = s:gsub('.', self.escape_patterns) s = s:gsub('%z', '\\0') return s end function mysql:parse_int(v, offset, len) if offset and len then v = v:sub(offset, offset + len - 1) end if not offset then offset = 1 end if not len then len = #v - offset + 1 end local ret,m = 0,1 for i = 1, #v do ret,m = ret + m * v:byte(i), m * 256 end return ret,offset + len end function mysql:encode_int(v, length) local ret = '' while #ret < length do ret,v = ret .. string.char(v % 256), math.floor(v / 256) end return ret end function mysql:read_packet() local packet = {} local len = self.socket:receive(3) packet.length = self:parse_int(len) local num = self.socket:receive(1) packet.sequence_number = num:byte(1) packet.payload = self.socket:receive(packet.length) return packet end function mysql:read_handshake() local handshake = {} local packet = self:read_packet() handshake.protocol_version = packet.payload:byte(1) local offset = 2 while packet.payload:byte(offset) ~= 0 and offset <= #packet.payload do offset = offset + 1 end handshake.server_version = packet.payload:sub(2, offset-1) offset = offset + 1 handshake.thread_id = self:parse_int(packet.payload, offset, 4) offset = offset + 4 handshake.scramble_buff = packet.payload:sub(offset, offset + 7) offset = offset + 9 handshake.server_capabilities = self:parse_int(packet.payload, offset, 2) offset = offset + 2 handshake.server_language = packet.payload:byte(offset) offset = offset + 1 handshake.server_status = self:parse_int(packet.payload, offset, 2) offset = offset + 15 handshake.scramble_buff = handshake.scramble_buff .. packet.payload:sub(offset, offset + 12) return handshake end function mysql:encode_password(password, salt) if not password then return '' end local stage1 = sha1(password) return xor(sha1(salt .. sha1(stage1)), stage1) end function mysql:encode_lcb(data) if #data <= 250 then return string.char(#data) .. data end if #data <= 0xffff then return string.char(252) .. self:encode_int(#data, 2) .. data end if #data <= 0xffffff then return string.char(253) .. self:encode_int(#data, 3) .. data end return string.char(254) .. self:encode_int(#data, 8) .. data end function mysql:parse_lcb(data, offset) local len = data:byte(offset) if not len then return end offset = offset + 1 if len == 252 then len,offset = self:parse_int(data, offset, 2) elseif len == 253 then len,offset = self:parse_int(data, offset, 3) elseif len == 253 then len,offset = self:parse_int(data, offset, 8) end return data:sub(offset, offset + len - 1), offset + len end function mysql:send_auth(user, password, database) local client_flags = self.CLIENT_PROTOCOL_41 + self.CLIENT_SECURE_CONNECTION + self.CLIENT_MULTI_STATEMENTS + self.CLIENT_MULTI_RESULTS local max_packet_size = 0xffffff local packet = self:encode_int(client_flags, 4) .. self:encode_int(max_packet_size, 4) .. self:encode_int(self.handshake.server_language, 1) .. string.rep(string.char(0), 23) .. user .. string.char(0) .. self:encode_lcb(self:encode_password(password, self.handshake.scramble_buff)) if database then packet = packet .. database .. string.char(0) end self.socket:send(self:encode_int(#packet, 3) .. string.char(self.sequence_number) .. packet) end function mysql:parse_field(data, offset, type) local ret data,offset = self:parse_lcb(data, offset) ret = data return ret,offset end function mysql:read_reply() local ret = {} local reply = self:read_packet() local offset = 2 local packet_type = reply.payload:byte(1) if packet_type == self.PACKET_OK then ret.ok = {} ret.ok.affected_rows,offset = self:parse_int(self:parse_lcb(reply.payload, offset)) ret.ok.insert_id,offset = self:parse_int(self:parse_lcb(reply.payload, offset)) ret.ok.server_status,offset = self:parse_int(reply.payload, offset, 2) ret.ok.warning_count,offset = self:parse_int(reply.payload, offset, 2) ret.ok.message = reply.payload:sub(offset, #reply.payload) elseif packet_type == self.PACKET_ERROR then ret.error = {} ret.error.errno,offset = self:parse_int(reply.payload, offset, 2) offset = offset + 1 ret.error.sqlstate,offset = reply.payload:sub(offset, offset + 4), offset + 5 ret.error.message = reply.payload:sub(offset, #reply.payload) elseif packet_type == self.PACKET_EOF then ret.eof = {} ret.eof.warning_count,offset = self:parse_int(reply.payload, offset, 2) ret.eof.status_flags,offset = self:parse_int(reply.payload, offset, 2) else ret.result = {} -- read fields local fields = {} for i = 1,packet_type do reply = self:read_packet() offset = 1 local field = {} field.catalog,offset = self:parse_lcb(reply.payload, offset) field.db,offset = self:parse_lcb(reply.payload, offset) field.table,offset = self:parse_lcb(reply.payload, offset) field.org_table,offset = self:parse_lcb(reply.payload, offset) field.name,offset = self:parse_lcb(reply.payload, offset) field.org_name,offset = self:parse_lcb(reply.payload, offset) offset = offset + 1 field.charsetnr,offset = self:parse_int(reply.payload, offset, 2) field.length,offset = self:parse_int(reply.payload, offset, 4) field.type,offset = self:parse_int(reply.payload, offset, 1) field.flags,offset = self:parse_int(reply.payload, offset, 2) field.decimals,offset = self:parse_int(reply.payload, offset, 1) field.default,offset = self:parse_lcb(reply.payload, offset) fields[#fields+1] = field end -- read eof reply = self:read_packet() -- todo verify eof -- read rows reply = self:read_packet() while reply.payload:byte(1) ~= self.PACKET_EOF do offset = 1 local record = {} for i = 1,#fields do record[fields[i].name],offset = self:parse_field(reply.payload, offset) end ret.result[#ret.result+1] = record reply = self:read_packet() end end return ret end function mysql:send_command(command, data) self.sequence_number = 0 local packet = string.char(command) if data then packet = packet .. data end self.socket:send(self:encode_int(#packet, 3) .. string.char(self.sequence_number) .. packet) return self:read_reply() end function mysql:quit() return self:send_command(self.COM_QUIT) end function mysql:select_db(db) return self:send_command(self.COM_INIT_DB, db) end function mysql:query(query, vars) if vars then local s = self query = string.gsub(query, '{(.-)}', function(a) return s:escape_string(vars[a] or a) end ) end return self:send_command(self.COM_QUERY, query), query end function mysql:field_list(table, column) local data = table .. string.char(0) if column then data = data .. column end return self:send_command(self.COM_FIELD_LIST, data) end function mysql:flush(mask) return self:send_command(self.COM_REFRESH, string.char(mask)) end function mysql:shutdown(mask) return self:send_command(self.COM_SHUTDOWN, string.char(mask)) end function mysql:status() return self:send_command(self.COM_STATISTICS) end function mysql:process_list() return self:send_command(self.COM_PROCESS_INFO) end function mysql:kill(process_id) return self:send_command(self.COM_PROCESS_KILL, self:encode_int(process_id, 4)) end function mysql:debug() return self:send_command(self.COM_DEBUG) end function mysql:ping() return self:send_command(self.COM_PING) end function mysql:option(mask) return self:send_command(self.COM_SET_OPTION, string.char(mask)) end function mysql:connect(host, user, password, database) local state = {sequence_number=1} for k,v in pairs(self) do state[k] = v end state.socket = socket.tcp() state.socket:connect(host, 3306) state.handshake = state:read_handshake() state:send_auth(user, password, database) local auth_reply = state:read_reply() if not auth_reply.ok then return end if database then state:select_db('intertech') end return state end --[[ db = mysql:connect('127.0.0.1', 'user', 'password', 'database') query = "select * from Table where field='{value}'" res = db:query(query, { value='some value' }) for i=1,#res do print('record #' .. i) for k,v in pairs(res[i]) do print('\t'..k..' = ' .. v) end end ]]