Pārlūkot izejas kodu

Fix lots of issues in Silverlight socket methods.
Fix SocketReadLine (in both .NET and Silverlight) to return null when remote server has shut down the socket.
Skip any lines in response from SSH server before the protocol identification string.
Fixes issue #2223.

Gert Driesen 11 gadi atpakaļ
vecāks
revīzija
e76c06251d

+ 161 - 97
Renci.SshClient/Renci.SshNet.Silverlight/Session.SilverlightShared.cs

@@ -1,164 +1,212 @@
 using System;
+using System.Collections.Generic;
+using System.Globalization;
 using System.Linq;
 using System.Net;
 using System.Net.Sockets;
 using System.Threading;
 using Renci.SshNet.Common;
 using Renci.SshNet.Messages.Transport;
-using System.Collections.Generic;
 
 namespace Renci.SshNet
 {
     public partial class Session
     {
-        private readonly AutoResetEvent _autoEvent = new AutoResetEvent(false);
+        private const byte Null = 0x00;
+        private const byte CarriageReturn = 0x0d;
+        private const byte LineFeed = 0x0a;
+
+        private readonly AutoResetEvent _connectEvent = new AutoResetEvent(false);
         private readonly AutoResetEvent _sendEvent = new AutoResetEvent(false);
         private readonly AutoResetEvent _receiveEvent = new AutoResetEvent(false);
-        private bool _isConnected;
 
         /// <summary>
         /// Gets a value indicating whether the socket is connected.
         /// </summary>
-        /// <value>
-        /// <c>true</c> if the socket is connected; otherwise, <c>false</c>.
-        /// </value>
+        /// <param name="isConnected"><c>true</c> if the socket is connected; otherwise, <c>false</c></param>
         partial void IsSocketConnected(ref bool isConnected)
         {
-            isConnected = (this._socket != null && this._socket.Connected && this._isConnected);
+            isConnected = (_socket != null && _socket.Connected);
         }
 
+        /// <summary>
+        /// Establishes a socket connection to the specified host and port.
+        /// </summary>
+        /// <param name="host">The host name of the server to connect to.</param>
+        /// <param name="port">The port to connect to.</param>
+        /// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="Renci.SshNet.ConnectionInfo.Timeout"/>.</exception>
+        /// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
         partial void SocketConnect(string host, int port)
         {
-            var ep = new DnsEndPoint(host, port);
-            this._socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+            var timeout = ConnectionInfo.Timeout;
+            var ipAddress = host.GetIPAddress();
+            var ep = new IPEndPoint(ipAddress, port);
 
-            var args = new SocketAsyncEventArgs();
-            args.UserToken = this._socket;
-            args.RemoteEndPoint = ep;
-            args.Completed += OnConnect;
+            _socket = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
 
-            this._socket.ConnectAsync(args);
-            this._autoEvent.WaitOne(this.ConnectionInfo.Timeout);
+            var args = CreateSocketAsyncEventArgs(_connectEvent);
+            if (_socket.ConnectAsync(args))
+            {
+                if (!_connectEvent.WaitOne(timeout))
+                    throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
+                        "Connection failed to establish within {0:F0} milliseconds.", timeout.TotalMilliseconds));
+            }
 
             if (args.SocketError != SocketError.Success)
-                throw new SocketException((int)args.SocketError);
+                throw new SocketException((int) args.SocketError);
         }
 
+        /// <summary>
+        /// Closes the socket.
+        /// </summary>
+        /// <remarks>
+        /// This method will wait up to <c>10</c> seconds to send any remaining data.
+        /// </remarks>
         partial void SocketDisconnect()
         {
-            this._socket.Close(10000);
+            _socket.Close(10);
         }
 
