فهرست منبع

Fix tests after recent changes.

drieseng 9 سال پیش
والد
کامیت
cfdaa75902

+ 0 - 11
src/Renci.SshNet.Silverlight/Session.SilverlightShared.cs

@@ -13,17 +13,6 @@ namespace Renci.SshNet
             isConnected = (_socket != null && _socket.Connected);
         }
 
-        /// <summary>
-        /// Closes the socket.
-        /// </summary>
-        /// <remarks>
-        /// This method will wait up to <c>10</c> seconds to send any remaining data.
-        /// </remarks>
-        partial void SocketDisconnect()
-        {
-            _socket.Close(10);
-        }
-
         partial void InternalRegisterMessage(string messageName)
         {
             lock (_messagesMetadata)

+ 4 - 0
src/Renci.SshNet.Tests/Classes/Channels/ChannelDirectTcpipTest.cs

@@ -18,6 +18,7 @@ namespace Renci.SshNet.Tests.Classes.Channels
     {
         private Mock<ISession> _sessionMock;
         private Mock<IForwardedPort> _forwardedPortMock;
+        private Mock<IConnectionInfo> _connectionInfoMock;
         private uint _localChannelNumber;
         private uint _localWindowSize;
         private uint _localPacketSize;
@@ -44,6 +45,7 @@ namespace Renci.SshNet.Tests.Classes.Channels
 
             _sessionMock = new Mock<ISession>(MockBehavior.Strict);
             _forwardedPortMock = new Mock<IForwardedPort>(MockBehavior.Strict);
+            _connectionInfoMock = new Mock<IConnectionInfo>(MockBehavior.Strict);
         }
 
         [TestMethod]
@@ -158,6 +160,8 @@ namespace Renci.SshNet.Tests.Classes.Channels
                             _remoteWindowSize, _remotePacketSize, _remoteChannelNumber))));
             _sessionMock.Setup(p => p.WaitOnHandle(It.IsAny<EventWaitHandle>()))
                 .Callback<WaitHandle>(p => p.WaitOne(-1));
+            _sessionMock.Setup(p => p.ConnectionInfo).Returns(_connectionInfoMock.Object);
+            _connectionInfoMock.Setup(p => p.Timeout).Returns(TimeSpan.FromSeconds(60));
             _sessionMock.Setup(p => p.TrySendMessage(It.IsAny<ChannelEofMessage>()))
                 .Returns(true)
                 .Callback<Message>(

+ 4 - 0
src/Renci.SshNet.Tests/Classes/Channels/ChannelDirectTcpipTest_Close_SessionIsConnectedAndChannelIsOpen.cs

@@ -16,6 +16,7 @@ namespace Renci.SshNet.Tests.Classes.Channels
     {
         private Mock<ISession> _sessionMock;
         private Mock<IForwardedPort> _forwardedPortMock;
+        private Mock<IConnectionInfo> _connectionInfoMock;
         private ChannelDirectTcpip _channel;
         private uint _localChannelNumber;
         private uint _localWindowSize;
@@ -72,6 +73,7 @@ namespace Renci.SshNet.Tests.Classes.Channels
 
             _sessionMock = new Mock<ISession>(MockBehavior.Strict);
             _forwardedPortMock = new Mock<IForwardedPort>(MockBehavior.Strict);
+            _connectionInfoMock = new Mock<IConnectionInfo>(MockBehavior.Strict);
 
             var sequence = new MockSequence();
             _sessionMock.InSequence(sequence).Setup(p => p.IsConnected).Returns(true);
@@ -92,6 +94,8 @@ namespace Renci.SshNet.Tests.Classes.Channels
                                     _remoteChannelNumber)));
                         w.WaitOne();
                     });
+            _sessionMock.InSequence(sequence).Setup(p => p.ConnectionInfo).Returns(_connectionInfoMock.Object);
+            _connectionInfoMock.InSequence(sequence).Setup(p => p.Timeout).Returns(TimeSpan.FromSeconds(60));
             _sessionMock.InSequence(sequence).Setup(p => p.IsConnected).Returns(true);
             _sessionMock.InSequence(sequence)
                 .Setup(

+ 145 - 9
src/Renci.SshNet/Abstractions/SocketAbstraction.cs

@@ -75,6 +75,68 @@ namespace Renci.SshNet.Abstractions
 #endif
         }
 
