Parcourir la source

Modify ReceiveMessage to return null when connection is closed.
Added TrySocketRead method that returns 0 (zeo) when connection is closed.
Remove SocketRead(int length) overload.

Added IsConnected extension method to Socket.
Modify MessageListener() to use this extension method as condition for the message loop.
Do not bother checking readSockets as the connected check of the socket allows us to combine both the connection closed and socket disposed conditions.

drieseng il y a 9 ans
Parent
commit
9e8ce4a868

+ 7 - 0
src/Renci.SshNet/Common/Extensions.cs

@@ -327,5 +327,12 @@ namespace Renci.SshNet
         {
             return SocketAbstraction.CanWrite(socket);
         }
+
+        internal static bool IsConnected(this Socket socket)
+        {
+            if (socket == null)
+                return false;
+            return socket.Connected;
+        }
     }
 }

+ 3 - 3
src/Renci.SshNet/Session.NET.cs

@@ -61,7 +61,7 @@ namespace Renci.SshNet
             lock (_socketDisposeLock)
             {
 #if FEATURE_SOCKET_POLL
-                if (_socket == null || !_socket.Connected)
+                if (!_socket.IsConnected())
                 {
                     isConnected = false;
                     return;
@@ -69,11 +69,11 @@ namespace Renci.SshNet
 
                 lock (_socketReadLock)
                 {
-                    var connectionClosedOrDataAvailable = _socket.Poll(1, SelectMode.SelectRead);
+                    var connectionClosedOrDataAvailable = _socket.Poll(0, SelectMode.SelectRead);
                     isConnected = !(connectionClosedOrDataAvailable && _socket.Available == 0);
                 }
 #else
-                isConnected = _socket != null && _socket.Connected;
+                isConnected = _socket.IsConnected();
 #endif // FEATURE_SOCKET_POLL
             }
 

+ 78 - 33
src/Renci.SshNet/Session.cs

@@ -829,7 +829,7 @@ namespace Renci.SshNet
             // atomically, and only after the packet has actually been sent
             lock (_socketWriteLock)
             {
-                if (_socket == null || !_socket.Connected)
+                if (!_socket.IsConnected())
                     throw new SshConnectionException("Client not connected.");
 
                 byte[] hash = null;
@@ -914,9 +914,8 @@ namespace Renci.SshNet
         /// Receives the message from the server.
         /// </summary>
         /// <returns>
-        /// The incoming SSH message.
+        /// The incoming SSH message, or <c>null</c> if the connection with the SSH server was closed.
         /// </returns>
-        /// <exception cref="SshConnectionException"></exception>
         /// <remarks>
         /// We need no locking here since all messages are read by a single thread.
         /// </remarks>
@@ -945,7 +944,12 @@ namespace Renci.SshNet
             {
 #endif // FEATURE_SOCKET_POLL
                 //  Read first block - which starts with the packet length
-                var firstBlock = SocketRead(blockSize);
+                var firstBlock = new byte[blockSize];
+                if (TrySocketRead(firstBlock, 0, blockSize) == 0)
+                {
+                    // connection with SSH server was closed
+                    return null;
+                }
 
 #if DEBUG_GERT
                 DiagnosticAbstraction.Log(string.Format("[{0}] FirstBlock [{1}]: {2}", ToHex(SessionId), blockSize, ToHex(firstBlock)));
@@ -988,7 +992,10 @@ namespace Renci.SshNet
 
                 if (bytesToRead > 0)
                 {
-                    SocketRead(data, blockSize + inboundPacketSequenceLength, bytesToRead);
+                    if (TrySocketRead(data, blockSize + inboundPacketSequenceLength, bytesToRead) == 0)
+                    {
+                        return null;
+                    }
                 }
 #if FEATURE_SOCKET_POLL
             }
@@ -1787,20 +1794,6 @@ namespace Renci.SshNet
             _socket.ReceiveBufferSize = socketBufferSize;
         }
 
-        /// <summary>
-        /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
-        /// </summary>
-        /// <param name="length">The number of bytes to read.</param>
-        /// <returns>
-        /// The bytes read from the server.
-        /// </returns>
-        private byte[] SocketRead(int length)
-        {
-            var buffer = new byte[length];
-            SocketRead(buffer, 0, length);
-            return buffer;
-        }
-
         /// <summary>
         /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
         /// </summary>
@@ -1827,6 +1820,22 @@ namespace Renci.SshNet
             return bytesRead;
         }
 
