Explorar o código

Rename Read(int length) to SocketRead(int length)
Introduce dispose lock to resolve race condition in IsConnected/IsSocketConnected.
Rename _socketLock to _socketWriteLock
Eliminate extra allocations in ReceiveMessage, and combine two socket reads.
Use separate lock to eliminate race condition in IsSocketConnected between Poll and checking the Available property.
Modify SocketRead(int length, byte[] buffer) to also take offset.
Modify MessageListener to use Select instead of blocking Receive.

Fixes issue #80.

drieseng %!s(int64=9) %!d(string=hai) anos
pai
achega
f9ad89384f

+ 210 - 0
src/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs

@@ -0,0 +1,210 @@
+using System;
+using System.Collections.Generic;
+using System.Globalization;
+using System.Net;
+using System.Net.Sockets;
+using System.Security.Cryptography;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Common;
+using Renci.SshNet.Compression;
+using Renci.SshNet.Messages;
+using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Security;
+using Renci.SshNet.Security.Cryptography;
+using Renci.SshNet.Tests.Common;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connected_ServerAndClientDisconnectRace
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<IKeyExchange> _keyExchangeMock;
+        private Mock<IClientAuthentication> _clientAuthenticationMock;
+        private IPEndPoint _serverEndPoint;
+        private string _keyExchangeAlgorithm;
+        private DisconnectMessage _disconnectMessage;
+
+        protected Random Random { get; private set; }
+        protected byte[] SessionId { get; private set; }
+        protected ConnectionInfo ConnectionInfo { get; private set; }
+        protected IList<EventArgs> DisconnectedRegister { get; private set; }
+        protected IList<MessageEventArgs<DisconnectMessage>> DisconnectReceivedRegister { get; private set; }
+        protected IList<ExceptionEventArgs> ErrorOccurredRegister { get; private set; }
+        protected AsyncSocketListener ServerListener { get; private set; }
+        protected IList<byte[]> ServerBytesReceivedRegister { get; private set; }
+        protected Session Session { get; private set; }
+        protected Socket ServerSocket { get; private set; }
+
+        private void TearDown()
+        {
+            if (ServerListener != null)
+            {
+                ServerListener.Dispose();
+            }
+
+            if (Session != null)
+            {
+                Session.Dispose();
+            }
+        }
+
+        protected virtual void SetupData()
+        {
+            Random = new Random();
+
+            _serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+            ConnectionInfo = new ConnectionInfo(
+                _serverEndPoint.Address.ToString(),
+                _serverEndPoint.Port,
+                "user",
+                new PasswordAuthenticationMethod("user", "password"))
+            { Timeout = TimeSpan.FromSeconds(20) };
+            _keyExchangeAlgorithm = Random.Next().ToString(CultureInfo.InvariantCulture);
+            SessionId = new byte[10];
+            Random.NextBytes(SessionId);
+            DisconnectedRegister = new List<EventArgs>();
+            DisconnectReceivedRegister = new List<MessageEventArgs<DisconnectMessage>>();
+            ErrorOccurredRegister = new List<ExceptionEventArgs>();
+            ServerBytesReceivedRegister = new List<byte[]>();
+            _disconnectMessage = new DisconnectMessage(DisconnectReason.ServiceNotAvailable, "Not today!");
+
+            Session = new Session(ConnectionInfo, _serviceFactoryMock.Object);
+            Session.Disconnected += (sender, args) => DisconnectedRegister.Add(args);
+            Session.DisconnectReceived += (sender, args) => DisconnectReceivedRegister.Add(args);
+            Session.ErrorOccured += (sender, args) => ErrorOccurredRegister.Add(args);
+            Session.KeyExchangeInitReceived += (sender, args) =>
+            {
+                var newKeysMessage = new NewKeysMessage();
+                var newKeys = newKeysMessage.GetPacket(8, null);
+                ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None);
+            };
+
+            ServerListener = new AsyncSocketListener(_serverEndPoint);
+            ServerListener.Connected += socket =>
+            {
+                ServerSocket = socket;
+
+                socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                socket.Send(Encoding.ASCII.GetBytes("SSH-2.0-SshStub\r\n"));
+            };
+
+            var counter = 0;
+
+            ServerListener.BytesReceived += (received, socket) =>
+            {
+                ServerBytesReceivedRegister.Add(received);
+
+                switch (counter++)
+                {
+                    case 0:
+                        var keyExchangeInitMessage = new KeyExchangeInitMessage
+                        {
+                            CompressionAlgorithmsClientToServer = new string[0],
+                            CompressionAlgorithmsServerToClient = new string[0],
+                            EncryptionAlgorithmsClientToServer = new string[0],
+                            EncryptionAlgorithmsServerToClient = new string[0],
+                            KeyExchangeAlgorithms = new[] { _keyExchangeAlgorithm },
+                            LanguagesClientToServer = new string[0],
+                            LanguagesServerToClient = new string[0],
+                            MacAlgorithmsClientToServer = new string[0],
+                            MacAlgorithmsServerToClient = new string[0],
+                            ServerHostKeyAlgorithms = new string[0]
+                        };
+                        var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null);
+                        ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None);
+                        break;
+                    case 1:
+                        var serviceAcceptMessage =ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication).Build();
+                        ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None);
+                        break;
+                }
+            };
+        }
+
+        private void CreateMocks()
+        {
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _keyExchangeMock = new Mock<IKeyExchange>(MockBehavior.Strict);
+            _clientAuthenticationMock = new Mock<IClientAuthentication>(MockBehavior.Strict);
+        }
+
+        private void SetupMocks()
+        {
+            _serviceFactoryMock.Setup(
+                p =>
+                    p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
+            _keyExchangeMock.Setup(p => p.Name).Returns(_keyExchangeAlgorithm);
+            _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny<KeyExchangeInitMessage>()));
+            _keyExchangeMock.Setup(p => p.ExchangeHash).Returns(SessionId);
+            _keyExchangeMock.Setup(p => p.CreateServerCipher()).Returns((Cipher)null);
+            _keyExchangeMock.Setup(p => p.CreateClientCipher()).Returns((Cipher)null);
+            _keyExchangeMock.Setup(p => p.CreateServerHash()).Returns((HashAlgorithm)null);
+            _keyExchangeMock.Setup(p => p.CreateClientHash()).Returns((HashAlgorithm)null);
+            _keyExchangeMock.Setup(p => p.CreateCompressor()).Returns((Compressor)null);
+            _keyExchangeMock.Setup(p => p.CreateDecompressor()).Returns((Compressor)null);
+            _keyExchangeMock.Setup(p => p.Dispose());
+            _serviceFactoryMock.Setup(p => p.CreateClientAuthentication()).Returns(_clientAuthenticationMock.Object);
+            _clientAuthenticationMock.Setup(p => p.Authenticate(ConnectionInfo, Session));
+        }
+
+        protected virtual void Arrange()
+        {
+            CreateMocks();
+            SetupData();
+            SetupMocks();
+
+            ServerListener.Start();
+            Session.Connect();
+        }
+
+        [TestMethod]
+        public  void Act()
+        {
+            for (var i = 0; i < 50; i++)
+            {
+                Arrange();
+                try
+                {
+                    var disconnect = _disconnectMessage.GetPacket(8, null);
+                    ServerSocket.Send(disconnect, 4, disconnect.Length - 4, SocketFlags.None);
+                    Session.Disconnect();
+                }
+                finally
+                {
+                    TearDown();
+                }
+            }
+        }
+
+        private class ServiceAcceptMessageBuilder
+        {
+            private readonly ServiceName _serviceName;
+
+            private ServiceAcceptMessageBuilder(ServiceName serviceName)
+            {
+                _serviceName = serviceName;
+            }
+
+            public static ServiceAcceptMessageBuilder Create(ServiceName serviceName)
+            {
+                return new ServiceAcceptMessageBuilder(serviceName);
+            }
+
+            public byte[] Build()
+            {
+                var serviceName = _serviceName.ToArray();
+
+                var sshDataStream = new SshDataStream(4 + 1 + 1 + 4 + serviceName.Length);
+                sshDataStream.Write((uint)(sshDataStream.Capacity - 4)); // packet length
+                sshDataStream.WriteByte(0); // padding length
+                sshDataStream.WriteByte(ServiceAcceptMessage.MessageNumber);
+                sshDataStream.WriteBinary(serviceName);
+                return sshDataStream.ToArray();
+            }
+        }
+    }
+}