+        public static void ClearReadBuffer(Socket socket)
+        {
+            try
+            {
+                var buffer = new byte[256];
+                int bytesReceived;
+
+                do
+                {
+                    bytesReceived = ReadPartial(socket, buffer, 0, buffer.Length, TimeSpan.FromSeconds(2));
+                } while (bytesReceived > 0);
+            }
+            catch
+            {
+                // ignore any exceptions
+            }
+        }
+
+        public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout)
+        {
+#if FEATURE_SOCKET_SYNC
+            return socket.Receive(buffer, offset, size, SocketFlags.None);
+#elif FEATURE_SOCKET_EAP
+            var receiveCompleted = new ManualResetEvent(false);
+            var sendReceiveToken = new PartialSendReceiveToken(socket, receiveCompleted);
+            var args = new SocketAsyncEventArgs
+            {
+                RemoteEndPoint = socket.RemoteEndPoint,
+                UserToken = sendReceiveToken
+            };
+            args.Completed += ReceiveCompleted;
+            args.SetBuffer(buffer, offset, size);
+
+            try
+            {
+                if (socket.ReceiveAsync(args))
+                {
+                    if (!receiveCompleted.WaitOne(timeout))
+                        throw new SshOperationTimeoutException(
+                            string.Format(
+                                CultureInfo.InvariantCulture,
+                                "Socket read operation has timed out after {0:F0} milliseconds.",
+                                timeout.TotalMilliseconds));
+                }
+
+                if (args.SocketError != SocketError.Success)
+                    throw new SocketException((int) args.SocketError);
+
+                return args.BytesTransferred;
+            }
+            finally
+            {
+                // initialize token to avoid the waithandle getting used after it's disposed
+                args.UserToken = null;
+                args.Dispose();
+                receiveCompleted.Dispose();
+            }
+#else
+#error Receiving data from a Socket is not implemented.
+#endif
+        }
+
         /// <summary>
         /// Receives data from a bound <see cref="Socket"/>into a receive buffer.
         /// </summary>
@@ -96,10 +158,35 @@ namespace Renci.SshNet.Abstractions
         public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout)
         {
 #if FEATURE_SOCKET_SYNC
-            return socket.Receive(buffer, offset, size, SocketFlags.None);
+            var totalBytesRead = 0;
+            var totalBytesToRead = size;
+
+            do
+            {
+                try
+                {
+                    var bytesRead = socket.Receive(buffer, offset + totalBytesRead, totalBytesToRead - totalBytesRead, SocketFlags.None);
+                    if (bytesRead == 0)
+                        return 0;
+
+                    totalBytesRead += bytesRead;
+                }
+                catch (SocketException ex)
+                {
+                    if (IsErrorResumable(ex.SocketErrorCode))
+                    {
+                        ThreadAbstraction.Sleep(30);
+                        continue;
+                    }
+                    throw;
+                }
+            }
+            while (totalBytesRead < totalBytesToRead);
+
+            return totalBytesRead;
 #elif FEATURE_SOCKET_EAP
             var receiveCompleted = new ManualResetEvent(false);
-            var sendReceiveToken = new SendReceiveToken(socket, buffer, offset, size, receiveCompleted);
+            var sendReceiveToken = new BlockingSendReceiveToken(socket, buffer, offset, size, receiveCompleted);
 
             var args = new SocketAsyncEventArgs
             {
@@ -131,7 +218,7 @@ namespace Renci.SshNet.Abstractions
                 receiveCompleted.Dispose();
             }
 #else
-            #error Receiving data from a Socket is not implemented.
+#error Receiving data from a Socket is not implemented.
 #endif
         }
 
@@ -171,7 +258,7 @@ namespace Renci.SshNet.Abstractions
             } while (totalBytesSent < totalBytesToSend);
 #elif FEATURE_SOCKET_EAP
             var sendCompleted = new ManualResetEvent(false);
-            var sendReceiveToken = new SendReceiveToken(socket, data, offset, size, sendCompleted);
+            var sendReceiveToken = new BlockingSendReceiveToken(socket, data, offset, size, sendCompleted);
             var socketAsyncSendArgs = new SocketAsyncEventArgs
             {
                 RemoteEndPoint = socket.RemoteEndPoint,
@@ -203,7 +290,7 @@ namespace Renci.SshNet.Abstractions
                 sendCompleted.Dispose();
             }
 #else
-            #error Receiving data from a Socket is not implemented.
+#error Sending data to a Socket is not implemented.
 #endif
         }
 
