Kaynağa Gözat

Only use Socket.Poll, Socket.Select and read lock when FEATURE_SOCKET_POLL is defined.
Added test for SSH server shutdown while we're reading the packet.

drieseng 9 yıl önce
ebeveyn
işleme
395bbe70d6

+ 8 - 0
src/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs

@@ -47,14 +47,22 @@ namespace Renci.SshNet.Tests.Classes
         [TestCleanup]
         public void TearDown()
         {
+            if (ServerSocket != null)
+            {
+                ServerSocket.Dispose();
+                ServerSocket = null;
+            }
+
             if (ServerListener != null)
             {
                 ServerListener.Dispose();
+                ServerListener = null;
             }
 
             if (Session != null)
             {
                 Session.Dispose();
+                Session = null;
             }
         }
 

+ 166 - 0
src/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerShutsDownSendAfterSendingIncompletePacket.cs

@@ -0,0 +1,166 @@
+using System;
+using System.Diagnostics;
+using System.Net.Sockets;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connected_ServerShutsDownSendAfterSendingIncompletePacket : SessionTest_ConnectedBase
+    {
+        protected override void Act()
+        {
+            var incompletePacket = new byte[] {0x0a, 0x05, 0x05};
+            ServerSocket.Send(incompletePacket, 0, incompletePacket.Length, SocketFlags.None);
+
+            // give session some time to start reading packet
+            Thread.Sleep(100);
+
+            ServerSocket.Shutdown(SocketShutdown.Send);
+
+            // give session some time to process shut down of server socket
+            Thread.Sleep(100);
+        }
+
+        [TestMethod]
+        public void IsConnectedShouldReturnFalse()
+        {
+            Assert.IsFalse(Session.IsConnected);
+        }
+
+        [TestMethod]
+        public void DisconnectShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Disconnect();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void DisconnectedIsNeverRaised()
+        {
+            Assert.AreEqual(0, DisconnectedRegister.Count);
+        }
+
+        [TestMethod]
+        public void DisconnectReceivedIsNeverRaised()
+        {
+            Assert.AreEqual(0, DisconnectReceivedRegister.Count);
+        }
+
+        [TestMethod]
+        public void ErrorOccurredIsRaisedOnce()
+        {
+            Assert.AreEqual(1, ErrorOccurredRegister.Count);
+
+            var errorOccurred = ErrorOccurredRegister[0];
+            Assert.IsNotNull(errorOccurred);
+
+            var exception = errorOccurred.Exception;
+            Assert.IsNotNull(exception);
+            Assert.AreEqual(typeof(SshConnectionException), exception.GetType());
+
+            var connectionException = (SshConnectionException) exception;
+            Assert.AreEqual(DisconnectReason.ConnectionLost, connectionException.DisconnectReason);
+            Assert.IsNull(connectionException.InnerException);
+            Assert.AreEqual("An established connection was aborted by the server.", connectionException.Message);
+        }
+
+        [TestMethod]
+        public void DisposeShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Dispose();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void ReceiveOnServerSocketShouldReturnZero()
+        {
+            var buffer = new byte[1];
+
+            var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
+
+            Assert.AreEqual(0, actual);
+        }
+
+        [TestMethod]
+        public void SendMessageShouldSucceed()
+        {
+            try
+            {
+                Session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_MessageListenerCompletedShouldBeSignaled()
+        {
+            var session = (ISession) Session;
+
+            Assert.IsNotNull(session.MessageListenerCompleted);
+            Assert.IsTrue(session.MessageListenerCompleted.WaitOne());
+        }
+
+        [TestMethod]
+        public void ISession_SendMessageShouldSucceed()
+        {
+            var session = (ISession) Session;
+
+            try
+            {
+                session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_TrySendMessageShouldReturnTrue()
+        {
+            var session = (ISession) Session;
+
+            Assert.IsFalse(session.TrySendMessage(new IgnoreMessage()));
+        }
+
+        [TestMethod]
+        public void ISession_WaitOnHandleShouldThrowSshConnectionExceptionDetailingBadPacket()
+        {
+            var session = (ISession) Session;
+            var waitHandle = new ManualResetEvent(false);
+
+            try
+            {
+                session.WaitOnHandle(waitHandle);
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("An established connection was aborted by the server.", ex.Message);
+            }
+        }
+    }
+}

+ 1 - 2
src/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerShutsDownSocket.cs

@@ -1,5 +1,4 @@
-using System;
-using System.Diagnostics;
+using System.Diagnostics;
 using System.Net.Sockets;
 using System.Threading;
 using Microsoft.VisualStudio.TestTools.UnitTesting;

+ 1 - 0
src/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj

@@ -258,6 +258,7 @@
     <Compile Include="Classes\SessionTest_Connected_ServerSendsDisconnectMessage.cs" />
     <Compile Include="Classes\SessionTest_Connected_ServerSendsBadPacket.cs" />
     <Compile Include="Classes\SessionTest_Connected_ServerSendsDisconnectMessageAndShutsDownSocket.cs" />
+    <Compile Include="Classes\SessionTest_Connected_ServerShutsDownSendAfterSendingIncompletePacket.cs" />
     <Compile Include="Classes\SessionTest_Connected_ServerShutsDownSocket.cs" />
     <Compile Include="Classes\SessionTest_NotConnected.cs" />
     <Compile Include="Classes\SessionTest_SocketConnected_BadPacketAndDispose.cs" />

+ 28 - 8
src/Renci.SshNet/Session.cs

@@ -162,11 +162,13 @@ namespace Renci.SshNet
         /// </summary>
         private Socket _socket;
 
+#if FEATURE_SOCKET_POLL
         /// <summary>
         /// Holds an object that is used to ensure only a single thread can read from
         /// <see cref="_socket"/> at any given time.
         /// </summary>
         private readonly object _socketReadLock = new object();
+#endif // FEATURE_SOCKET_POLL
 
         /// <summary>
         /// Holds an object that is used to ensure only a single thread can write to
@@ -918,8 +920,13 @@ namespace Renci.SshNet
             byte[] data;
             uint packetLength;
 
+#if FEATURE_SOCKET_POLL
+            // avoid reading from socket while IsSocketConnected is attempting to determine whether the
+            // socket is still connected by invoking Socket.Poll(...) and subsequently verifying value of
+            // Socket.Available
             lock (_socketReadLock)
             {
+#endif // FEATURE_SOCKET_POLL
                 //  Read first block - which starts with the packet length
                 var firstBlock = SocketRead(blockSize);
 
@@ -938,12 +945,14 @@ namespace Renci.SshNet
                 packetLength = (uint) (firstBlock[0] << 24 | firstBlock[1] << 16 | firstBlock[2] << 8 | firstBlock[3]);
 
                 // Test packet minimum and maximum boundaries
-                if (packetLength < Math.Max((byte)16, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
-                    throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength), DisconnectReason.ProtocolError);
+                if (packetLength < Math.Max((byte) 16, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
+                    throw new SshConnectionException(
+                        string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength),
+                        DisconnectReason.ProtocolError);
 
                 // Determine the number of bytes left to read; We've already read "blockSize" bytes, but the
                 // "packet length" field itself - which is 4 bytes - is not included in the length of the packet
-                var bytesToRead = (int)(packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength;
+                var bytesToRead = (int) (packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength;
 
                 // Construct buffer for holding the payload and the inbound packet sequence as we need both in order
                 // to generate the hash.
@@ -964,7 +973,9 @@ namespace Renci.SshNet
                 {
                     SocketRead(data, blockSize + inboundPacketSequenceLength, bytesToRead);
                 }
+#if FEATURE_SOCKET_POLL
             }
+#endif // FEATURE_SOCKET_POLL
 
             if (_serverCipher != null)
             {
@@ -1856,14 +1867,19 @@ namespace Renci.SshNet
                             // interrupt any pending reads
                             _socket.Shutdown(SocketShutdown.Send);
 
-                            // since we've shut down the socket, there should not be
-                            // any reads in progress but we still take a read lock
-                            // to ensure IsSocketConnected continues to provide
+#if FEATURE_SOCKET_POLL
+                            // since we've shut down the socket, there should not be any reads in progress but
+                            // we still take a read lock to ensure IsSocketConnected continues to provide
                             // correct results
+                            //
+                            // only necessary if IsSocketConnected actually uses Socket.Poll.
                             lock (_socketReadLock)
                             {
+#endif // FEATURE_SOCKET_POLL
                                 SocketAbstraction.ClearReadBuffer(_socket);
+#if FEATURE_SOCKET_POLL
                             }
+#endif // FEATURE_SOCKET_POLL
                         }
 
                         _socket.Dispose();
@@ -1882,8 +1898,10 @@ namespace Renci.SshNet
             {
                 var readSockets = new List<Socket> {_socket};
 
-                while (_socket != null)
+                // remain in message loop until socket is shut down
+                while (true)
                 {
+#if FEATURE_SOCKET_POLL
                     Socket.Select(readSockets, null, null, -1);
 
                     if (readSockets.Count == 0)
@@ -1891,7 +1909,9 @@ namespace Renci.SshNet
 
                     // when the socket is disposed while a Select is executing, then the
                     // Select will be interrupted; the socket will not be removed from
-                    // readSocket
+                    // readSocket, and therefore we need to explicitly check if the
+                    // socket is still connected
+#endif // FEATURE_SOCKET_POLL
                     var socket = _socket;
                     if (socket == null || !socket.Connected)
                         break;