+ 1 - 3
src/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsDisconnectMessage.cs

@@ -8,7 +8,6 @@ using Renci.SshNet.Messages.Transport;
 
 namespace Renci.SshNet.Tests.Classes
 {
-    [TestClass]
     public class SessionTest_Connected_ServerSendsDisconnectMessage : SessionTest_ConnectedBase
     {
         private DisconnectMessage _disconnectMessage;
@@ -25,8 +24,7 @@ namespace Renci.SshNet.Tests.Classes
             var disconnect = _disconnectMessage.GetPacket(8, null);
             ServerSocket.Send(disconnect, 4, disconnect.Length - 4, SocketFlags.None);
 
-            // give session some time to process DisconnectMessage
-            Thread.Sleep(200);
+            Session.Disconnect();
         }
 
         [TestMethod]

+ 7 - 6
src/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj

@@ -254,6 +254,7 @@
     <Compile Include="Classes\SessionTest_Connected_ConnectionReset.cs" />
     <Compile Include="Classes\SessionTest_Connected_Disconnect.cs" />
     <Compile Include="Classes\SessionTest_Connected_GlobalRequestMessageAfterAuthenticationRace.cs" />
+    <Compile Include="Classes\SessionTest_Connected_ServerAndClientDisconnectRace.cs" />
     <Compile Include="Classes\SessionTest_Connected_ServerSendsDisconnectMessage.cs" />
     <Compile Include="Classes\SessionTest_Connected_ServerSendsBadPacket.cs" />
     <Compile Include="Classes\SessionTest_Connected_ServerSendsDisconnectMessageAndShutsDownSocket.cs" />
@@ -535,12 +536,6 @@
     <Compile Include="Classes\Sftp\SftpSynchronizeDirectoriesAsyncResultTest.cs" />
     <Compile Include="Classes\Sftp\SftpUploadAsyncResultTest.cs" />
   </ItemGroup>
-  <ItemGroup>
-    <ProjectReference Include="..\Renci.SshNet\Renci.SshNet.csproj">
-      <Project>{2F5F8C90-0BD1-424F-997C-7BC6280919D1}</Project>
-      <Name>Renci.SshNet</Name>
-    </ProjectReference>
-  </ItemGroup>
   <ItemGroup>
     <EmbeddedResource Include="Properties\Resources.resx">
       <Generator>ResXFileCodeGenerator</Generator>
@@ -576,6 +571,12 @@
     <EmbeddedResource Include="Data\Key.SSH2.DSA.Encrypted.Des.CBC.12345.txt" />
     <EmbeddedResource Include="Data\Key.SSH2.DSA.txt" />
   </ItemGroup>
+  <ItemGroup>
+    <ProjectReference Include="..\Renci.SshNet\Renci.SshNet.csproj">
+      <Project>{2f5f8c90-0bd1-424f-997c-7bc6280919d1}</Project>
+      <Name>Renci.SshNet</Name>
+    </ProjectReference>
+  </ItemGroup>
   <Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
   <!-- To modify your build process, add your task inside one of the targets below and uncomment it. 
        Other similar extension points exist, see Microsoft.Common.targets.

+ 16 - 30
src/Renci.SshNet/Session.NET.cs

@@ -6,13 +6,6 @@ namespace Renci.SshNet
 {
     public partial class Session
     {
-#if FEATURE_SOCKET_POLL
-        /// <summary>
-        /// Holds the lock object to ensure read access to the socket is synchronized.
-        /// </summary>
-        private readonly object _socketReadLock = new object();
-#endif // FEATURE_SOCKET_POLL
-
 #if FEATURE_SOCKET_POLL
         /// <summary>
         /// Gets a value indicating whether the socket is connected.
@@ -63,35 +56,28 @@ namespace Renci.SshNet
 #endif
         partial void IsSocketConnected(ref bool isConnected)
         {
-            isConnected = (_socket != null && _socket.Connected);
-#if FEATURE_SOCKET_POLL
-            if (isConnected)
+            DiagnosticAbstraction.Log(string.Format("[{0}] {1} Checking socket", ToHex(SessionId), DateTime.Now.Ticks));
+
+            lock (_socketDisposeLock)
             {
-                // synchronize this to ensure thread B does not reset the wait handle before
-                // thread A was able to check whether "bytes read from socket" signal was
-                // actually received
-                lock (_socketReadLock)
+#if FEATURE_SOCKET_POLL
+                if (_socket == null || !_socket.Connected)
                 {
-                    DiagnosticAbstraction.Log(string.Format("[{0}] {1} Checking socket", ToHex(SessionId), DateTime.Now.Ticks));
+                    isConnected = false;
+                    return;
+                }
 
-                    // reset waithandle, as we're only interested in reads that take
-                    // place between Poll and the Available check
-                    _bytesReadFromSocket.Reset();
-                    var connectionClosedOrDataAvailable = _socket.Poll(100, SelectMode.SelectRead);
+                lock (_socketReadLock)
+                {
+                    var connectionClosedOrDataAvailable = _socket.Poll(1, SelectMode.SelectRead);
                     isConnected = !(connectionClosedOrDataAvailable && _socket.Available == 0);
-                    if (!isConnected)
-                    {
-                        // the race condition is between the Socket.Poll call and
-                        // Socket.Available, but the event handler - where we signal that
-                        // bytes have been received from the socket - is sometimes invoked
-                        // shortly after
-                        isConnected = _bytesReadFromSocket.WaitOne(500);
-                    }
-
-                    DiagnosticAbstraction.Log(string.Format("[{0}] {1} Checked socket", ToHex(SessionId), DateTime.Now.Ticks));
                 }
-            }
+#else
+                isConnected = _socket != null && _socket.Connected;
 #endif // FEATURE_SOCKET_POLL
+            }
+
+            DiagnosticAbstraction.Log(string.Format("[{0}] {1} Checked socket", ToHex(SessionId), DateTime.Now.Ticks));
         }
     }
 }