-        partial void SocketReadLine(ref string response)
+        /// <summary>
+        /// Performs a blocking read on the socket until a line is read.
+        /// </summary>
+        /// <param name="response">The line read from the socket, or <c>null</c> when the remote server has shutdown and all data has been received.</param>
+        /// <param name="timeout">A <see cref="TimeSpan"/> that represents the time to wait until a line is read.</param>
+        /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
+        /// <exception cref="SocketException">An error occurred when trying to access the socket.</exception>
+        partial void SocketReadLine(ref string response, TimeSpan timeout)
         {
             var encoding = new ASCIIEncoding();
-
-            //  Read data one byte at a time to find end of line and leave any unhandled information in the buffer to be processed later
             var buffer = new List<byte>();
-
             var data = new byte[1];
+
+            // read data one byte at a time to find end of line and leave any unhandled information in the buffer
+            // to be processed by subsequent invocations
             do
             {
-                var args = new SocketAsyncEventArgs();
-                args.SetBuffer(data, 0, data.Length);
-                args.UserToken = this._socket;
-                args.RemoteEndPoint = this._socket.RemoteEndPoint;
-                args.Completed += OnReceive;
-                this._socket.ReceiveAsync(args);
+                var args = CreateSocketAsyncEventArgs(_receiveEvent, data, 0, data.Length);
+                if (_socket.ReceiveAsync(args))
+                {
+                    if (!_receiveEvent.WaitOne(timeout))
+                        throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
+                            "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds));
+                }
 
-                if (!this._receiveEvent.WaitOne(this.ConnectionInfo.Timeout))
-                    throw new SshOperationTimeoutException("Socket read operation has timed out");
+                if (args.SocketError != SocketError.Success)
+                    throw new SocketException((int) args.SocketError);
 
-                //  If zero bytes received then exit
                 if (args.BytesTransferred == 0)
+                    // the remote server shut down the socket
                     break;
 
                 buffer.Add(data[0]);
             }
-            while (!(buffer.Count > 0 && (buffer[buffer.Count - 1] == 0x0A || buffer[buffer.Count - 1] == 0x00)));
+            while (!(buffer.Count > 0 && (buffer[buffer.Count - 1] == LineFeed || buffer[buffer.Count - 1] == Null)));
 
-            // Return an empty version string if the buffer consists of a 0x00 character.
-            if (buffer.Count > 0 && buffer[buffer.Count - 1] == 0x00)
-            {
-                response = string.Empty;
-            }
-            else if (buffer.Count == 0) 
+            if (buffer.Count == 0)
+                response = null;
+            else if (buffer.Count == 1 && buffer[buffer.Count - 1] == 0x00)
+                // return an empty version string if the buffer consists of only a 0x00 character
                 response = string.Empty;
-            else if (buffer.Count > 1 && buffer[buffer.Count - 2] == 0x0D)
+            else if (buffer.Count > 1 && buffer[buffer.Count - 2] == CarriageReturn)
+                // strip trailing CRLF
                 response = encoding.GetString(buffer.ToArray(), 0, buffer.Count - 2);
-            else
+            else if (buffer.Count > 1 && buffer[buffer.Count - 1] == LineFeed)
+                // strip trailing LF
                 response = encoding.GetString(buffer.ToArray(), 0, buffer.Count - 1);
+            else
+                response = encoding.GetString(buffer.ToArray(), 0, buffer.Count);
         }
 
-        partial void SocketRead(int length, ref byte[] buffer)
+        /// <summary>
+        /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
+        /// </summary>
+        /// <param name="length">The number of bytes to read.</param>
+        /// <param name="buffer">The buffer to read to.</param>
+        /// <param name="timeout">A <see cref="TimeSpan"/> that represents the time to wait until <paramref name="length"/> bytes a read.</param>
+        /// <exception cref="SshConnectionException">The socket is closed.</exception>
+        /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
+        /// <exception cref="SocketException">The read failed.</exception>
+        partial void SocketRead(int length, ref byte[] buffer, TimeSpan timeout)
         {
-            var receivedTotal = 0;  // how many bytes is already received
+            var totalBytesReceived = 0;  // how many bytes are already received
 
             do
             {
-                var args = new SocketAsyncEventArgs();
-                args.SetBuffer(buffer, receivedTotal, length - receivedTotal);
-                args.UserToken = this._socket;
-                args.RemoteEndPoint = this._socket.RemoteEndPoint;
-                args.Completed += OnReceive;
-                this._socket.ReceiveAsync(args);
-
-                this._receiveEvent.WaitOne(this.ConnectionInfo.Timeout);
-
-                if (args.SocketError == SocketError.WouldBlock ||
-                    args.SocketError == SocketError.IOPending ||
-                    args.SocketError == SocketError.NoBufferSpaceAvailable)
-                {
-                    // socket buffer is probably empty, wait and try again
-                    Thread.Sleep(30);
-                    continue;
-                }
-                if (args.SocketError != SocketError.Success)
+                var args = CreateSocketAsyncEventArgs(_receiveEvent, buffer, totalBytesReceived,
+                    length - totalBytesReceived);
+                if (_socket.ReceiveAsync(args))
                 {
-                    throw new SocketException((int)args.SocketError);
+                    if (!_receiveEvent.WaitOne(timeout))
+                        throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
+                            "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds));
                 }
 
-                var receivedBytes = args.BytesTransferred;
-                if (receivedBytes > 0)
+                switch (args.SocketError)
                 {
-                    receivedTotal += receivedBytes;
-                    continue;
+                    case SocketError.WouldBlock:
+                    case SocketError.IOPending:
+                    case SocketError.NoBufferSpaceAvailable:
+                        // socket buffer is probably full, wait and try again
+                        Thread.Sleep(30);
+                        break;
+                    case SocketError.Success:
+                        var bytesReceived = args.BytesTransferred;
+                        if (bytesReceived > 0)
+                        {
+                            totalBytesReceived += bytesReceived;
+                            continue;
+                        }
+
+                        if (_isDisconnecting)
+                            throw new SshConnectionException(
+                                "An established connection was aborted by the software in your host machine.",
+                                DisconnectReason.ConnectionLost);
+                        throw new SshConnectionException("An established connection was aborted by the server.",
+                            DisconnectReason.ConnectionLost);
+                    default:
+                        throw new SocketException((int) args.SocketError);
                 }
-                throw new SshConnectionException("An established connection was aborted by the software in your host machine.", DisconnectReason.ConnectionLost);
-            } while (receivedTotal < length);
+            } while (totalBytesReceived < length);
         }
 