@@ -232,21 +319,26 @@ namespace Renci.SshNet.Abstractions
 #if FEATURE_SOCKET_EAP && !FEATURE_SOCKET_SYNC
         private static void ReceiveCompleted(object sender, SocketAsyncEventArgs e)
         {
-            var sendReceiveToken = (SendReceiveToken) e.UserToken;
+            var sendReceiveToken = (Token) e.UserToken;
             if (sendReceiveToken != null)
                 sendReceiveToken.Process(e);
         }
 
         private static void SendCompleted(object sender, SocketAsyncEventArgs e)
         {
-            var sendReceiveToken = (SendReceiveToken)e.UserToken;
+            var sendReceiveToken = (Token) e.UserToken;
             if (sendReceiveToken != null)
                 sendReceiveToken.Process(e);
         }
 
-        private class SendReceiveToken
+        private interface Token
+        {
+            void Process(SocketAsyncEventArgs args);
+        }
+
+        private class BlockingSendReceiveToken : Token
         {
-            public SendReceiveToken(Socket socket, byte[] buffer, int offset, int size, EventWaitHandle completionWaitHandle)
+            public BlockingSendReceiveToken(Socket socket, byte[] buffer, int offset, int size, EventWaitHandle completionWaitHandle)
             {
                 _socket = socket;
                 _buffer = buffer;
@@ -312,6 +404,50 @@ namespace Renci.SshNet.Abstractions
             private readonly byte[] _buffer;
             private int _offset;
         }
+
+        private class PartialSendReceiveToken : Token
+        {
+            public PartialSendReceiveToken(Socket socket, EventWaitHandle completionWaitHandle)
+            {
+                _socket = socket;
+                _completionWaitHandle = completionWaitHandle;
+            }
+
+            public void Process(SocketAsyncEventArgs args)
+            {
+                if (args.SocketError == SocketError.Success)
+                {
+                    _completionWaitHandle.Set();
+                    return;
+                }
+
+                if (IsErrorResumable(args.SocketError))
+                {
+                    ThreadAbstraction.Sleep(30);
+                    ResumeOperation(args);
+                    return;
+                }
+
+                // we're dealing with a (fatal) error
+                _completionWaitHandle.Set();
+            }
+
+            private void ResumeOperation(SocketAsyncEventArgs args)
+            {
+                switch (args.LastOperation)
+                {
+                    case SocketAsyncOperation.Receive:
+                        _socket.ReceiveAsync(args);
+                        break;
+                    case SocketAsyncOperation.Send:
+                        _socket.SendAsync(args);
+                        break;
+                }
+            }
+
+            private readonly EventWaitHandle _completionWaitHandle;
+            private readonly Socket _socket;
+        }
 #endif // FEATURE_SOCKET_EAP && !FEATURE_SOCKET_SYNC
     }
 }

+ 2 - 2
src/Renci.SshNet/Channels/ChannelDirectTcpip.cs

@@ -11,7 +11,7 @@ namespace Renci.SshNet.Channels
     /// <summary>
     /// Implements "direct-tcpip" SSH channel.
     /// </summary>