+ 169 - 141
src/Renci.SshNet/Session.cs

@@ -86,16 +86,6 @@ namespace Renci.SshNet
         /// </summary>
         private SshMessageFactory _sshMessageFactory;
 
-        /// <summary>
-        /// Holds connection socket.
-        /// </summary>
-        private Socket _socket;
-
-        /// <summary>
-        /// Holds locker object for the socket
-        /// </summary>
-        private readonly object _socketLock = new object();
-
         /// <summary>
         /// Holds a <see cref="WaitHandle"/> that is signaled when the message listener loop has completed.
         /// </summary>
@@ -126,11 +116,6 @@ namespace Renci.SshNet
         /// </summary>
         private EventWaitHandle _keyExchangeCompletedWaitHandle = new ManualResetEvent(false);
 
-        /// <summary>
-        /// WaitHandle to signal that bytes have been read from the socket.
-        /// </summary>
-        private EventWaitHandle _bytesReadFromSocket = new ManualResetEvent(false);
-
         /// <summary>
         /// WaitHandle to signal that key exchange is in progress.
         /// </summary>
@@ -172,6 +157,37 @@ namespace Renci.SshNet
         /// </summary>
         private readonly IServiceFactory _serviceFactory;
 
+        /// <summary>
+        /// Holds connection socket.
+        /// </summary>
+        private Socket _socket;
+
+        /// <summary>
+        /// Holds an object that is used to ensure only a single thread can read from
+        /// <see cref="_socket"/> at any given time.
+        /// </summary>
+        private readonly object _socketReadLock = new object();
+
+        /// <summary>
+        /// Holds an object that is used to ensure only a single thread can write to
+        /// <see cref="_socket"/> at any given time.
+        /// </summary>
+        /// <remarks>
+        /// This is also used to ensure that <see cref="_outboundPacketSequence"/> is
+        /// incremented atomatically.
+        /// </remarks>
+        private readonly object _socketWriteLock = new object();
+
+        /// <summary>
+        /// Holds an object that is used to ensure only a single thread can dispose
+        /// <see cref="_socket"/> at any given time.
+        /// </summary>
+        /// <remarks>
+        /// This is also used to ensure that <see cref="_socket"/> will not be disposed
+        /// while performing a given operation or set of operations on <see cref="_socket"/>.
+        /// </remarks>
+        private readonly object _socketDisposeLock = new object();
+
         /// <summary>
         /// Gets the session semaphore that controls session channels.
         /// </summary>