+        /// <summary>
+        /// Writes the specified data to the server.
+        /// </summary>
+        /// <param name="data">The data to write to the server.</param>
+        /// <exception cref="SshOperationTimeoutException">The write has timed-out.</exception>
+        /// <exception cref="SocketException">The write failed.</exception>
         partial void SocketWrite(byte[] data)
         {
-            if (this._isConnected)
-            {
-                var args = new SocketAsyncEventArgs();
-                args.SetBuffer(data, 0, data.Length);
-                args.UserToken = this._socket;
-                args.RemoteEndPoint = this._socket.RemoteEndPoint;
-                args.Completed += OnSend;
-
-                this._socket.SendAsync(args);
-            }
-            else
-                throw new SocketException((int)SocketError.NotConnected);
-
-        }
-
-        private void OnConnect(object sender, SocketAsyncEventArgs e)
-        {
-            this._autoEvent.Set();
-            this._isConnected = (e.SocketError == SocketError.Success);
-        }
+            var timeout = ConnectionInfo.Timeout;
+            var totalBytesSent = 0;  // how many bytes are already sent
+            var totalBytesToSend = data.Length;
 
-        private void OnSend(object sender, SocketAsyncEventArgs e)
-        {
-            this._sendEvent.Set();
-        }
+            do
+            {
+                var args = CreateSocketAsyncEventArgs(_sendEvent, data, 0, totalBytesToSend - totalBytesSent);
+                if (_socket.SendAsync(args))
+                {
+                    if (!_sendEvent.WaitOne(timeout))
+                        throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
+                            "Socket write operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds));
+                }
 
-        private void OnReceive(object sender, SocketAsyncEventArgs e)
-        {
-            this._receiveEvent.Set();
+                switch (args.SocketError)
+                {
+                    case SocketError.WouldBlock:
+                    case SocketError.IOPending:
+                    case SocketError.NoBufferSpaceAvailable:
+                        // socket buffer is probably full, wait and try again
+                        Thread.Sleep(30);
+                        break;
+                    case SocketError.Success:
+                        totalBytesSent += args.BytesTransferred;
+                        break;
+                    default:
+                        throw new SocketException((int) args.SocketError);
+}
+                } while (totalBytesSent < totalBytesToSend);
         }
 
         partial void ExecuteThread(Action action)
