// Copyright (c) 2004-2008 MySQL AB, 2008-2009 Sun Microsystems, Inc. // // This program is free software; you can redistribute it and/or modify // it under the terms of the GNU General Public License version 2 as published by // the Free Software Foundation // // There are special exceptions to the terms and conditions of the GPL // as it is applied to this software. View the full text of the // exception in file EXCEPTIONS in the directory of this software // distribution. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program; if not, write to the Free Software // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA using System; using System.Collections; using System.Diagnostics; using System.IO; using MySql.Data.Common; using MySql.Data.Types; using System.Security.Cryptography.X509Certificates; using MySql.Data.MySqlClient.Properties; using System.Text; #if !CF using System.Net.Security; using System.Security.Authentication; using System.Globalization; using System.Text; #endif namespace MySql.Data.MySqlClient { /// /// Summary description for Driver. /// internal class NativeDriver : IDriver { private DBVersion version; private int threadId; protected String encryptionSeed; protected ServerStatusFlags serverStatus; protected MySqlStream stream; protected Stream baseStream; private BitArray nullMap; private MySqlPacket packet; private ClientFlags connectionFlags; private Driver owner; private int warnings; public NativeDriver(Driver owner) { this.owner = owner; threadId = -1; } public ClientFlags Flags { get { return connectionFlags; } } public int ThreadId { get { return threadId; } } public DBVersion Version { get { return version; } } public ServerStatusFlags ServerStatus { get { return serverStatus; } } public int WarningCount { get { return warnings; } } public MySqlPacket Packet { get { return packet; } } private MySqlConnectionStringBuilder Settings { get { return owner.Settings; } } private Encoding Encoding { get { return owner.Encoding; } } private void HandleException(MySqlException ex) { if (ex.IsFatal) owner.Close(); } private void ReadOk(bool read) { try { if (read) packet = stream.ReadPacket(); byte marker = (byte) packet.ReadByte(); if (marker != 0) throw new MySqlException("Out of sync with server", true, null); packet.ReadFieldLength(); /* affected rows */ packet.ReadFieldLength(); /* last insert id */ if (packet.HasMoreData) { serverStatus = (ServerStatusFlags) packet.ReadInteger(2); packet.ReadInteger(2); /* warning count */ if (packet.HasMoreData) { packet.ReadLenString(); /* message */ } } } catch (MySqlException ex) { HandleException(ex); throw; } } /// /// Sets the current database for the this connection /// /// public void SetDatabase(string dbName) { byte[] dbNameBytes = Encoding.GetBytes(dbName); packet.Clear(); packet.WriteByte((byte)DBCmd.INIT_DB); packet.Write(dbNameBytes); ExecutePacket(packet); ReadOk(true); } public void Configure() { stream.MaxPacketSize = (ulong)owner.MaxPacketSize; stream.Encoding = Encoding; } public void Open() { // connect to one of our specified hosts try { #if !CF if (Settings.ConnectionProtocol == MySqlConnectionProtocol.SharedMemory) { SharedMemoryStream str = new SharedMemoryStream(Settings.SharedMemoryName); str.Open(Settings.ConnectionTimeout); baseStream = str; } else { #endif string pipeName = Settings.PipeName; if (Settings.ConnectionProtocol != MySqlConnectionProtocol.NamedPipe) pipeName = null; StreamCreator sc = new StreamCreator(Settings.Server, Settings.Port, pipeName, Settings.Keepalive); baseStream = sc.GetStream(Settings.ConnectionTimeout); #if !CF } #endif } catch (Exception ex) { throw new MySqlException(Resources.UnableToConnectToHost, (int) MySqlErrorCode.UnableToConnectToHost, ex); } if (baseStream == null) throw new MySqlException(Resources.UnableToConnectToHost, (int)MySqlErrorCode.UnableToConnectToHost); int maxSinglePacket = 255*255*255; stream = new MySqlStream(baseStream, Encoding, false); stream.ResetTimeout((int)Settings.ConnectionTimeout*1000); // read off the welcome packet and parse out it's values packet = stream.ReadPacket(); int protocol = packet.ReadByte(); string versionString = packet.ReadString(); version = DBVersion.Parse(versionString); if (!version.isAtLeast(4, 1, 1)) throw new NotSupportedException(Resources.ServerTooOld); threadId = packet.ReadInteger(4); encryptionSeed = packet.ReadString(); maxSinglePacket = (256*256*256) - 1; // read in Server capabilities if they are provided ClientFlags serverCaps = 0; if (packet.HasMoreData) serverCaps = (ClientFlags) packet.ReadInteger(2); /* New protocol with 16 bytes to describe server characteristics */ owner.ConnectionCharSetIndex = (int)packet.ReadByte(); serverStatus = (ServerStatusFlags) packet.ReadInteger(2); packet.Position += 13; string seedPart2 = packet.ReadString(); encryptionSeed += seedPart2; // based on our settings, set our connection flags SetConnectionFlags(serverCaps); packet.Clear(); packet.WriteInteger((int) connectionFlags, 4); #if !CF if ((serverCaps & ClientFlags.SSL) ==0) { if ((Settings.SslMode != MySqlSslMode.None) && (Settings.SslMode != MySqlSslMode.Preferred)) { // Client requires SSL connections. string message = String.Format(Resources.NoServerSSLSupport, Settings.Server); throw new MySqlException(message); } } else if (Settings.SslMode != MySqlSslMode.None) { stream.SendPacket(packet); StartSSL(); packet.Clear(); packet.WriteInteger((int) connectionFlags, 4); } #endif packet.WriteInteger(maxSinglePacket, 4); packet.WriteByte(8); packet.Write(new byte[23]); Authenticate(); // if we are using compression, then we use our CompressedStream class // to hide the ugliness of managing the compression if ((connectionFlags & ClientFlags.COMPRESS) != 0) stream = new MySqlStream(baseStream, Encoding, true); // give our stream the server version we are connected to. // We may have some fields that are read differently based // on the version of the server we are connected to. packet.Version = version; stream.MaxBlockSize = maxSinglePacket; } #if !CF #region SSL /// /// Retrieve client SSL certificates. Dependent on connection string /// settings we use either file or store based certificates. /// private X509CertificateCollection GetClientCertificates() { X509CertificateCollection certs = new X509CertificateCollection(); // Check for file-based certificate if (Settings.CertificateFile != null) { X509Certificate2 clientCert = new X509Certificate2(Settings.CertificateFile, Settings.CertificatePassword); certs.Add(clientCert); return certs; } if (Settings.CertificateStoreLocation == MySqlCertificateStoreLocation.None) return certs; StoreLocation location = (Settings.CertificateStoreLocation == MySqlCertificateStoreLocation.CurrentUser) ? StoreLocation.CurrentUser : StoreLocation.LocalMachine; // Check for store-based certificate X509Store store = new X509Store(StoreName.My, location); store.Open(OpenFlags.ReadOnly | OpenFlags.OpenExistingOnly); if (Settings.CertificateThumbprint == null) { // Return all certificates from the store. certs.AddRange(store.Certificates); return certs; } // Find certificate with given thumbprint certs.AddRange(store.Certificates.Find(X509FindType.FindByThumbprint, Settings.CertificateThumbprint, true)); if (certs.Count == 0) { throw new MySqlException("Certificate with Thumbprint " + Settings.CertificateThumbprint + " not found"); } return certs; } private void StartSSL() { RemoteCertificateValidationCallback sslValidateCallback = new RemoteCertificateValidationCallback(ServerCheckValidation); SslStream ss = new SslStream(baseStream, true, sslValidateCallback, null); X509CertificateCollection certs = GetClientCertificates(); ss.AuthenticateAsClient(Settings.Server, certs, SslProtocols.Default, false); baseStream = ss; stream = new MySqlStream(ss, Encoding, false); stream.SequenceByte = 2; } private bool ServerCheckValidation(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { if (sslPolicyErrors == SslPolicyErrors.None) return true; if (Settings.SslMode == MySqlSslMode.Preferred || Settings.SslMode == MySqlSslMode.Required) { //Tolerate all certificate errors. return true; } if (Settings.SslMode == MySqlSslMode.VerifyCA && sslPolicyErrors == SslPolicyErrors.RemoteCertificateNameMismatch) { // Tolerate name mismatch in certificate, if full validation is not requested. return true; } return false; } #endregion #endif #region Authentication /// /// Return the appropriate set of connection flags for our /// server capabilities and our user requested options. /// private void SetConnectionFlags(ClientFlags serverCaps) { // allow load data local infile ClientFlags flags = ClientFlags.LOCAL_FILES; if (!Settings.UseAffectedRows) flags |= ClientFlags.FOUND_ROWS; flags |= ClientFlags.PROTOCOL_41; // Need this to get server status values flags |= ClientFlags.TRANSACTIONS; // user allows/disallows batch statements if (Settings.AllowBatch) flags |= ClientFlags.MULTI_STATEMENTS; // We always allow multiple result sets flags |= ClientFlags.MULTI_RESULTS; // if the server allows it, tell it that we want long column info if ((serverCaps & ClientFlags.LONG_FLAG) != 0) flags |= ClientFlags.LONG_FLAG; // if the server supports it and it was requested, then turn on compression if ((serverCaps & ClientFlags.COMPRESS) != 0 && Settings.UseCompression) flags |= ClientFlags.COMPRESS; flags |= ClientFlags.LONG_PASSWORD; // for long passwords // did the user request an interactive session? if (Settings.InteractiveSession) flags |= ClientFlags.INTERACTIVE; // if the server allows it and a database was specified, then indicate // that we will connect with a database name if ((serverCaps & ClientFlags.CONNECT_WITH_DB) != 0 && Settings.Database != null && Settings.Database.Length > 0) flags |= ClientFlags.CONNECT_WITH_DB; // if the server is requesting a secure connection, then we oblige if ((serverCaps & ClientFlags.SECURE_CONNECTION) != 0) flags |= ClientFlags.SECURE_CONNECTION; // if the server is capable of SSL and the user is requesting SSL if ((serverCaps & ClientFlags.SSL) != 0 && Settings.SslMode != MySqlSslMode.None) flags |= ClientFlags.SSL; // if the server supports output parameters, then we do too //if ((serverCaps & ClientFlags.PS_MULTI_RESULTS) != 0) flags |= ClientFlags.PS_MULTI_RESULTS; connectionFlags = flags; } /// /// Perform an authentication against a 4.1.1 server /// private void AuthenticateNew() { if ((connectionFlags & ClientFlags.SECURE_CONNECTION) == 0) AuthenticateOld(); packet.Write(Crypt.Get411Password(Settings.Password, encryptionSeed)); if ((connectionFlags & ClientFlags.CONNECT_WITH_DB) != 0 && Settings.Database != null) packet.WriteString(Settings.Database); stream.SendPacket(packet); // this result means the server wants us to send the password using // old encryption packet = stream.ReadPacket(); if (packet.IsLastPacket) { packet.Clear(); packet.WriteString(Crypt.EncryptPassword( Settings.Password, encryptionSeed.Substring(0, 8), true)); stream.SendPacket(packet); ReadOk(true); } else ReadOk(false); } private void AuthenticateOld() { packet.WriteString(Crypt.EncryptPassword( Settings.Password, encryptionSeed, true)); if ((connectionFlags & ClientFlags.CONNECT_WITH_DB) != 0 && Settings.Database != null) packet.WriteString(Settings.Database); stream.SendPacket(packet); ReadOk(true); } public void Authenticate() { // write the user id to the auth packet packet.WriteString(Settings.UserID); AuthenticateNew(); } #endregion public void Reset() { warnings = 0; stream.SequenceByte = 0; packet.Clear(); packet.WriteByte((byte)DBCmd.CHANGE_USER); Authenticate(); } /// /// Query is the method that is called to send all queries to the server /// public void SendQuery(MySqlPacket queryPacket) { warnings = 0; queryPacket.Buffer[4] = (byte)DBCmd.QUERY; ExecutePacket(queryPacket); // the server will respond in one of several ways with the first byte indicating // the type of response. // 0 == ok packet. This indicates non-select queries // 0xff == error packet. This is handled in stream.OpenPacket // > 0 = number of columns in select query // We don't actually read the result here since a single query can generate // multiple resultsets and we don't want to duplicate code. See ReadResult // Instead we set our internal server status flag to indicate that we have a query waiting. // This flag will be maintained by ReadResult serverStatus |= ServerStatusFlags.AnotherQuery; } public void Close(bool isOpen) { try { if (isOpen) { try { packet.Clear(); packet.WriteByte((byte)DBCmd.QUIT); ExecutePacket(packet); } catch (Exception) { // Eat exception here. We should try to closing // the stream anyway. } } if (stream != null) stream.Close(); stream = null; } catch (Exception) { // we are just going to eat any exceptions // generated here } } public bool Ping() { try { packet.Clear(); packet.WriteByte((byte)DBCmd.PING); ExecutePacket(packet); ReadOk(true); return true; } catch (Exception) { return false; } } public int GetResult(ref int affectedRow, ref int insertedId) { try { packet = stream.ReadPacket(); } catch (TimeoutException) { // Do not reset serverStatus, allow to reenter, e.g when // ResultSet is closed. throw; } catch (Exception) { serverStatus = 0; throw; } int fieldCount = (int)packet.ReadFieldLength(); if (-1 == fieldCount) { string filename = packet.ReadString(); SendFileToServer(filename); return GetResult(ref affectedRow, ref insertedId); } else if (fieldCount == 0) { // the code to read last packet will set these server status vars // again if necessary. serverStatus &= ~(ServerStatusFlags.AnotherQuery | ServerStatusFlags.MoreResults); affectedRow = (int)packet.ReadFieldLength(); insertedId = (int)packet.ReadFieldLength(); serverStatus = (ServerStatusFlags)packet.ReadInteger(2); warnings += packet.ReadInteger(2); if (packet.HasMoreData) { packet.ReadLenString(); //TODO: server message } } return fieldCount; } /// /// Sends the specified file to the server. /// This supports the LOAD DATA LOCAL INFILE /// /// private void SendFileToServer(string filename) { byte[] buffer = new byte[8196]; long len = 0; try { using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { len = fs.Length; while (len > 0) { int count = fs.Read(buffer, 4, (int)(len > 8192 ? 8192 : len)); stream.SendEntirePacketDirectly(buffer, count); len -= count; } stream.SendEntirePacketDirectly(buffer, 0); } } catch (Exception ex) { throw new MySqlException("Error during LOAD DATA LOCAL INFILE", ex); } } private void ReadNullMap(int fieldCount) { // if we are binary, then we need to load in our null bitmap nullMap = null; byte[] nullMapBytes = new byte[(fieldCount + 9)/8]; packet.ReadByte(); packet.Read(nullMapBytes, 0, nullMapBytes.Length); nullMap = new BitArray(nullMapBytes); } public IMySqlValue ReadColumnValue(int index, MySqlField field, IMySqlValue valObject) { long length = -1; bool isNull; if (nullMap != null) isNull = nullMap[index + 2]; else { length = packet.ReadFieldLength(); isNull = length == -1; } packet.Encoding = field.Encoding; packet.Version = version; return valObject.ReadValue(packet, length, isNull); } public void SkipColumnValue(IMySqlValue valObject) { int length = -1; if (nullMap == null) { length = packet.ReadFieldLength(); if (length == -1) return; } if (length > -1) packet.Position += length; else valObject.SkipValue(packet); } public void GetColumnsData(MySqlField[] columns) { for (int i = 0; i < columns.Length; i++) GetColumnData(columns[i]); ReadEOF(); } private void GetColumnData(MySqlField field) { stream.Encoding = Encoding; packet = stream.ReadPacket(); field.Encoding = Encoding; field.CatalogName = packet.ReadLenString(); field.DatabaseName = packet.ReadLenString(); field.TableName = packet.ReadLenString(); field.RealTableName = packet.ReadLenString(); field.ColumnName = packet.ReadLenString(); field.OriginalColumnName = packet.ReadLenString(); packet.ReadByte(); field.CharacterSetIndex = packet.ReadInteger(2); field.ColumnLength = packet.ReadInteger(4); MySqlDbType type = (MySqlDbType)packet.ReadByte(); ColumnFlags colFlags; if ((connectionFlags & ClientFlags.LONG_FLAG) != 0) colFlags = (ColumnFlags)packet.ReadInteger(2); else colFlags = (ColumnFlags)packet.ReadByte(); field.Scale = (byte)packet.ReadByte(); if (packet.HasMoreData) { packet.ReadInteger(2); // reserved } if (type == MySqlDbType.Decimal || type == MySqlDbType.NewDecimal) { field.Precision = (byte)(field.ColumnLength - (int)field.Scale); if ((colFlags & ColumnFlags.UNSIGNED) != 0) field.Precision++; } field.SetTypeAndFlags(type, colFlags); } private void ExecutePacket(MySqlPacket packetToExecute) { try { warnings = 0; stream.SequenceByte = 0; stream.SendPacket(packetToExecute); } catch (MySqlException ex) { HandleException(ex); throw; } } public void ExecuteStatement(MySqlPacket packetToExecute) { warnings = 0; packetToExecute.Buffer[4] = (byte)DBCmd.EXECUTE; ExecutePacket(packetToExecute); serverStatus |= ServerStatusFlags.AnotherQuery; } private void CheckEOF() { if (!packet.IsLastPacket) throw new MySqlException("Expected end of data packet"); packet.ReadByte(); // read off the 254 if (packet.HasMoreData) { warnings += packet.ReadInteger(2); serverStatus = (ServerStatusFlags)packet.ReadInteger(2); // if we are at the end of this cursor based resultset, then we remove // the last row sent status flag so our next fetch doesn't abort early // and we remove this command result from our list of active CommandResult objects. // if ((serverStatus & ServerStatusFlags.LastRowSent) != 0) // { // serverStatus &= ~ServerStatusFlags.LastRowSent; // commandResults.Remove(lastCommandResult); // } } } private void ReadEOF() { packet = stream.ReadPacket(); CheckEOF(); } public int PrepareStatement(string sql, ref MySqlField[] parameters) { //TODO: check this //ClearFetchedRow(); packet.Length = sql.Length*4 + 5; byte[] buffer = packet.Buffer; int len = Encoding.GetBytes(sql, 0, sql.Length, packet.Buffer, 5); packet.Position = len + 5; buffer[4] = (byte)DBCmd.PREPARE; ExecutePacket(packet); packet = stream.ReadPacket(); int marker = packet.ReadByte(); if (marker != 0) throw new MySqlException("Expected prepared statement marker"); int statementId = packet.ReadInteger(4); int numCols = packet.ReadInteger(2); int numParams = packet.ReadInteger(2); //TODO: find out what this is needed for packet.ReadInteger(3); if (numParams > 0) { parameters = owner.GetColumns(numParams); // we set the encoding for each parameter back to our connection encoding // since we can't trust what is coming back from the server for (int i = 0; i < parameters.Length; i++) parameters[i].Encoding = Encoding; } if (numCols > 0) { while (numCols-- > 0) { packet = stream.ReadPacket(); //TODO: handle streaming packets } ReadEOF(); } return statementId; } // private void ClearFetchedRow() // { // if (lastCommandResult == 0) return; //TODO /* CommandResult result = (CommandResult)commandResults[lastCommandResult]; result.ReadRemainingColumns(); stream.OpenPacket(); if (! stream.IsLastPacket) throw new MySqlException("Cursor reading out of sync"); ReadEOF(false); lastCommandResult = 0;*/ // } /// /// FetchDataRow is the method that the data reader calls to see if there is another /// row to fetch. In the non-prepared mode, it will simply read the next data packet. /// In the prepared mode (statementId > 0), it will /// public bool FetchDataRow(int statementId, int columns) { /* ClearFetchedRow(); if (!commandResults.ContainsKey(statementId)) return false; if ( (serverStatus & ServerStatusFlags.LastRowSent) != 0) return false; stream.StartPacket(9, true); stream.WriteByte((byte)DBCmd.FETCH); stream.WriteInteger(statementId, 4); stream.WriteInteger(1, 4); stream.Flush(); lastCommandResult = statementId; */ packet = stream.ReadPacket(); if (packet.IsLastPacket) { CheckEOF(); return false; } nullMap = null; if (statementId > 0) ReadNullMap(columns); return true; } public void CloseStatement(int statementId) { packet.Clear(); packet.WriteByte((byte)DBCmd.CLOSE_STMT); packet.WriteInteger((long)statementId, 4); stream.SequenceByte = 0; stream.SendPacket(packet); } /// /// Execution timeout, in milliseconds. When the accumulated time for network IO exceeds this value /// TimeoutException is thrown. This timeout needs to be reset for every new command /// /// public void ResetTimeout(int timeout) { if (stream != null) stream.ResetTimeout(timeout); } } }