@@ -638,8 +654,6 @@ namespace Renci.SshNet
                     RegisterMessage("SSH_MSG_CHANNEL_DATA");
                     RegisterMessage("SSH_MSG_CHANNEL_EOF");
                     RegisterMessage("SSH_MSG_CHANNEL_CLOSE");
-
-                    Monitor.Pulse(this);
                 }
             }
             finally
@@ -680,9 +694,12 @@ namespace Renci.SshNet
             // send disconnect message to the server if the connection is still open
             // and the disconnect message has not yet been sent
             //
-            // note that this should also cause the listener thread to be stopped as
+            // note that this should also cause the listener loop to be interrupted as
             // the server should respond by closing the socket
-            SendDisconnect(reason, message);
+            if (IsConnected)
+            {
+                SendDisconnect(reason, message);
+            }
 
             // disconnect socket, and dispose it
             SocketDisconnectAndDispose();
@@ -789,8 +806,9 @@ namespace Renci.SshNet
 
             var packetData = message.GetPacket(paddingMultiplier, _clientCompression);
 
-            //  Lock handling of _outboundPacketSequence since it must be sent sequently to server
-            lock (_socketLock)
+            // take a write lock to ensure the outbound packet sequence number is incremented
+            // atomically, and only after the packet has actually been sent
+            lock (_socketWriteLock)
             {
                 if (_socket == null || !_socket.Connected)
                     throw new SshConnectionException("Client not connected.");
@@ -832,9 +850,13 @@ namespace Renci.SshNet
                     SocketAbstraction.Send(_socket, data, 0, data.Length);
                 }
 