@@ -168,9 +216,9 @@ namespace Renci.SshNet
 
         partial void InternalRegisterMessage(string messageName)
         {
-            lock (this._messagesMetadata)
+            lock (_messagesMetadata)
             {
-                foreach (var item in from m in this._messagesMetadata where m.Name == messageName select m)
+                foreach (var item in from m in _messagesMetadata where m.Name == messageName select m)
                 {
                     item.Enabled = true;
                     item.Activated = true;
@@ -180,14 +228,30 @@ namespace Renci.SshNet
 
         partial void InternalUnRegisterMessage(string messageName)
         {
-            lock (this._messagesMetadata)
+            lock (_messagesMetadata)
             {
-                foreach (var item in from m in this._messagesMetadata where m.Name == messageName select m)
+                foreach (var item in from m in _messagesMetadata where m.Name == messageName select m)
                 {
                     item.Enabled = false;
                     item.Activated = false;
                 }
             }
         }
+
+        private SocketAsyncEventArgs CreateSocketAsyncEventArgs(EventWaitHandle waitHandle)
+        {
+            var args = new SocketAsyncEventArgs();
+            args.UserToken = _socket;
+            args.RemoteEndPoint = _socket.RemoteEndPoint;
+            args.Completed += (sender, eventArgs) => waitHandle.Set();
+            return args;
+        }
+
+        private SocketAsyncEventArgs CreateSocketAsyncEventArgs(EventWaitHandle waitHandle, byte[] data, int offset, int count)
+        {
+            var args = CreateSocketAsyncEventArgs(waitHandle);
+            args.SetBuffer(data, offset, count);
+            return args;
+        }
     }
 }

+ 27 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest.HttpProxy.cs

@@ -10,6 +10,33 @@ namespace Renci.SshNet.Tests.Classes
 {
     public partial class SessionTest
     {
+        [TestMethod]
+        public void ConnectShouldThrowProxyExceptionWhenHttpProxyResponseDoesNotContainStatusLine()
+        {
+            var proxyEndPoint = new IPEndPoint(IPAddress.Loopback, 8123);
+            var serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+
+            using (var proxyStub = new HttpProxyStub(proxyEndPoint))
+            {
+                proxyStub.Responses.Add(Encoding.ASCII.GetBytes("Whatever\r\n"));
+                proxyStub.Start();
+
+                using (var session = new Session(CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon")))
+                {
+                    try
+                    {
+                        session.Connect();
+                        Assert.Fail();
+                    }
+                    catch (ProxyException ex)
+                    {
+                        Assert.IsNull(ex.InnerException);
+                        Assert.AreEqual("HTTP response does not contain status line.", ex.Message);
+                    }
+                }
+            }
+        }
+
         [TestMethod]
         public void ConnectShouldThrowProxyExceptionWhenHttpProxyReturnsHttpStatusOtherThan200()
         {

+ 219 - 1
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest.cs

@@ -1,4 +1,9 @@
-using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Globalization;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
 using Renci.SshNet.Common;
 using Renci.SshNet.Messages;
 using Renci.SshNet.Tests.Common;
@@ -11,10 +16,206 @@ namespace Renci.SshNet.Tests.Classes
     [TestClass]
     public partial class SessionTest : TestBase
     {
+        [TestMethod]
+        public void ConnectShouldSkipLinesBeforeProtocolIdentificationString()
+        {
+            var serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+            var connectionInfo = CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5));
+
+            using (var serverStub = new AsyncSocketListener(serverEndPoint))
+            {
+                serverStub.Connected += (socket) =>
+                {
+                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("SSH-666-SshStub\r\n"));
+                    socket.Shutdown(SocketShutdown.Send);
+                };
+                serverStub.Start();
+
+                using (var session = new Session(connectionInfo))
+                {
+                    try
+                    {
+                        session.Connect();
+                        Assert.Fail();
+                    }
+                    catch (SshConnectionException ex)
+                    {
+                        Assert.IsNull(ex.InnerException);
+                        Assert.AreEqual("Server version '666' is not supported.", ex.Message);
+
+                        Assert.AreEqual("SSH-666-SshStub", connectionInfo.ServerVersion);
+                    }
+                }
+            }
+        }
+
+        [TestMethod]
+        public void ConnectShouldSupportProtocolIdentificationStringThatDoesNotEndWithCrlf()
+        {
+            var serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+            var connectionInfo = CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5));
+
+            using (var serverStub = new AsyncSocketListener(serverEndPoint))
+            {
+                serverStub.Connected += (socket) =>
+                {
+                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("SSH-666-SshStub"));
+                    socket.Shutdown(SocketShutdown.Send);
+                };
+                serverStub.Start();
+
+                using (var session = new Session(connectionInfo))
+                {
+                    try
+                    {
+                        session.Connect();
+                        Assert.Fail();
+                    }
+                    catch (SshConnectionException ex)
+                    {
+                        Assert.IsNull(ex.InnerException);
+                        Assert.AreEqual("Server version '666' is not supported.", ex.Message);
+
+                        Assert.AreEqual("SSH-666-SshStub", connectionInfo.ServerVersion);
+                    }
+                }
+            }
+        }
+
+        [TestMethod]
+        public void ConnectShouldThrowSshOperationExceptionWhenServerDoesNotRespondWithinConnectionTimeout()
+        {
+            var serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+            var timeout = TimeSpan.FromMilliseconds(500);
+            Socket clientSocket = null;
+
+            using (var serverStub = new AsyncSocketListener(serverEndPoint))
+            {
+                serverStub.Connected += (socket) =>
+                {
+                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                    clientSocket = socket;
+                };
+                serverStub.Start();
+
+                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromMilliseconds(500))))
+                {
+                    try
+                    {
+                        session.Connect();
+                        Assert.Fail();
+                    }
+                    catch (SshOperationTimeoutException ex)
+                    {
+                        Assert.IsNull(ex.InnerException);
+                        Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds), ex.Message);
+
+                        Assert.IsNotNull(clientSocket);
+                        Assert.IsTrue(clientSocket.Connected);
+
+                        // shut down socket
+                        clientSocket.Shutdown(SocketShutdown.Send);
+                    }
+                }
+            }
+        }
+
+        [TestMethod]
+        public void ConnectShouldSshConnectionExceptionWhenServerResponseDoesNotContainProtocolIdentificationString()
+        {
+            var serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+
+            // response ends with CRLF
+            using (var serverStub = new AsyncSocketListener(serverEndPoint))
+            {
+                serverStub.Connected += (socket) =>
+                {
+                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                    socket.Shutdown(SocketShutdown.Send);
+                };
+                serverStub.Start();
+
+                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5))))
+                {
+                    try
+                    {
+                        session.Connect();
+                        Assert.Fail();
+                    }
+                    catch (SshConnectionException ex)
+                    {
+                        Assert.IsNull(ex.InnerException);
+                        Assert.AreEqual("Server response does not contain SSH protocol identification.", ex.Message);
+                    }
+                }
+            }
+
+            // response does not end with CRLF
+            using (var serverStub = new AsyncSocketListener(serverEndPoint))
+            {
+                serverStub.Connected += (socket) =>
+                {
+                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner"));
+                    socket.Shutdown(SocketShutdown.Send);
+                };
+                serverStub.Start();
+
+                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5))))
+                {
+                    try
+                    {
+                        session.Connect();
+                        Assert.Fail();
+                    }
+                    catch (SshConnectionException ex)
+                    {
+                        Assert.IsNull(ex.InnerException);
+                        Assert.AreEqual("Server response does not contain SSH protocol identification.", ex.Message);
+                    }
+                }
+            }
+
+            // last line is empty
+            using (var serverStub = new AsyncSocketListener(serverEndPoint))
+            {
+                serverStub.Connected += (socket) =>
+                {
+                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                    socket.Shutdown(SocketShutdown.Send);
+                };
+                serverStub.Start();
+
+                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5))))
+                {
+                    try
+                    {
+                        session.Connect();
+                        Assert.Fail();
+                    }
+                    catch (SshConnectionException ex)
+                    {
+                        Assert.IsNull(ex.InnerException);
+                        Assert.AreEqual("Server response does not contain SSH protocol identification.", ex.Message);
+                    }
+                }
+            }
+        }
+
+
         /// <summary>
         ///A test for SessionSemaphore
         ///</summary>
         [TestMethod()]