-    internal partial class ChannelDirectTcpip : ClientChannel, IChannelDirectTcpip
+    internal class ChannelDirectTcpip : ClientChannel, IChannelDirectTcpip
     {
         private readonly object _socketLock = new object();
 
@@ -91,7 +91,7 @@ namespace Renci.SshNet.Channels
             {
                 try
                 {
-                    var read = SocketAbstraction.Read(_socket, buffer, 0, buffer.Length, ConnectionInfo.Timeout);
+                    var read = SocketAbstraction.ReadPartial(_socket, buffer, 0, buffer.Length, ConnectionInfo.Timeout);
                     if (read > 0)
                     {
 #if TUNING

+ 2 - 2
src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs

@@ -10,7 +10,7 @@ namespace Renci.SshNet.Channels
     /// <summary>
     /// Implements "forwarded-tcpip" SSH channel.
     /// </summary>
-    internal partial class ChannelForwardedTcpip : ServerChannel, IChannelForwardedTcpip
+    internal class ChannelForwardedTcpip : ServerChannel, IChannelForwardedTcpip
     {
         private readonly object _socketShutdownAndCloseLock = new object();
         private Socket _socket;
@@ -83,7 +83,7 @@ namespace Renci.SshNet.Channels
                 {
                 try
                 {
-                    var read = SocketAbstraction.Read(_socket, buffer, 0, buffer.Length, ConnectionInfo.Timeout);
+                    var read = SocketAbstraction.ReadPartial(_socket, buffer, 0, buffer.Length, ConnectionInfo.Timeout);
                     if (read > 0)
                     {
 #if TUNING

+ 1 - 1
src/Renci.SshNet/Renci.SshNet.csproj

@@ -18,7 +18,7 @@
     <DebugType>full</DebugType>
     <Optimize>false</Optimize>
     <OutputPath>bin\Debug\</OutputPath>
-    <DefineConstants>TRACE;DEBUG;TUNING;FEATURE_RNG_CSP;FEATURE_SOCKET_EAP;FEATURE_SOCKET_POLL;FEATURE_SOCKET_SETSOCKETOPTION;FEATURE_STREAM_APM;FEATURE_DNS_SYNC;FEATURE_THREAD_THREADPOOL;FEATURE_THREAD_SLEEP;FEATURE_HASH_MD5;FEATURE_HASH_SHA1;FEATURE_HASH_SHA256;FEATURE_HASH_SHA384;FEATURE_HASH_SHA512;FEATURE_HASH_RIPEMD160;FEATURE_HMAC_MD5;FEATURE_HMAC_SHA1;FEATURE_HMAC_SHA256;FEATURE_HMAC_SHA384;FEATURE_HMAC_SHA512;FEATURE_HMAC_RIPEMD160;FEATURE_MEMORYSTREAM_GETBUFFER</DefineConstants>
+    <DefineConstants>TRACE;DEBUG;TUNING;FEATURE_RNG_CSP;FEATURE_SOCKET_SYNC;FEATURE_SOCKET_APM;FEATURE_SOCKET_POLL;FEATURE_SOCKET_SETSOCKETOPTION;FEATURE_STREAM_APM;FEATURE_DNS_SYNC;FEATURE_THREAD_THREADPOOL;FEATURE_THREAD_SLEEP;FEATURE_HASH_MD5;FEATURE_HASH_SHA1;FEATURE_HASH_SHA256;FEATURE_HASH_SHA384;FEATURE_HASH_SHA512;FEATURE_HASH_RIPEMD160;FEATURE_HMAC_MD5;FEATURE_HMAC_SHA1;FEATURE_HMAC_SHA256;FEATURE_HMAC_SHA384;FEATURE_HMAC_SHA512;FEATURE_HMAC_RIPEMD160;FEATURE_MEMORYSTREAM_GETBUFFER</DefineConstants>
     <ErrorReport>prompt</ErrorReport>
     <WarningLevel>4</WarningLevel>
     <DocumentationFile>bin\Debug\Renci.SshNet.xml</DocumentationFile>

+ 0 - 10
src/Renci.SshNet/Session.NET.cs

@@ -79,16 +79,6 @@ namespace Renci.SshNet
             }
         }
 
-        /// <summary>
-        /// Closes the socket and allows the socket to be reused after the current connection is closed.
-        /// </summary>
-        /// <exception cref="SocketException">An error occurred when trying to access the socket.</exception>
-        partial void SocketDisconnect()
-        {
-            // TODO should disconnect instead ?!!
-            _socket.Dispose();
-        }
-
         [Conditional("DEBUG")]
         partial void Log(string text)
         {

+ 8 - 7
src/Renci.SshNet/Session.cs

@@ -1864,12 +1864,6 @@ namespace Renci.SshNet
 //#endif // FEATURE_SOCKET_SETSOCKETOPTION
         }
 
-        /// <summary>
-        /// Closes the socket.
-        /// </summary>
-        /// <exception cref="SocketException">An error occurred when trying to access the socket.</exception>
-        partial void SocketDisconnect();
-
         /// <summary>
         /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
         /// </summary>
@@ -1958,7 +1952,10 @@ namespace Renci.SshNet
                     if (_socket != null)
                     {
                         if (_socket.Connected)
-                            SocketDisconnect();
+                        {
+                            _socket.Shutdown(SocketShutdown.Send);
+                            SocketAbstraction.ClearReadBuffer(_socket);
+                        }
                         _socket.Dispose();
                         _socket = null;
                     }
@@ -1979,6 +1976,10 @@ namespace Renci.SshNet
                     HandleMessageCore(message);
                 }
             }
+            catch (SocketException ex)
+            {
+                RaiseError(new SshConnectionException(ex.Message, DisconnectReason.ConnectionLost, ex));
+            }
             catch (Exception exp)
             {
                 RaiseError(exp);