+                // increment the packet sequence number only after we're sure the packet has
+                // been sent; even though it's only used for the MAC, it need to be incremented
+                // for each package sent.
+                // 
+                // the server will use it to verify the data integrity, and as such the order in
+                // which messages are sent must follow the outbound packet sequence number
                 _outboundPacketSequence++;
-
-                Monitor.Pulse(_socketLock);
             }
         }
 
@@ -872,7 +894,9 @@ namespace Renci.SshNet
         /// <summary>
         /// Receives the message from the server.
         /// </summary>
-        /// <returns>Incoming SSH message.</returns>
+        /// <returns>
+        /// The incoming SSH message.
+        /// </returns>
         /// <exception cref="SshConnectionException"></exception>
         /// <remarks>
         /// We need no locking here since all messages are read by a single thread.
@@ -886,86 +910,89 @@ namespace Renci.SshNet
             // The length of the "padding length" field in bytes
             const int paddingLengthFieldLength = 1;
 
-            //  Determine the size of the first block, which is 8 or cipher block size (whichever is larger) bytes
-            var blockSize = _serverCipher == null ? (byte)8 : Math.Max((byte)8, _serverCipher.MinimumSize);
-
-            //  Read first block - which starts with the packet length
-            var firstBlock = Read(blockSize);
-
-#if DEBUG_GERT
-            DiagnosticAbstraction.Log(string.Format("[{0}] FirstBlock [{1}]: {2}", ToHex(SessionId), blockSize, ToHex(firstBlock)));
-#endif // DEBUG_GERT
-
-            if (_serverCipher != null)
-            {
-                firstBlock = _serverCipher.Decrypt(firstBlock);
-#if DEBUG_GERT
-                DiagnosticAbstraction.Log(string.Format("[{0}] FirstBlock decrypted [{1}]: {2}", ToHex(SessionId), firstBlock.Length, ToHex(firstBlock)));
-#endif // DEBUG_GERT
-            }
-
-            var packetLength = (uint)(firstBlock[0] << 24 | firstBlock[1] << 16 | firstBlock[2] << 8 | firstBlock[3]);
+            // Determine the size of the first block, which is 8 or cipher block size (whichever is larger) bytes
+            var blockSize = _serverCipher == null ? (byte) 8 : Math.Max((byte) 8, _serverCipher.MinimumSize);
 
-            //  Test packet minimum and maximum boundaries
-            if (packetLength < Math.Max((byte)16, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
-                throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength), DisconnectReason.ProtocolError);
+            var serverMacLength = _serverMac != null ? _serverMac.HashSize / 8 : 0;
 
-            //  Determine the number of bytes left to read; We've already read "blockSize" bytes, but the
-            //  "packet length" field itself - which is 4 bytes - is not included in the length of the packet
-            var bytesToRead = (int) (packetLength - (blockSize - packetLengthFieldLength));
+            byte[] data;
+            uint packetLength;
 