+        [Ignore]
         public void SessionSemaphoreTest()
         {
             ConnectionInfo connectionInfo = null; // TODO: Initialize to an appropriate value
@@ -28,6 +229,7 @@ namespace Renci.SshNet.Tests.Classes
         ///A test for IsConnected
         ///</summary>
         [TestMethod()]
+        [Ignore]
         public void IsConnectedTest()
         {
             ConnectionInfo connectionInfo = null; // TODO: Initialize to an appropriate value
@@ -41,6 +243,7 @@ namespace Renci.SshNet.Tests.Classes
         ///A test for ClientInitMessage
         ///</summary>
         [TestMethod()]
+        [Ignore]
         public void ClientInitMessageTest()
         {
             ConnectionInfo connectionInfo = null; // TODO: Initialize to an appropriate value
@@ -54,6 +257,7 @@ namespace Renci.SshNet.Tests.Classes
         ///A test for UnRegisterMessage
         ///</summary>
         [TestMethod()]
+        [Ignore]
         public void UnRegisterMessageTest()
         {
             ConnectionInfo connectionInfo = null; // TODO: Initialize to an appropriate value
@@ -67,6 +271,7 @@ namespace Renci.SshNet.Tests.Classes
         ///A test for RegisterMessage
         ///</summary>
         [TestMethod()]
+        [Ignore]
         public void RegisterMessageTest()
         {
             ConnectionInfo connectionInfo = null; // TODO: Initialize to an appropriate value
@@ -80,6 +285,7 @@ namespace Renci.SshNet.Tests.Classes
         ///A test for Dispose
         ///</summary>
         [TestMethod()]
+        [Ignore]
         public void DisposeTest()
         {
             ConnectionInfo connectionInfo = null; // TODO: Initialize to an appropriate value
@@ -92,6 +298,7 @@ namespace Renci.SshNet.Tests.Classes
         ///A test for Disconnect
         ///</summary>
         [TestMethod()]
+        [Ignore]
         public void DisconnectTest()
         {
             ConnectionInfo connectionInfo = null; // TODO: Initialize to an appropriate value
@@ -104,6 +311,7 @@ namespace Renci.SshNet.Tests.Classes
         ///A test for Connect
         ///</summary>
         [TestMethod()]
+        [Ignore]
         public void ConnectTest()
         {
             ConnectionInfo connectionInfo = null; // TODO: Initialize to an appropriate value
@@ -112,5 +320,15 @@ namespace Renci.SshNet.Tests.Classes
             Assert.Inconclusive("A method that does not return a value cannot be verified.");
         }
 
+        private static ConnectionInfo CreateConnectionInfo(IPEndPoint serverEndPoint, TimeSpan timeout)
+        {
+            var connectionInfo = new ConnectionInfo(
+                serverEndPoint.Address.ToString(),
+                serverEndPoint.Port,
+                "eric",
+                new NoneAuthenticationMethod("eric"));
+            connectionInfo.Timeout = timeout;
+            return connectionInfo;
+        }
     }
 }

+ 83 - 52
Renci.SshClient/Renci.SshNet/Session.NET.cs

@@ -1,4 +1,5 @@
-using System.Linq;
+using System.Globalization;
+using System.Linq;
 using System;
 using System.Net.Sockets;
 using System.Net;
@@ -12,6 +13,10 @@ namespace Renci.SshNet
 {
     public partial class Session
     {
+        private const byte Null = 0x00;
+        private const byte CarriageReturn = 0x0d;
+        private const byte LineFeed = 0x0a;
+
         private readonly TraceSource _log =
 #if DEBUG
             new TraceSource("SshNet.Logging", SourceLevels.All);
@@ -27,9 +32,7 @@ namespace Renci.SshNet
         /// <summary>
         /// Gets a value indicating whether the socket is connected.
         /// </summary>
-        /// <value>
-        /// <c>true</c> if the socket is connected; otherwise, <c>false</c>.
-        /// </value>
+        /// <param name="isConnected"><c>true</c> if the socket is connected; otherwise, <c>false</c></param>
         /// <remarks>
         /// <para>
         /// As a first check we verify whether <see cref="Socket.Connected"/> is
@@ -52,7 +55,7 @@ namespace Renci.SshNet
         /// </para>
         /// <para>
         /// <c>Conclusion:</c> when the return value is <c>true</c> - but no data is available for reading - then
-        ///  the socket is no longer connected.
+        /// the socket is no longer connected.
         /// </para>
         /// <para>
         /// When a <see cref="Socket"/> is used from multiple threads, there's a race condition
@@ -86,81 +89,102 @@ namespace Renci.SshNet
             }
         }
 
+        /// <summary>
+        /// Establishes a socket connection to the specified host and port.
+        /// </summary>
+        /// <param name="host">The host name of the server to connect to.</param>
+        /// <param name="port">The port to connect to.</param>
+        /// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="Renci.SshNet.ConnectionInfo.Timeout"/>.</exception>
+        /// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
         partial void SocketConnect(string host, int port)
         {
             const int socketBufferSize = 2 * MaximumSshPacketSize;
 
-            var addr = host.GetIPAddress();
+            var ipAddress = host.GetIPAddress();
+            var timeout = ConnectionInfo.Timeout;
+            var ep = new IPEndPoint(ipAddress, port);
 
-            var ep = new IPEndPoint(addr, port);
-            this._socket = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
+            _socket = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
+            _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true);
+            _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.SendBuffer, socketBufferSize);
+            _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReceiveBuffer, socketBufferSize);
 
-            this._socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true);
-            this._socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.SendBuffer, socketBufferSize);
-            this._socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReceiveBuffer, socketBufferSize);
+            Log(string.Format("Initiating connect to '{0}:{1}'.", ConnectionInfo.Host, ConnectionInfo.Port));
 
