浏览代码

Improved accuracy of IsConnected on .NET.

Gert Driesen 11 年之前
父节点
当前提交
2192dc0241

+ 4 - 1
Renci.SshClient/Build/nuget/SSH.NET.nuspec

@@ -13,11 +13,14 @@
         <releaseNotes>2014.4.6-beta2
 ==============
 
+New Features:
+
+    * Improved accuracy of IsConnected on .NET
+
 Fixes:
 
     * Client channels are no longer closed on dispose (issue #1943)
 
-
 2014.4.6-beta1
 ==============
 

+ 252 - 2
Renci.SshClient/Renci.SshNet/Session.NET.cs

@@ -19,19 +19,70 @@ namespace Renci.SshNet
             new TraceSource("SshNet.Logging");
 #endif
 
+        /// <summary>
+        /// Holds the lock object to ensure read access to the socket is synchronized.
+        /// </summary>
+        private readonly object _socketReadLock = new object();
+
         /// <summary>
         /// Gets a value indicating whether the socket is connected.
         /// </summary>
         /// <value>
         /// <c>true</c> if the socket is connected; otherwise, <c>false</c>.
         /// </value>
+        /// <remarks>
+        /// <para>
+        /// As a first check we verify whether <see cref="Socket.Connected"/> is
+        /// <c>true</c>. However, this only returns the state of the socket as of
+        /// the last I/O operation. Therefore we use the combination of Socket.Poll
+        /// with mode SelectRead and Socket.Available to verify if the socket is
+        /// still connected.
+        /// </para>
+        /// <para>
+        /// The MSDN doc mention the following on the return value of <see cref="Socket.Poll(int, SelectMode)"/>
+        /// with mode <see cref="SelectMode.SelectRead"/>:
+        /// <list type="bullet">
+        ///     <item>
+        ///         <description><c>true</c> if data is available for reading;</description>
+        ///     </item>
+        ///     <item>
+        ///         <description><c>true</c> if the connection has been closed, reset, or terminated; otherwise, returns <c>false</c>.</description>
+        ///     </item>
+        /// </list>
+        /// </para>
+        /// <para>
+        /// <c>Conclusion:</c> when the return value is <c>true</c> - but no data is available for reading - then
+        ///  the socket is no longer connected.
+        /// </para>
+        /// <para>
+        /// When a <see cref="Socket"/> is used from multiple threads, there's a race condition
+        /// between the invocation of <see cref="Socket.Poll(int, SelectMode)"/> and the moment
+        /// when the value of <see cref="Socket.Available"/> is obtained. As a workaround, we signal
+        /// when bytes are read from the <see cref="Socket"/>.
+        /// </para>
+        /// </remarks>
         partial void IsSocketConnected(ref bool isConnected)
         {
             isConnected = (_socket != null && _socket.Connected);
             if (isConnected)
             {
-                var connectionClosedOrDataAvailable = _socket.Poll(1000, SelectMode.SelectRead);
-                isConnected = !(connectionClosedOrDataAvailable && _socket.Available == 0);
+                // synchronize this to ensure thread B does not reset the wait handle before
+                // thread A was able to check whether "bytes read from socket" signal was
+                // actually received
+                lock (_socketReadLock)
+                {
+                    _bytesReadFromSocket.Reset();
+                    var connectionClosedOrDataAvailable = _socket.Poll(1000, SelectMode.SelectRead);
+                    isConnected = !(connectionClosedOrDataAvailable && _socket.Available == 0);
+                    if (!isConnected)
+                    {
+                        // the race condition is between the Socket.Poll call and
+                        // Socket.Available, but the event handler - where we signal that
+                        // bytes have been received from the socket - is sometimes invoked
+                        // shortly after
+                        isConnected = _bytesReadFromSocket.WaitOne(500);
+                    }
+                }
             }
         }
 
@@ -120,6 +171,9 @@ namespace Renci.SshNet
                     var receivedBytes = this._socket.Receive(buffer, receivedTotal, length - receivedTotal, SocketFlags.None);
                     if (receivedBytes > 0)
                     {
+                        // signal that bytes have been read from the socket
+                        // this is used to improve accuracy of Session.IsSocketConnected
+                        _bytesReadFromSocket.Set();
                         receivedTotal += receivedBytes;
                         continue;
                     }
@@ -192,5 +246,201 @@ namespace Renci.SshNet
         {
             this._log.TraceEvent(TraceEventType.Verbose, 1, text);
         }
+
+#if ASYNC_SOCKET_READ
+        private void SocketRead(int length, ref byte[] buffer)
+        {
+            var state = new SocketReadState(_socket, length, ref buffer);
+
+            _socket.BeginReceive(buffer, 0, length, SocketFlags.None, SocketReceiveCallback, state);
+
+            var readResult = state.Wait();
+            switch (readResult)
+            {
+                case SocketReadResult.Complete:
+                    break;
+                case SocketReadResult.ConnectionLost:
+                    if (_isDisconnecting)
+                        throw new SshConnectionException(
+                            "An established connection was aborted by the software in your host machine.",
+                            DisconnectReason.ConnectionLost);
+                    throw new SshConnectionException("An established connection was aborted by the server.",
+                        DisconnectReason.ConnectionLost);
+                case SocketReadResult.Failed:
+                    var socketException = state.Exception as SocketException;
+                    if (socketException != null)
+                    {
+                        if (socketException.SocketErrorCode == SocketError.ConnectionAborted)
+                        {
+                            buffer = new byte[length];
+                            this.Disconnect();
+                            return;
+                        }
+                    }
+                    throw state.Exception;
+            }
+        }
+
+        private void SocketReceiveCallback(IAsyncResult ar)
+        {
+            var state = ar.AsyncState as SocketReadState;
+            var socket = state.Socket;
+
+            try
+            {
+                var bytesReceived = socket.EndReceive(ar);
+                if (bytesReceived > 0)
+                {
+                    _bytesReadFromSocket.Set();
+                    state.BytesRead += bytesReceived;
+                    if (state.BytesRead < state.TotalBytesToRead)
+                    {
+                        socket.BeginReceive(state.Buffer, state.BytesRead, state.TotalBytesToRead - state.BytesRead,
+                            SocketFlags.None, SocketReceiveCallback, state);
+                    }
+                    else
+                    {
+                        // we received all bytes that we wanted, so lets mark the read
+                        // complete
+                        state.Complete();
+                    }
+                }
+                else
+                {
+                    // the remote host shut down the connection; this could also have been
+                    // triggered by a SSH_MSG_DISCONNECT sent by the client
+                    state.ConnectionLost();
+                }
+            }
+            catch (SocketException ex)
+            {
+                if (ex.SocketErrorCode != SocketError.ConnectionAborted)
+                {
+                    if (ex.SocketErrorCode == SocketError.WouldBlock ||
+                        ex.SocketErrorCode == SocketError.IOPending ||
+                        ex.SocketErrorCode == SocketError.NoBufferSpaceAvailable)
+                    {
+                        // socket buffer is probably empty, wait and try again
+                        Thread.Sleep(30);
+
+                        socket.BeginReceive(state.Buffer, state.BytesRead, state.TotalBytesToRead - state.BytesRead,
+                            SocketFlags.None, SocketReceiveCallback, state);
+                        return;
+                    }
+                }
+
+                state.Fail(ex);
+            }
+            catch (Exception ex)
+            {
+                state.Fail(ex);
+            }
+        }
+
+        private class SocketReadState
+        {
+            private SocketReadResult _result;
+
+            /// <summary>
+            /// WaitHandle to signal that read from socket has completed (either successfully
+            /// or with failure)
+            /// </summary>
+            private EventWaitHandle _socketReadComplete;
+
+            public SocketReadState(Socket socket, int totalBytesToRead, ref byte[] buffer)
+            {
+                Socket = socket;
+                TotalBytesToRead = totalBytesToRead;
+                Buffer = buffer;
+                _socketReadComplete = new ManualResetEvent(false);
+            }
+
+            /// <summary>
+            /// Gets the <see cref="Socket"/> to read from.
+            /// </summary>
+            /// <value>
+            /// The <see cref="Socket"/> to read from.
+            /// </value>
+            public Socket Socket { get; private set; }
+
+            /// <summary>
+            /// Gets or sets the number of bytes that have been read from the <see cref="Socket"/>.
+            /// </summary>
+            /// <value>
+            /// The number of bytes that have been read from the <see cref="Socket"/>.
+            /// </value>
+            public int BytesRead { get; set; }
+
+            /// <summary>
+            /// Gets the total number of bytes to read from the <see cref="Socket"/>.
+            /// </summary>
+            /// <value>
+            /// The total number of bytes to read from the <see cref="Socket"/>.
+            /// </value>
+            public int TotalBytesToRead { get; private set; }
+
+            /// <summary>
+            /// Gets or sets the buffer to hold the bytes that have been read.
+            /// </summary>
+            /// <value>
+            /// The buffer to hold the bytes that have been read.
+            /// </value>
+            public byte[] Buffer { get; private set; }
+
+            /// <summary>
+            /// Gets or sets the exception that was thrown while reading from the
+            /// <see cref="Socket"/>.
+            /// </summary>
+            /// <value>
+            /// The exception that was thrown while reading from the <see cref="Socket"/>,
+            /// or <c>null</c> if no exception was thrown.
+            /// </value>
+            public Exception Exception { get; private set; }
+
+            /// <summary>
+            /// Signals that the total number of bytes has been read successfully.
+            /// </summary>
+            public void Complete()
+            {
+                _result = SocketReadResult.Complete;
+                _socketReadComplete.Set();
+            }
+
+            /// <summary>
+            /// Signals that the socket read failed.
+            /// </summary>
+            /// <param name="cause">The <see cref="Exception"/> that caused the read to fail.</param>
+            public void Fail(Exception cause)
+            {
+                Exception = cause;
+                _result = SocketReadResult.Failed;
+                _socketReadComplete.Set();
+            }
+
+            /// <summary>
+            /// Signals that the connection to the server was lost.
+            /// </summary>
+            public void ConnectionLost()
+            {
+                _result = SocketReadResult.ConnectionLost;
+                _socketReadComplete.Set();
+            }
+
+            public SocketReadResult Wait()
+            {
+                _socketReadComplete.WaitOne();
+                _socketReadComplete.Dispose();
+                _socketReadComplete = null;
+                return _result;
+            }
+        }
+
+        private enum SocketReadResult
+        {
+            Complete,
+            ConnectionLost,
+            Failed
+        }
+#endif
     }
 }

+ 11 - 0
Renci.SshClient/Renci.SshNet/Session.cs

@@ -107,6 +107,11 @@ namespace Renci.SshNet
         /// </summary>
         private EventWaitHandle _keyExchangeCompletedWaitHandle = new ManualResetEvent(false);
 
+        /// <summary>
+        /// WaitHandle to signal that bytes have been read from the socket.
+        /// </summary>
+        private EventWaitHandle _bytesReadFromSocket = new ManualResetEvent(false);
+
         /// <summary>
         /// WaitHandle to signal that key exchange is in progress.
         /// </summary>
@@ -2084,6 +2089,12 @@ namespace Renci.SshNet
                         this._keyExchange.Dispose();
                         this._keyExchange = null;
                     }
+
+                    if (_bytesReadFromSocket != null)
+                    {
+                        _bytesReadFromSocket.Dispose();
+                        _bytesReadFromSocket = null;
+                    }
                 }
 
                 // Note disposing has been done.