+        /// <summary>
+        /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
+        /// </summary>
+        /// <param name="buffer">An array of type <see cref="byte"/> that is the storage location for the received data.</param>
+        /// <param name="offset">The position in <paramref name="buffer"/> parameter to store the received data.</param>
+        /// <param name="length">The number of bytes to read.</param>
+        /// <returns>
+        /// The number of bytes read.
+        /// </returns>
+        /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
+        /// <exception cref="SocketException">The read failed.</exception>
+        private int TrySocketRead(byte[] buffer, int offset, int length)
+        {
+            return SocketAbstraction.Read(_socket, buffer, offset, length, InfiniteTimeSpan);
+        }
+
         /// <summary>
         /// Performs a blocking read on the socket until a line is read.
         /// </summary>
@@ -1920,26 +1929,56 @@ namespace Renci.SshNet
                 var readSockets = new List<Socket> {_socket};
 
                 // remain in message loop until socket is shut down or until we're disconnecting
-                while (true)
+                while (_socket.IsConnected())
                 {
 #if FEATURE_SOCKET_POLL
+                    // if the socket is already disposed when Select is invoked, then a SocketException
+                    // stating "An operation was attempted on something that is not a socket" is thrown;
+                    // we attempt to avoid this exception by having an IsConnected() that can break the
+                    // message loop
+                    //
+                    // note that there's no guarantee that the socket will not be disposed between the
+                    // IsConnected() check and the Select invocation; we can't take a "dispose" lock
+                    // that includes the Select invocation as we want Dispose() to be able to interrupt
+                    // the Select
+
+                    // perform a blocking select to determine whether there's is data available to be
+                    // read; we do not use a blocking read to allow us to use Socket.Poll to determine
+                    // if the connection is still available (in IsSocketConnected
                     Socket.Select(readSockets, null, null, -1);
 
-                    if (readSockets.Count == 0)
+                    // the Select invocation will be interrupted in one of the following conditions:
+                    // * data is available to be read
+                    //   => the socket will not be removed from "readSockets"
+                    // * the socket connection is closed during the Select invocation
+                    //   => the socket will be removed from "readSockets"
+                    // * the socket is disposed during the Select invocation
+                    //   => the socket will not be removed from "readSocket"
+                    // 
+                    // since we handle the second and third condition the same way and Socket.Connected
+                    // allows us to check for both conditions, we use that instead of both checking for
+                    // the removal from "readSockets" and the Connection check
+                    if (!_socket.IsConnected())
+                    {
+                        // connection with SSH server was closed or socket was disposed;
+                        // break out of the message loop
                         break;
-
-                    // when the socket is disposed while a Select is executing, then the
-                    // Select will be interrupted; the socket will not be removed from
-                    // readSocket, and therefore we need to explicitly check if the
-                    // socket is still connected
+                    }
 #endif // FEATURE_SOCKET_POLL
-                    var socket = _socket;
-                    if (socket == null || !socket.Connected)
-                        break;
 
                     var message = ReceiveMessage();
+                    if (message == null)
+                    {
+                        // connection with SSH server was closed;
+                        // break out of the message loop
+                        break;
+                    }
+
                     HandleMessageCore(message);
                 }
+
+                // connection with SSH server was closed
+                RaiseError(CreateConnectionAbortedByServerException());
             }
             catch (SocketException ex)
             {
@@ -2305,7 +2344,13 @@ namespace Renci.SshNet
             _keyExchangeInProgress = false;
         }
 
-#region IDisposable implementation
+        private static SshConnectionException CreateConnectionAbortedByServerException()
+        {
+            return new SshConnectionException("An established connection was aborted by the server.",
+                DisconnectReason.ConnectionLost);
+        }
+
+        #region IDisposable implementation
 
         private bool _disposed;
 
@@ -2396,9 +2441,9 @@ namespace Renci.SshNet
             Dispose(false);
         }
 
-#endregion IDisposable implementation
+        #endregion IDisposable implementation
 
-#region ISession implementation
+        #region ISession implementation
 
         /// <summary>
         /// Gets or sets the connection info.
@@ -2483,6 +2528,6 @@ namespace Renci.SshNet
             return TrySendMessage(message);
         }
 
-#endregion ISession implementation
+        #endregion ISession implementation
     }
 }