-            this.Log(string.Format("Initiating connect to '{0}:{1}'.", this.ConnectionInfo.Host, this.ConnectionInfo.Port));
+            var connectResult = _socket.BeginConnect(ep, null, null);
+            if (!connectResult.AsyncWaitHandle.WaitOne(timeout, false))
+                throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
+                    "Connection failed to establish within {0:F0} milliseconds.", timeout.TotalMilliseconds));
 
-            //  Connect socket with specified timeout
-            var connectResult = this._socket.BeginConnect(ep, null, null);
-
-            if (!connectResult.AsyncWaitHandle.WaitOne(this.ConnectionInfo.Timeout, false))
-            {
-                throw new SshOperationTimeoutException("Connection Could Not Be Established");
-            }
-
-            this._socket.EndConnect(connectResult);
+            _socket.EndConnect(connectResult);
         }
 
+        /// <summary>
+        /// Closes the socket and allows the socket to be reused after the current connection is closed.
+        /// </summary>
+        /// <exception cref="SocketException">An error occurred when trying to access the socket.</exception>
         partial void SocketDisconnect()
         {
             _socket.Disconnect(true);
         }
 
-        partial void SocketReadLine(ref string response)
+        /// <summary>
+        /// Performs a blocking read on the socket until a line is read.
+        /// </summary>
+        /// <param name="response">The line read from the socket, or <c>null</c> when the remote server has shutdown and all data has been received.</param>
+        /// <param name="timeout">A <see cref="TimeSpan"/> that represents the time to wait until a line is read.</param>
+        /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
+        /// <exception cref="SocketException">An error occurred when trying to access the socket.</exception>
+        partial void SocketReadLine(ref string response, TimeSpan timeout)
         {
             var encoding = new ASCIIEncoding();
-
-            //  Read data one byte at a time to find end of line and leave any unhandled information in the buffer to be processed later
             var buffer = new List<byte>();
-
             var data = new byte[1];
+
+            // read data one byte at a time to find end of line and leave any unhandled information in the buffer
+            // to be processed by subsequent invocations
             do
             {
-                var asyncResult = this._socket.BeginReceive(data, 0, data.Length, SocketFlags.None, null, null);
+                var asyncResult = _socket.BeginReceive(data, 0, data.Length, SocketFlags.None, null, null);
+                if (!asyncResult.AsyncWaitHandle.WaitOne(timeout))
+                    throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
+                        "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds));
 
-                if (!asyncResult.AsyncWaitHandle.WaitOne(this.ConnectionInfo.Timeout))
-                    throw new SshOperationTimeoutException("Socket read operation has timed out");
+                var received = _socket.EndReceive(asyncResult);
 
-                var received = this._socket.EndReceive(asyncResult);
-
-                //  If zero bytes received then exit
                 if (received == 0)
+                    // the remote server shut down the socket
                     break;
 
                 buffer.Add(data[0]);
             }
-            while (!(buffer.Count > 0 && (buffer[buffer.Count - 1] == 0x0A || buffer[buffer.Count - 1] == 0x00)));
+            while (!(buffer.Count > 0 && (buffer[buffer.Count - 1] == LineFeed || buffer[buffer.Count - 1] == Null)));
 
-            // Return an empty version string if the buffer consists of a 0x00 character.
-            if (buffer.Count > 0 && buffer[buffer.Count - 1] == 0x00)
-            {
+            if (buffer.Count == 0)
+                response = null;
+            else if (buffer.Count == 1 && buffer[buffer.Count - 1] == 0x00)
+                // return an empty version string if the buffer consists of only a 0x00 character
                 response = string.Empty;
-            }
-            else if (buffer.Count > 1 && buffer[buffer.Count - 2] == 0x0D)
+            else if (buffer.Count > 1 && buffer[buffer.Count - 2] == CarriageReturn)
+                // strip trailing CRLF
                 response = encoding.GetString(buffer.Take(buffer.Count - 2).ToArray());
-            else
+            else if (buffer.Count > 1 && buffer[buffer.Count - 1] == LineFeed)
+                // strip trailing LF
                 response = encoding.GetString(buffer.Take(buffer.Count - 1).ToArray());
+            else
+                response = encoding.GetString(buffer.ToArray());
         }
 
         /// <summary>
