Bläddra i källkod

Remove calls to Socket.Poll and use SocketShutdown.Both (#1706)

The message loop currently sits in a call to Poll until the socket has data to read or
it is closed. This is unnecessary - it can equally just sit in the call to Receive.

The call to Poll in Session.IsConnected is also unnecessary - we can instead just call
Socket.Connected. This only returns the connection state as of the last operation, but
we are always performing operations in the message loop (or else we are not connected),
so it should work equally well while being cheaper.

Lastly, when shutting down the socket, shut down both sides rather than just the sending
side (SocketShutdown.Both rather than SocketShutdown.Send) - at this point we do not care
about reading anything else. This makes it (more) certain that we will break out of the
Receive call in the message loop, as has been noted in #355 for whatever remaining issues
still exist there.
Rob Hague 3 veckor sedan
förälder
incheckning
c335ce2f20

+ 0 - 28
src/Renci.SshNet/Abstractions/SocketAbstraction.cs

@@ -12,34 +12,6 @@ namespace Renci.SshNet.Abstractions
 {
     internal static partial class SocketAbstraction
     {
-        public static bool CanRead(Socket socket)
-        {
-            if (socket.Connected)
-            {
-                return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0;
-            }
-
-            return false;
-        }
-
-        /// <summary>
-        /// Returns a value indicating whether the specified <see cref="Socket"/> can be used
-        /// to send data.
-        /// </summary>
-        /// <param name="socket">The <see cref="Socket"/> to check.</param>
-        /// <returns>
-        /// <see langword="true"/> if <paramref name="socket"/> can be written to; otherwise, <see langword="false"/>.
-        /// </returns>
-        public static bool CanWrite(Socket socket)
-        {
-            if (socket != null && socket.Connected)
-            {
-                return socket.Poll(-1, SelectMode.SelectWrite);
-            }
-
-            return false;
-        }
-
         public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
         {
             var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };

+ 0 - 11
src/Renci.SshNet/Common/Extensions.cs

@@ -10,7 +10,6 @@ using System.Numerics;
 using System.Runtime.CompilerServices;
 using System.Threading;
 
-using Renci.SshNet.Abstractions;
 using Renci.SshNet.Messages;
 
 namespace Renci.SshNet.Common
@@ -319,16 +318,6 @@ namespace Renci.SshNet.Common
             return concat;
         }
 
-        internal static bool CanRead(this Socket socket)
-        {
-            return SocketAbstraction.CanRead(socket);
-        }
-
-        internal static bool CanWrite(this Socket socket)
-        {
-            return SocketAbstraction.CanWrite(socket);
-        }
-
         internal static bool IsConnected(this Socket socket)
         {
             if (socket is null)

+ 92 - 235
src/Renci.SshNet/Session.cs

@@ -81,12 +81,6 @@ namespace Renci.SshNet
         private readonly ISocketFactory _socketFactory;
         private readonly ILogger _logger;
 
-        /// <summary>
-        /// Holds an object that is used to ensure only a single thread can read from
-        /// <see cref="_socket"/> at any given time.
-        /// </summary>
-        private readonly Lock _socketReadLock = new Lock();
-
         /// <summary>
         /// Holds an object that is used to ensure only a single thread can write to
         /// <see cref="_socket"/> at any given time.
@@ -105,7 +99,7 @@ namespace Renci.SshNet
         /// This is also used to ensure that <see cref="_socket"/> will not be disposed
         /// while performing a given operation or set of operations on <see cref="_socket"/>.
         /// </remarks>
-        private readonly SemaphoreSlim _socketDisposeLock = new SemaphoreSlim(1, 1);
+        private readonly Lock _socketDisposeLock = new Lock();
 
         /// <summary>
         /// Holds an object that is used to ensure only a single thread can connect
@@ -279,17 +273,11 @@ namespace Renci.SshNet
         {
             get
             {
-                if (_disposed || _isDisconnectMessageSent || !_isAuthenticated)
-                {
-                    return false;
-                }
-
-                if (_messageListenerCompleted is null || _messageListenerCompleted.WaitOne(0))
-                {
-                    return false;
-                }
-
-                return IsSocketConnected();
+                return !_disposed &&
+                    !_isDisconnectMessageSent &&
+                    _isAuthenticated &&
+                    _messageListenerCompleted?.WaitOne(0) == false &&
+                    _socket.IsConnected();
             }
         }
 
@@ -1046,7 +1034,7 @@ namespace Renci.SshNet
         /// <exception cref="InvalidOperationException">The size of the packet exceeds the maximum size defined by the protocol.</exception>
         internal void SendMessage(Message message)
         {
-            if (!_socket.CanWrite())
+            if (!_socket.IsConnected())
             {
                 throw new SshConnectionException("Client not connected.");
             }
@@ -1161,9 +1149,7 @@ namespace Renci.SshNet
         /// </remarks>
         private void SendPacket(byte[] packet, int offset, int length)
         {
-            _socketDisposeLock.Wait();
-
-            try
+            lock (_socketDisposeLock)
             {
                 if (!_socket.IsConnected())
                 {
@@ -1172,10 +1158,6 @@ namespace Renci.SshNet
 
                 SocketAbstraction.Send(_socket, packet, offset, length);
             }
-            finally
-            {
-                _ = _socketDisposeLock.Release();
-            }
         }
 
         /// <summary>
@@ -1259,76 +1241,70 @@ namespace Renci.SshNet
             byte[] data;
             uint packetLength;
 
-            // avoid reading from socket while IsSocketConnected is attempting to determine whether the
-            // socket is still connected by invoking Socket.Poll(...) and subsequently verifying value of
-            // Socket.Available
-            lock (_socketReadLock)
+            // Read first block - which starts with the packet length
+            var firstBlock = new byte[blockSize];
+            if (TrySocketRead(socket, firstBlock, 0, blockSize) == 0)
             {
-                // Read first block - which starts with the packet length
-                var firstBlock = new byte[blockSize];
-                if (TrySocketRead(socket, firstBlock, 0, blockSize) == 0)
-                {
-                    // connection with SSH server was closed
-                    return null;
-                }
+                // connection with SSH server was closed
+                return null;
+            }
 
-                var plainFirstBlock = firstBlock;
+            var plainFirstBlock = firstBlock;
 
-                // First block is not encrypted in AES GCM mode.
-                if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher)
-                {
-                    _serverCipher.SetSequenceNumber(_inboundPacketSequence);
+            // First block is not encrypted in AES GCM mode.
+            if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher)
+            {
+                _serverCipher.SetSequenceNumber(_inboundPacketSequence);
 
-                    // First block is not encrypted in ETM mode.
-                    if (_serverMac == null || !_serverEtm)
-                    {
-                        plainFirstBlock = _serverCipher.Decrypt(firstBlock);
-                    }
+                // First block is not encrypted in ETM mode.
+                if (_serverMac == null || !_serverEtm)
+                {
+                    plainFirstBlock = _serverCipher.Decrypt(firstBlock);
                 }
+            }
 
-                packetLength = BinaryPrimitives.ReadUInt32BigEndian(plainFirstBlock);
+            packetLength = BinaryPrimitives.ReadUInt32BigEndian(plainFirstBlock);
 
-                // Test packet minimum and maximum boundaries
-                if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
-                {
-                    throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength),
-                                                     DisconnectReason.ProtocolError);
-                }
+            // Test packet minimum and maximum boundaries
+            if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
+            {
+                throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength),
+                                                 DisconnectReason.ProtocolError);
+            }
 