-            //  Construct buffer for holding the payload and the inbound packet sequence as we need both in order
-            //  to generate the hash
-            var data = new byte[bytesToRead + blockSize + inboundPacketSequenceLength];
-            _inboundPacketSequence.Write(data, 0);
-            Buffer.BlockCopy(firstBlock, 0, data, inboundPacketSequenceLength, firstBlock.Length);
-
-            byte[] serverHash = null;
-            if (_serverMac != null)
-            {
-                serverHash = new byte[_serverMac.HashSize / 8];
-                bytesToRead += serverHash.Length;
-            }
-
-            if (bytesToRead > 0)
+            lock (_socketReadLock)
             {
-                var nextBlocks = Read(bytesToRead);
+                //  Read first block - which starts with the packet length
+                var firstBlock = SocketRead(blockSize);
 
 #if DEBUG_GERT
-                DiagnosticAbstraction.Log(string.Format("[{0}] NextBlocks [{1}]: {2}", ToHex(SessionId), bytesToRead, ToHex(nextBlocks)));
+                DiagnosticAbstraction.Log(string.Format("[{0}] FirstBlock [{1}]: {2}", ToHex(SessionId), blockSize, ToHex(firstBlock)));
 #endif // DEBUG_GERT
 
-                if (serverHash != null)
+                if (_serverCipher != null)
                 {
-                    Buffer.BlockCopy(nextBlocks, nextBlocks.Length - serverHash.Length, serverHash, 0, serverHash.Length);
-                    nextBlocks = nextBlocks.Take(nextBlocks.Length - serverHash.Length);
+                    firstBlock = _serverCipher.Decrypt(firstBlock);
 #if DEBUG_GERT
-                    DiagnosticAbstraction.Log(string.Format("[{0}] ServerHash [{1}]: {2}", ToHex(SessionId), serverHash.Length, ToHex(serverHash)));
+                    DiagnosticAbstraction.Log(string.Format("[{0}] FirstBlock decrypted [{1}]: {2}", ToHex(SessionId), firstBlock.Length, ToHex(firstBlock)));
 #endif // DEBUG_GERT
                 }
 
-                if (nextBlocks.Length > 0)
+                packetLength = (uint) (firstBlock[0] << 24 | firstBlock[1] << 16 | firstBlock[2] << 8 | firstBlock[3]);
+
+                // Test packet minimum and maximum boundaries
+                if (packetLength < Math.Max((byte)16, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
+                    throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength), DisconnectReason.ProtocolError);
+
+                // Determine the number of bytes left to read; We've already read "blockSize" bytes, but the
+                // "packet length" field itself - which is 4 bytes - is not included in the length of the packet
+                var bytesToRead = (int)(packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength;
+
+                // Construct buffer for holding the payload and the inbound packet sequence as we need both in order
+                // to generate the hash.
+                // 
+                // The total length of the "data" buffer is an addition of:
+                // - inboundPacketSequenceLength (4 bytes)
+                // - packetLength
+                // - serverMacLength
+                // 
+                // We include the inbound packet sequence to allow us to have the the full SSH packet in a single
+                // byte[] for the purpose of calculating the client hash. Room for the server MAC is foreseen
+                // to read the packet including server MAC in a single pass (except for the initial block).
+                data = new byte[bytesToRead + blockSize + inboundPacketSequenceLength];
+                _inboundPacketSequence.Write(data, 0);
+                Buffer.BlockCopy(firstBlock, 0, data, inboundPacketSequenceLength, firstBlock.Length);
+
+                if (bytesToRead > 0)
+                {
+                    SocketRead(data, blockSize + inboundPacketSequenceLength, bytesToRead);
+                }
+            }
+
+            if (_serverCipher != null)
+            {
+                var numberOfBytesToDecrypt = data.Length - (blockSize + inboundPacketSequenceLength + serverMacLength);
+                if (numberOfBytesToDecrypt > 0)
                 {
-                    if (_serverCipher != null)
-                    {
-                        nextBlocks = _serverCipher.Decrypt(nextBlocks);
 #if DEBUG_GERT
-                        DiagnosticAbstraction.Log(string.Format("[{0}] NextBlocks decrypted [{1}]: {2}", ToHex(SessionId), nextBlocks.Length, ToHex(nextBlocks)));
+                    DiagnosticAbstraction.Log(string.Format("[{0}] NextBlocks [{1}]: {2}", ToHex(SessionId), bytesToRead, ToHex(nextBlocks)));
 #endif // DEBUG_GERT
-                    }
 
-                    nextBlocks.CopyTo(data, blockSize + inboundPacketSequenceLength);
+                    var decryptedData = _serverCipher.Decrypt(data, blockSize + inboundPacketSequenceLength, numberOfBytesToDecrypt);
+                    Buffer.BlockCopy(decryptedData, 0, data, blockSize + inboundPacketSequenceLength, decryptedData.Length);
+
+#if DEBUG_GERT
+                    DiagnosticAbstraction.Log(string.Format("[{0}] NextBlocks decrypted [{1}]: {2}", ToHex(SessionId), decryptedData.Length, ToHex(decryptedData)));
+#endif // DEBUG_GERT
                 }
             }
 
             var paddingLength = data[inboundPacketSequenceLength + packetLengthFieldLength];
-            var messagePayloadLength = (int) (packetLength - paddingLength - paddingLengthFieldLength);
+            var messagePayloadLength = (int) packetLength - paddingLength - paddingLengthFieldLength;
             var messagePayloadOffset = inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength;
 
-            //  Validate message against MAC
+            // validate message against MAC
             if (_serverMac != null)
             {
-                var clientHash = _serverMac.ComputeHash(data);
+                var clientHash = _serverMac.ComputeHash(data, 0, data.Length - serverMacLength);
+                var serverHash = data.Take(data.Length - serverMacLength, serverMacLength);
 
                 if (!serverHash.IsEqualTo(clientHash))
                 {
@@ -994,11 +1021,6 @@ namespace Renci.SshNet
 
         private void SendDisconnect(DisconnectReason reasonCode, string message)
         {
-            // only send a disconnect message if it wasn't already sent, and we're
-            // still connected
-            if (_isDisconnectMessageSent || !IsConnected)
-                return;
-
             var disconnectMessage = new DisconnectMessage(reasonCode, message);
 
             // send the disconnect message, but ignore the outcome
@@ -1635,23 +1657,7 @@ namespace Renci.SshNet
                 handlers(this, e);
         }
 
-        /// <summary>
-        /// Reads the specified length of bytes from the server.
-        /// </summary>
-        /// <param name="length">The length.</param>
-        /// <returns>
-        /// The bytes read from the server.
-        /// </returns>
-        private byte[] Read(int length)
-        {
-            var buffer = new byte[length];
-
-            SocketRead(length, buffer);
-
-            return buffer;
-        }
-
-#region Message loading functions
+        #region Message loading functions
 
         /// <summary>
         /// Registers SSH message with the session.
@@ -1756,25 +1762,40 @@ namespace Renci.SshNet
         /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
         /// </summary>
         /// <param name="length">The number of bytes to read.</param>
-        /// <param name="buffer">The buffer to read to.</param>
+        /// <returns>
+        /// The bytes read from the server.
+        /// </returns>
+        private byte[] SocketRead(int length)
+        {
+            var buffer = new byte[length];
+            SocketRead(buffer, 0, length);
+            return buffer;
+        }
+
+        /// <summary>
+        /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
+        /// </summary>
+        /// <param name="buffer">An array of type <see cref="byte"/> that is the storage location for the received data.</param>
+        /// <param name="offset">The position in <paramref name="buffer"/> parameter to store the received data.</param>
+        /// <param name="length">The number of bytes to read.</param>
+        /// <returns>
+        /// The number of bytes read.
+        /// </returns>
         /// <exception cref="SshConnectionException">The socket is closed.</exception>
         /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
         /// <exception cref="SocketException">The read failed.</exception>
-        private void SocketRead(int length, byte[] buffer)
+        private int SocketRead(byte[] buffer, int offset, int length)
         {
-            if (SocketAbstraction.Read(_socket, buffer, 0, length, InfiniteTimeSpan) > 0)
+            var bytesRead = SocketAbstraction.Read(_socket, buffer, offset, length, InfiniteTimeSpan);
+            if (bytesRead == 0)
             {
-                // signal that bytes have been read from the socket
-                // this is used to improve accuracy of Session.IsSocketConnected
-                _bytesReadFromSocket.Set();
-                return;
+                // when we're in the disconnecting state (either triggered by client or server), then the
+                // SshConnectionException will interrupt the message listener loop (if not already interrupted)
+                // and the exception itself will be ignored (in RaiseError)
+                throw new SshConnectionException("An established connection was aborted by the server.",
+                    DisconnectReason.ConnectionLost);
             }
-
-            // when we're in the disconnecting state (either triggered by client or server), then the
-            // SshConnectionException will interrupt the message listener loop (if not already interrupted)
-            // and the exception itself will be ignored (in RaiseError)
-            throw new SshConnectionException("An established connection was aborted by the server.",
-                DisconnectReason.ConnectionLost);
+            return bytesRead;
         }
 
         /// <summary>
@@ -1820,26 +1841,33 @@ namespace Renci.SshNet
         }
 
         /// <summary>
-        /// Disconnects and disposes the socket.
+        /// Shuts down and disposes the socket.
         /// </summary>
         private void SocketDisconnectAndDispose()
         {
             if (_socket != null)
             {
-                lock (_socketLock)
+                lock (_socketDisposeLock)
                 {
                     if (_socket != null)
                     {
                         if (_socket.Connected)
                         {
+                            // interrupt any pending reads
                             _socket.Shutdown(SocketShutdown.Send);
-                            SocketAbstraction.ClearReadBuffer(_socket);
+
+                            // since we've shut down the socket, there should not be
+                            // any reads in progress but we still take a read lock
+                            // to ensure IsSocketConnected continues to provide
+                            // correct results
+                            lock (_socketReadLock)
+                            {
+                                SocketAbstraction.ClearReadBuffer(_socket);
+                            }
                         }
 
-                        DiagnosticAbstraction.Log(string.Format("[{0}] {1} Disposing socket", ToHex(SessionId), DateTime.Now.Ticks));
                         _socket.Dispose();
                         _socket = null;
-                        DiagnosticAbstraction.Log(string.Format("[{0}] {1} Disposed socket", ToHex(SessionId), DateTime.Now.Ticks));
                     }
                 }
             }
@@ -1852,8 +1880,22 @@ namespace Renci.SshNet
         {
             try
             {
-                while (_socket != null && _socket.Connected)
+                var readSockets = new List<Socket> {_socket};
+
+                while (_socket != null)
                 {
+                    Socket.Select(readSockets, null, null, -1);
+
+                    if (readSockets.Count == 0)
+                        break;
+
+                    // when the socket is disposed while a Select is executing, then the
+                    // Select will be interrupted; the socket will not be removed from
+                    // readSocket
+                    var socket = _socket;
+                    if (socket == null || !socket.Connected)
+                        break;
+
                     var message = ReceiveMessage();
                     HandleMessageCore(message);
                 }
@@ -1876,9 +1918,7 @@ namespace Renci.SshNet
         private byte SocketReadByte()
         {
             var buffer = new byte[1];
-
-            SocketRead(1, buffer);
-
+            SocketRead(buffer, 0, 1);
             return buffer[0];
         }
 
@@ -1931,13 +1971,8 @@ namespace Renci.SshNet
                     throw new ProxyException("SOCKS4: Not valid response.");
             }
 
-            var dummyBuffer = new byte[4];
-
-            //  Read 2 bytes to be ignored
-            SocketRead(2, dummyBuffer);
-
-            //  Read 4 bytes to be ignored
-            SocketRead(4, dummyBuffer);
+            var dummyBuffer = new byte[6]; // field 3 (2 bytes) and field 4 (4) should be ignored
+            SocketRead(dummyBuffer, 0, 6);
         }
 
         private void ConnectSocks5()
@@ -2080,10 +2115,10 @@ namespace Renci.SshNet
             switch (addressType)
             {
                 case 0x01:
-                    SocketRead(4, responseIp);
+                    SocketRead(responseIp, 0, 4);
                     break;
                 case 0x04:
-                    SocketRead(16, responseIp);
+                    SocketRead(responseIp, 0, 16);
                     break;
                 default:
                     throw new ProxyException(string.Format("Address type '{0}' is not supported.", addressType));
@@ -2092,7 +2127,7 @@ namespace Renci.SshNet
             var port = new byte[2];
 
             //  Read 2 bytes to be ignored
-            SocketRead(2, port);
+            SocketRead(port, 0, 2);
         }
 
         private void ConnectHttp()
@@ -2160,7 +2195,7 @@ namespace Renci.SshNet
                     if (contentLength > 0)
                     {
                         var contentBody = new byte[contentLength];
-                        SocketRead(contentLength, contentBody);
+                        SocketRead(contentBody, 0, contentLength);
                     }
                     break;
                 }
@@ -2301,13 +2336,6 @@ namespace Renci.SshNet
                     _keyExchange = null;
                 }
 
-                var bytesReadFromSocket = _bytesReadFromSocket;
-                if (bytesReadFromSocket != null)
-                {
-                    bytesReadFromSocket.Dispose();
-                    _bytesReadFromSocket = null;
-                }
-
                 var messageListenerCompleted = _messageListenerCompleted;
                 if (messageListenerCompleted != null)
                 {