-        /// Function to read <paramref name="length"/> amount of data before returning, or throwing an exception.
+        /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
         /// </summary>
-        /// <param name="length">The amount wanted.</param>
+        /// <param name="length">The number of bytes to read.</param>
         /// <param name="buffer">The buffer to read to.</param>
-        /// <exception cref="SshConnectionException">Happens when the socket is closed.</exception>
-        /// <exception cref="Exception">Unhandled exception.</exception>
-        partial void SocketRead(int length, ref byte[] buffer)
+        /// <param name="timeout">A <see cref="TimeSpan"/> that represents the time to wait until <paramref name="length"/> bytes a read.</param>
+        /// <exception cref="SshConnectionException">The socket is closed.</exception>
+        /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
+        /// <exception cref="SocketException">The read failed.</exception>
+        partial void SocketRead(int length, ref byte[] buffer, TimeSpan timeout)
         {
             var receivedTotal = 0;  // how many bytes is already received
 
@@ -168,7 +192,7 @@ namespace Renci.SshNet
             {
                 try
                 {
-                    var receivedBytes = this._socket.Receive(buffer, receivedTotal, length - receivedTotal, SocketFlags.None);
+                    var receivedBytes = _socket.Receive(buffer, receivedTotal, length - receivedTotal, SocketFlags.None);
                     if (receivedBytes > 0)
                     {
                         // signal that bytes have been read from the socket
@@ -186,7 +210,7 @@ namespace Renci.SshNet
                     // 3 - MessageListener then again calls RaiseError().
                     // There the exception is checked for the exception thrown here (ConnectionLost), and if it matches it will not call Session.SendDisconnect().
                     //
-                    // Adding a check for this._isDisconnecting causes ReceiveMessage() to throw SshConnectionException: "Bad packet length {0}".
+                    // Adding a check for _isDisconnecting causes ReceiveMessage() to throw SshConnectionException: "Bad packet length {0}".
                     //
 
                     if (_isDisconnecting)
@@ -198,7 +222,7 @@ namespace Renci.SshNet
                     if (exp.SocketErrorCode == SocketError.ConnectionAborted)
                     {
                         buffer = new byte[length];
-                        this.Disconnect();
+                        Disconnect();
                         return;
                     }
 
@@ -215,16 +239,23 @@ namespace Renci.SshNet
             } while (receivedTotal < length);
         }
 
+        /// <summary>
+        /// Writes the specified data to the server.
+        /// </summary>
+        /// <param name="data">The data to write to the server.</param>
+        /// <exception cref="SshOperationTimeoutException">The write has timed-out.</exception>
+        /// <exception cref="SocketException">The write failed.</exception>
         partial void SocketWrite(byte[] data)
         {
-            var sent = 0;  // how many bytes is already sent
-            var length = data.Length;
+            var totalBytesSent = 0;  // how many bytes are already sent
+            var totalBytesToSend = data.Length;
 
             do
             {
                 try
                 {
-                    sent += this._socket.Send(data, sent, length - sent, SocketFlags.None);
+                    totalBytesSent += _socket.Send(data, totalBytesSent, totalBytesToSend - totalBytesSent,
+                        SocketFlags.None);
                 }
                 catch (SocketException ex)
                 {
@@ -238,13 +269,13 @@ namespace Renci.SshNet
                     else
                         throw;  // any serious error occurr
                 }
-            } while (sent < length);
+            } while (totalBytesSent < totalBytesToSend);
         }
 
         [Conditional("DEBUG")]
         partial void Log(string text)
         {
-            this._log.TraceEvent(TraceEventType.Verbose, 1, text);
+            _log.TraceEvent(TraceEventType.Verbose, 1, text);
         }
 
 #if ASYNC_SOCKET_READ
@@ -273,7 +304,7 @@ namespace Renci.SshNet
                         if (socketException.SocketErrorCode == SocketError.ConnectionAborted)
                         {
                             buffer = new byte[length];
-                            this.Disconnect();
+                            Disconnect();
                             return;
                         }
                     }