-                // Determine the number of bytes left to read; We've already read "blockSize" bytes, but the
-                // "packet length" field itself - which is 4 bytes - is not included in the length of the packet
-                var bytesToRead = (int)(packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength;
-
-                // Construct buffer for holding the payload and the inbound packet sequence as we need both in order
-                // to generate the hash.
-                //
-                // The total length of the "data" buffer is an addition of:
-                // - inboundPacketSequenceLength (4 bytes)
-                // - packetLength
-                // - serverMacLength
-                //
-                // We include the inbound packet sequence to allow us to have the the full SSH packet in a single
-                // byte[] for the purpose of calculating the client hash. Room for the server MAC is foreseen
-                // to read the packet including server MAC in a single pass (except for the initial block).
-                data = new byte[bytesToRead + blockSize + inboundPacketSequenceLength];
-                BinaryPrimitives.WriteUInt32BigEndian(data, _inboundPacketSequence);
-
-                // Use raw packet length field to calculate the mac in AEAD mode.
-                if (_serverAead)
-                {
-                    Buffer.BlockCopy(firstBlock, 0, data, inboundPacketSequenceLength, blockSize);
-                }
-                else
-                {
-                    Buffer.BlockCopy(plainFirstBlock, 0, data, inboundPacketSequenceLength, blockSize);
-                }
+            // Determine the number of bytes left to read; We've already read "blockSize" bytes, but the
+            // "packet length" field itself - which is 4 bytes - is not included in the length of the packet
+            var bytesToRead = (int)(packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength;
+
+            // Construct buffer for holding the payload and the inbound packet sequence as we need both in order
+            // to generate the hash.
+            //
+            // The total length of the "data" buffer is an addition of:
+            // - inboundPacketSequenceLength (4 bytes)
+            // - packetLength
+            // - serverMacLength
+            //
+            // We include the inbound packet sequence to allow us to have the the full SSH packet in a single
+            // byte[] for the purpose of calculating the client hash. Room for the server MAC is foreseen
+            // to read the packet including server MAC in a single pass (except for the initial block).
+            data = new byte[bytesToRead + blockSize + inboundPacketSequenceLength];
+            BinaryPrimitives.WriteUInt32BigEndian(data, _inboundPacketSequence);
+
+            // Use raw packet length field to calculate the mac in AEAD mode.
+            if (_serverAead)
+            {
+                Buffer.BlockCopy(firstBlock, 0, data, inboundPacketSequenceLength, blockSize);
+            }
+            else
+            {
+                Buffer.BlockCopy(plainFirstBlock, 0, data, inboundPacketSequenceLength, blockSize);
+            }
 
-                if (bytesToRead > 0)
+            if (bytesToRead > 0)
+            {
+                if (TrySocketRead(socket, data, blockSize + inboundPacketSequenceLength, bytesToRead) == 0)
                 {
-                    if (TrySocketRead(socket, data, blockSize + inboundPacketSequenceLength, bytesToRead) == 0)
-                    {
-                        return null;
-                    }
+                    return null;
                 }
             }
 
@@ -1888,84 +1864,6 @@ namespace Renci.SshNet
 #endif
         }
 
-        /// <summary>
-        /// Gets a value indicating whether the socket is connected.
-        /// </summary>
-        /// <returns>
-        /// <see langword="true"/> if the socket is connected; otherwise, <see langword="false"/>.
-        /// </returns>
-        /// <remarks>
-        /// <para>
-        /// As a first check we verify whether <see cref="Socket.Connected"/> is
-        /// <see langword="true"/>. However, this only returns the state of the socket as of
-        /// the last I/O operation.
-        /// </para>
-        /// <para>
-        /// Therefore we use the combination of <see cref="Socket.Poll(int, SelectMode)"/> with mode <see cref="SelectMode.SelectRead"/>
-        /// and <see cref="Socket.Available"/> to verify if the socket is still connected.
-        /// </para>
-        /// <para>
-        /// The MSDN doc mention the following on the return value of <see cref="Socket.Poll(int, SelectMode)"/>
-        /// with mode <see cref="SelectMode.SelectRead"/>:
-        /// <list type="bullet">
-        ///     <item>
-        ///         <description><see langword="true"/> if data is available for reading;</description>
-        ///     </item>
-        ///     <item>
-        ///         <description><see langword="true"/> if the connection has been closed, reset, or terminated; otherwise, returns <see langword="false"/>.</description>
-        ///     </item>
-        /// </list>
-        /// </para>
-        /// <para>
-        /// <c>Conclusion:</c> when the return value is <see langword="true"/> - but no data is available for reading - then
-        /// the socket is no longer connected.
-        /// </para>
-        /// <para>
-        /// When a <see cref="Socket"/> is used from multiple threads, there's a race condition
-        /// between the invocation of <see cref="Socket.Poll(int, SelectMode)"/> and the moment
-        /// when the value of <see cref="Socket.Available"/> is obtained. To workaround this issue
-        /// we synchronize reads from the <see cref="Socket"/>.
-        /// </para>
-        /// <para>
-        /// We assume the socket is still connected if the read lock cannot be acquired immediately.
-        /// In this case, we just return <see langword="true"/> without actually waiting to acquire
-        /// the lock. We don't want to wait for the read lock if another thread already has it because
-        /// there are cases where the other thread holding the lock can be waiting indefinitely for
-        /// a socket read operation to complete.
-        /// </para>
-        /// </remarks>
-        private bool IsSocketConnected()
-        {
-            _socketDisposeLock.Wait();
-
-            try
-            {
-                if (!_socket.IsConnected())
-                {
-                    return false;
-                }
-
-                if (!_socketReadLock.TryEnter())
-                {
-                    return true;
-                }
-
-                try
-                {
-                    var connectionClosedOrDataAvailable = _socket.Poll(0, SelectMode.SelectRead);
-                    return !(connectionClosedOrDataAvailable && _socket.Available == 0);
-                }
-                finally
-                {
-                    _socketReadLock.Exit();
-                }
-            }
-            finally
-            {
-                _ = _socketDisposeLock.Release();
-            }
-        }
-
         /// <summary>
         /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
         /// </summary>
@@ -1988,46 +1886,37 @@ namespace Renci.SshNet
         /// </summary>
         private void SocketDisconnectAndDispose()
         {
-            if (_socket != null)
+            lock (_socketDisposeLock)
             {
-                _socketDisposeLock.Wait();
+                if (_socket is null)
+                {
+                    return;
+                }
 
-                try
+                if (_socket.Connected)
                 {
-#pragma warning disable CA1508 // Avoid dead conditional code; Value could have been changed by another thread.
-                    if (_socket != null)
-#pragma warning restore CA1508 // Avoid dead conditional code
+                    try
                     {
-                        if (_socket.Connected)
-                        {
-                            try
-                            {
-                                _logger.LogDebug("[{SessionId}] Shutting down socket.", SessionIdHex);
-
-                                // Interrupt any pending reads; should be done outside of socket read lock as we
-                                // actually want shutdown the socket to make sure blocking reads are interrupted.
-                                //
-                                // This may result in a SocketException (eg. An existing connection was forcibly
-                                // closed by the remote host) which we'll log and ignore as it means the socket
-                                // was already shut down.
-                                _socket.Shutdown(SocketShutdown.Send);
-                            }
-                            catch (SocketException ex)
-                            {
-                                _logger.LogInformation(ex, "Failure shutting down socket");
-                            }
-                        }
-
-                        _logger.LogDebug("[{SessionId}] Disposing socket.", SessionIdHex);
-                        _socket.Dispose();
-                        _logger.LogDebug("[{SessionId}] Disposed socket.", SessionIdHex);
-                        _socket = null;
+                        _logger.LogDebug("[{SessionId}] Shutting down socket.", SessionIdHex);
+
+                        // Interrupt any pending reads; should be done outside of socket read lock as we
+                        // actually want shutdown the socket to make sure blocking reads are interrupted.
+                        //
+                        // This may result in a SocketException (eg. An existing connection was forcibly
+                        // closed by the remote host) which we'll log and ignore as it means the socket
+                        // was already shut down.
+                        _socket.Shutdown(SocketShutdown.Both);
+                    }
+                    catch (SocketException ex)
+                    {
+                        _logger.LogInformation(ex, "Failure shutting down socket");
                     }
                 }
-                finally
-                {
-                    _ = _socketDisposeLock.Release();
-                }
+
+                _logger.LogDebug("[{SessionId}] Disposing socket.", SessionIdHex);
+                _socket.Dispose();
+                _logger.LogDebug("[{SessionId}] Disposed socket.", SessionIdHex);
+                _socket = null;
             }
         }
 
@@ -2048,25 +1937,6 @@ namespace Renci.SshNet
                         break;
                     }
 
-                    try
-                    {
-                        // Block until either data is available or the socket is closed
-                        var connectionClosedOrDataAvailable = socket.Poll(-1, SelectMode.SelectRead);
-                        if (connectionClosedOrDataAvailable && socket.Available == 0)
-                        {
-                            // connection with SSH server was closed or connection was reset
-                            break;
-                        }
-                    }
-                    catch (ObjectDisposedException)
-                    {
-                        // The socket was disposed by either:
-                        // * a call to Disconnect()
-                        // * a call to Dispose()
-                        // * a SSH_MSG_DISCONNECT received from server
-                        break;
-                    }
-
                     var message = ReceiveMessage(socket);
                     if (message is null)
                     {
@@ -2102,25 +1972,12 @@ namespace Renci.SshNet
         /// <param name="exp">The <see cref="Exception"/>.</param>
         private void RaiseError(Exception exp)
         {
-            var connectionException = exp as SshConnectionException;
-
             _logger.LogInformation(exp, "[{SessionId}] Raised exception", SessionIdHex);
 
-            if (_isDisconnecting)
+            if (_isDisconnecting && exp is SshConnectionException or ObjectDisposedException)
             {
-                // a connection exception which is raised while isDisconnecting is normal and
-                // should be ignored
-                if (connectionException != null)
-                {
-                    return;
-                }
-
-                // any timeout while disconnecting can be caused by loss of connectivity
-                // altogether and should be ignored
-                if (exp is SocketException socketException && socketException.SocketErrorCode == SocketError.TimedOut)
-                {
-                    return;
-                }
+                // Such an exception raised while isDisconnecting is expected and can be ignored.
+                return;
             }
 
             // "save" exception and set exception wait handle to ensure any waits are interrupted
@@ -2129,10 +1986,10 @@ namespace Renci.SshNet
 
             ErrorOccured?.Invoke(this, new ExceptionEventArgs(exp));
 
-            if (connectionException != null)
+            if (exp is SshConnectionException connectionException)
             {
                 _logger.LogInformation(exp, "[{SessionId}] Disconnecting after exception", SessionIdHex);
-                Disconnect(connectionException.DisconnectReason, exp.ToString());
+                Disconnect(connectionException.DisconnectReason, exp.Message);
             }
         }
 

+ 0 - 6
test/Renci.SshNet.Tests/Classes/AbstractionsTest.cs

@@ -8,12 +8,6 @@ namespace Renci.SshNet.Tests.Classes
     [TestClass]
     public class AbstractionsTest
     {
-        [TestMethod]
-        public void SocketAbstraction_CanWrite_ShouldReturnFalseWhenSocketIsNull()
-        {
-            Assert.IsFalse(SocketAbstraction.CanWrite(null));
-        }
-
         [TestMethod]
         public void CryptoAbstraction_GenerateRandom_ShouldPerformNoOpWhenDataIsZeroLength()
         {

+ 10 - 28
test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ConnectionReset.cs

@@ -64,8 +64,6 @@ namespace Renci.SshNet.Tests.Classes
 
             var connectionException = (SshConnectionException)exception;
             Assert.AreEqual(DisconnectReason.ConnectionLost, connectionException.DisconnectReason);
-            Assert.IsNull(connectionException.InnerException);
-            Assert.AreEqual("An established connection was aborted by the server.", connectionException.Message);
         }
 
         [TestMethod]
@@ -137,45 +135,29 @@ namespace Renci.SshNet.Tests.Classes
         public void ISession_WaitOnHandle_WaitHandle_ShouldThrowSshConnectionException()
         {
             var session = (ISession)Session;
-            var waitHandle = new ManualResetEvent(false);
+            using var waitHandle = new ManualResetEvent(false);
 
-            try
-            {
-                session.WaitOnHandle(waitHandle);
-                Assert.Fail();
-            }
-            catch (SshConnectionException ex)
-            {
-                Assert.AreEqual("An established connection was aborted by the server.", ex.Message);
-                Assert.IsNull(ex.InnerException);
-                Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason);
-            }
+            var ex = Assert.ThrowsExactly<SshConnectionException>(() => session.WaitOnHandle(waitHandle));
+
+            Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason);
         }
 
         [TestMethod]
         public void ISession_WaitOnHandle_WaitHandleAndTimeout_ShouldThrowSshConnectionException()
         {
             var session = (ISession)Session;
-            var waitHandle = new ManualResetEvent(false);
+            using var waitHandle = new ManualResetEvent(false);
 
-            try
-            {
-                session.WaitOnHandle(waitHandle, Timeout.InfiniteTimeSpan);
-                Assert.Fail();
-            }
-            catch (SshConnectionException ex)
-            {
-                Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason);
-                Assert.IsNull(ex.InnerException);
-                Assert.AreEqual("An established connection was aborted by the server.", ex.Message);
-            }
+            var ex = Assert.ThrowsExactly<SshConnectionException>(() => session.WaitOnHandle(waitHandle));
+
+            Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason);
         }
 
         [TestMethod]
         public void ISession_TryWait_WaitHandleAndTimeout_ShouldReturnDisconnected()
         {
             var session = (ISession)Session;
-            var waitHandle = new ManualResetEvent(false);
+            using var waitHandle = new ManualResetEvent(false);
 
             var result = session.TryWait(waitHandle, Timeout.InfiniteTimeSpan);
 
@@ -186,7 +168,7 @@ namespace Renci.SshNet.Tests.Classes
         public void ISession_TryWait_WaitHandleAndTimeoutAndException_ShouldReturnDisconnected()
         {
             var session = (ISession)Session;
-            var waitHandle = new ManualResetEvent(false);
+            using var waitHandle = new ManualResetEvent(false);
 
             var result = session.TryWait(waitHandle, Timeout.InfiniteTimeSpan, out var exception);
 

+ 11 - 4
test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsBadPacket.cs

@@ -91,13 +91,20 @@ namespace Renci.SshNet.Tests.Classes
         }
 
         [TestMethod]
-        public void ReceiveOnServerSocketShouldReturnZero()
+        public void ServerShouldBeDisconnected()
         {
-            var buffer = new byte[1];
+            try
+            {
+                var buffer = new byte[1];
 
-            var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
+                var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
 
-            Assert.AreEqual(0, actual);
+                Assert.AreEqual(0, actual); // FIN
+            }
+            catch (SocketException sx)
+            {
+                Assert.AreEqual(SocketError.ConnectionReset, sx.SocketErrorCode); // RST
+            }
         }
 
         [TestMethod]