+ 55 - 15
Renci.SshClient/Renci.SshNet/Session.cs

@@ -510,18 +510,13 @@ namespace Renci.SshNet
                     while (true)
                     {
                         string serverVersion = string.Empty;
-                        this.SocketReadLine(ref serverVersion);
-
-                        this.ServerVersion = serverVersion;
-                        if (string.IsNullOrEmpty(this.ServerVersion))
-                        {
-                            throw new InvalidOperationException("Server string is null or empty.");
-                        }
-
-                        versionMatch = ServerVersionRe.Match(this.ServerVersion);
-
+                        this.SocketReadLine(ref serverVersion, ConnectionInfo.Timeout);
+                        if (serverVersion == null)
+                            throw new SshConnectionException("Server response does not contain SSH protocol identification.");
+                        versionMatch = ServerVersionRe.Match(serverVersion);
                         if (versionMatch.Success)
                         {
+                            this.ServerVersion = serverVersion;
                             break;
                         }
                     }
@@ -1672,20 +1667,58 @@ namespace Renci.SshNet
         /// </value>
         partial void IsSocketConnected(ref bool isConnected);
 
+        /// <summary>
+        /// Establishes a socket connection to the specified host and port.
+        /// </summary>
+        /// <param name="host">The host name of the server to connect to.</param>
+        /// <param name="port">The port to connect to.</param>
+        /// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="Renci.SshNet.ConnectionInfo.Timeout"/>.</exception>
+        /// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
         partial void SocketConnect(string host, int port);
 
+        /// <summary>
+        /// Closes the socket.
+        /// </summary>
+        /// <exception cref="SocketException">An error occurred when trying to access the socket.</exception>
         partial void SocketDisconnect();
 
-        partial void SocketRead(int length, ref byte[] buffer);
+        /// <summary>
+        /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
+        /// </summary>
+        /// <param name="length">The number of bytes to read.</param>
+        /// <param name="buffer">The buffer to read to.</param>
+        /// <exception cref="SshConnectionException">The socket is closed.</exception>
+        /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
+        private void SocketRead(int length, ref byte[] buffer)
+        {
+            SocketRead(length, ref buffer, ConnectionInfo.Timeout);
+        }
+
+        /// <summary>
+        /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
+        /// </summary>
+        /// <param name="length">The number of bytes to read.</param>
+        /// <param name="buffer">The buffer to read to.</param>
+        /// <param name="timeout">A <see cref="TimeSpan"/> that represents the time to wait until <paramref name="length"/> bytes a read.</param>
+        /// <exception cref="SshConnectionException">The socket is closed.</exception>
+        /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
+        partial void SocketRead(int length, ref byte[] buffer, TimeSpan timeout);
 
-        partial void SocketReadLine(ref string response);
+        /// <summary>
+        /// Performs a blocking read on the socket until a line is read.
+        /// </summary>
+        /// <param name="response">The line read from the socket, or <c>null</c> when the remote server has shutdown and all data has been received.</param>
+        /// <param name="timeout">A <see cref="TimeSpan"/> that represents the time to wait until a line is read.</param>
+        /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
+        partial void SocketReadLine(ref string response, TimeSpan timeout);
 
         partial void Log(string text);
 
         /// <summary>
         /// Writes the specified data to the server.
         /// </summary>
-        /// <param name="data">The data.</param>
+        /// <param name="data">The data to write to the server.</param>
+        /// <exception cref="SshOperationTimeoutException">The write has timed-out.</exception>
         partial void SocketWrite(byte[] data);
 
         /// <summary>
@@ -1976,7 +2009,10 @@ namespace Renci.SshNet
 
             while (true)
             {
-                this.SocketReadLine(ref response);
+                this.SocketReadLine(ref response, ConnectionInfo.Timeout);
+                if (response == null)
+                    // server shut down socket
+                    break;
 
                 if (statusCode == null)
                 {
@@ -1991,8 +2027,9 @@ namespace Renci.SshNet
                             throw new ProxyException(string.Format("HTTP: Status code {0}, \"{1}\"", httpStatusCode,
                                 reasonPhrase));
                         }
-                        continue;
                     }
+
+                    continue;
                 }
 
                 // continue on parsing message headers coming from the server
@@ -2019,6 +2056,9 @@ namespace Renci.SshNet
                     break;
                 }
             }
+
+            if (statusCode == null)
+                throw new ProxyException("HTTP response does not contain status line.");
         }
 
         /// <summary>