Преглед изворни кода

Improve socket related error handling.
Refactor to compile UWP project without using CoreFX.

drieseng пре 9 година
родитељ
комит
9f9864e789

+ 237 - 157
src/Renci.SshNet/ForwardedPortDynamic.NET.cs

@@ -1,6 +1,5 @@
 using System;
 using System.Diagnostics;
-using System.IO;
 using System.Linq;
 using System.Text;
 using System.Net;
@@ -31,9 +30,12 @@ namespace Renci.SshNet
 
             var ep = new IPEndPoint(ip, (int) BoundPort);
 
-            _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp) {Blocking = true};
+            _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
+            // TODO: decide if we want to have blocking socket
+#if FEATURE_SOCKET_SETSOCKETOPTION
             _listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.DontLinger, true);
             _listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.NoDelay, true);
+#endif //FEATURE_SOCKET_SETSOCKETOPTION
             _listener.Bind(ep);
             _listener.Listen(5);
 
@@ -153,48 +155,23 @@ namespace Renci.SshNet
 
             try
             {
+#if FEATURE_SOCKET_SETSOCKETOPTION
                 remoteSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.DontLinger, true);
                 remoteSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.NoDelay, true);
+#endif //FEATURE_SOCKET_SETSOCKETOPTION
 
                 using (var channel = Session.CreateChannelDirectTcpip())
                 {
                     channel.Exception += Channel_Exception;
 
-                    var version = new byte[1];
-
-                    // create eventhandler which is to be invoked to interrupt a blocking receive
-                    // when we're closing the forwarded port
-                    EventHandler closeClientSocket = (_, args) => CloseSocket(remoteSocket);
-
                     try
                     {
-                        Closing += closeClientSocket;
-
-                        var bytesRead = remoteSocket.Receive(version);
-                        if (bytesRead == 0)
+                        if (!HandleSocks(channel, remoteSocket, Session.ConnectionInfo.Timeout))
                         {
                             CloseSocket(remoteSocket);
                             return;
                         }
 
-                        if (version[0] == 4)
-                        {
-                            HandleSocks4(remoteSocket, channel);
-                        }
-                        else if (version[0] == 5)
-                        {
-                            HandleSocks5(remoteSocket, channel);
-                        }
-                        else
-                        {
-                            throw new NotSupportedException(string.Format("SOCKS version {0} is not supported.",
-                                version[0]));
-                        }
-
-                        // interrupt of blocking receive is now handled by channel (SOCKS4 and SOCKS5)
-                        // or no longer necessary
-                        Closing -= closeClientSocket;
-
                         // start receiving from client socket (and sending to server)
                         channel.Bind();
                     }
@@ -225,6 +202,44 @@ namespace Renci.SshNet
             }
         }
 
+        private bool HandleSocks(IChannelDirectTcpip channel, Socket remoteSocket, TimeSpan timeout)
+        {
+            // create eventhandler which is to be invoked to interrupt a blocking receive
+            // when we're closing the forwarded port
+            EventHandler closeClientSocket = (_, args) => CloseSocket(remoteSocket);
+
+            Closing += closeClientSocket;
+
+            try
+            {
+                var version = SocketAbstraction.ReadByte(remoteSocket, timeout);
+                if (version == -1)
+                {
+                    return false;
+                }
+
+                if (version == 4)
+                {
+                    return HandleSocks4(remoteSocket, channel, timeout);
+                }
+                else if (version == 5)
+                {
+                    return HandleSocks5(remoteSocket, channel, timeout);
+                }
+                else
+                {
+                    throw new NotSupportedException(string.Format("SOCKS version {0} is not supported.", version));
+                }
+            }
+            finally
+            {
+                // interrupt of blocking receive is now handled by channel (SOCKS4 and SOCKS5)
+                // or no longer necessary
+                Closing -= closeClientSocket;
+            }
+
+        }
+
         private static void CloseSocket(Socket socket)
         {
             if (socket.Connected)
@@ -301,169 +316,225 @@ namespace Renci.SshNet
             }
         }
 
-        private void HandleSocks4(Socket socket, IChannelDirectTcpip channel)
+        private bool HandleSocks4(Socket socket, IChannelDirectTcpip channel, TimeSpan timeout)
         {
-            using (var stream = new NetworkStream(socket))
+            var commandCode = SocketAbstraction.ReadByte(socket, timeout);
+            if (commandCode == 0)
             {
-                var commandCode = stream.ReadByte();
-                //  TODO:   See what need to be done depends on the code
+                // SOCKS client closed connection
+                return false;
+            }
 
-                var portBuffer = new byte[2];
-                stream.Read(portBuffer, 0, portBuffer.Length);
-                var port = (uint)(portBuffer[0] * 256 + portBuffer[1]);
+            //  TODO:   See what need to be done depends on the code
+
+            var portBuffer = new byte[2];
+            if (SocketAbstraction.Read(socket, portBuffer, 0, portBuffer.Length, timeout) == 0)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
 
-                var ipBuffer = new byte[4];
-                stream.Read(ipBuffer, 0, ipBuffer.Length);
-                var ipAddress = new IPAddress(ipBuffer);
+            var port = (uint)(portBuffer[0] * 256 + portBuffer[1]);
 
-                var username = ReadString(stream);
+            var ipBuffer = new byte[4];
+            if (SocketAbstraction.Read(socket, ipBuffer, 0, ipBuffer.Length, timeout) == 0)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
 
-                var host = ipAddress.ToString();
+            var ipAddress = new IPAddress(ipBuffer);
 
-                RaiseRequestReceived(host, port);
+            var username = ReadString(socket, timeout);
+            if (username == null)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
 
-                channel.Open(host, port, this, socket);
+            var host = ipAddress.ToString();
 
-                using (var writeStream = new MemoryStream())
-                {
-                    writeStream.WriteByte(0x00);
+            RaiseRequestReceived(host, port);
 
-                    if (channel.IsOpen)
-                    {
-                        writeStream.WriteByte(0x5a);
-                    }
-                    else
-                    {
-                        writeStream.WriteByte(0x5b);
-                    }
+            channel.Open(host, port, this, socket);
 
-                    writeStream.Write(portBuffer, 0, portBuffer.Length);
-                    writeStream.Write(ipBuffer, 0, ipBuffer.Length);
+            SocketAbstraction.SendByte(socket, 0x00);
 
-                    // write buffer to stream
-                    var writeBuffer = writeStream.ToArray();
-                    stream.Write(writeBuffer, 0, writeBuffer.Length);
-                    stream.Flush();
-                }
+            if (channel.IsOpen)
+            {
+                SocketAbstraction.SendByte(socket, 0x5a);
+                SocketAbstraction.Send(socket, portBuffer, 0, portBuffer.Length);
+                SocketAbstraction.Send(socket, ipBuffer, 0, ipBuffer.Length);
+                return true;
             }
+
+            // signal that request was rejected or failed
+            SocketAbstraction.SendByte(socket, 0x5b);
+            return false;
         }
 
-        private void HandleSocks5(Socket socket, IChannelDirectTcpip channel)
+        private bool HandleSocks5(Socket socket, IChannelDirectTcpip channel, TimeSpan timeout)
         {
-            using (var stream = new NetworkStream(socket))
+            var authenticationMethodsCount = SocketAbstraction.ReadByte(socket, timeout);
+            if (authenticationMethodsCount == -1)
             {
-                var authenticationMethodsCount = stream.ReadByte();
+                // SOCKS client closed connection
+                return false;
+            }
 
-                var authenticationMethods = new byte[authenticationMethodsCount];
-                stream.Read(authenticationMethods, 0, authenticationMethods.Length);
+            var authenticationMethods = new byte[authenticationMethodsCount];
+            if (SocketAbstraction.Read(socket, authenticationMethods, 0, authenticationMethods.Length, timeout) == 0)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
 
-                if (authenticationMethods.Min() == 0)
-                {
-                    stream.Write(new byte[] { 0x05, 0x00 }, 0, 2);
-                }
-                else
-                {
-                    stream.Write(new byte[] { 0x05, 0xFF }, 0, 2);
-                }
+            if (authenticationMethods.Min() == 0)
+            {
+                // no user authentication is one of the authentication methods supported
+                // by the SOCKS client
+                SocketAbstraction.Send(socket, new byte[] { 0x05, 0x00 }, 0, 2);
+            }
+            else
+            {
+                // the SOCKS client requires authentication, which we currently do not support
+                SocketAbstraction.Send(socket, new byte[] { 0x05, 0xFF }, 0, 2);
 
-                var version = stream.ReadByte();
+                // we continue business as usual but expect the client to close the connection
+                // so one of the subsequent reads should return -1 signaling that the client
+                // has effectively closed the connection
+            }
 
-                if (version != 5)
-                    throw new ProxyException("SOCKS5: Version 5 is expected.");
+            var version = SocketAbstraction.ReadByte(socket, timeout);
+            if (version == -1)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
 
-                var commandCode = stream.ReadByte();
+            if (version != 5)
+                throw new ProxyException("SOCKS5: Version 5 is expected.");
 
-                if (stream.ReadByte() != 0)
-                {
-                    throw new ProxyException("SOCKS5: 0 is expected.");
-                }
+            var commandCode = SocketAbstraction.ReadByte(socket, timeout);
+            if (commandCode == -1)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
 
-                var addressType = stream.ReadByte();
+            var reserved = SocketAbstraction.ReadByte(socket, timeout);
+            if (reserved == -1)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
 
-                IPAddress ipAddress;
-                byte[] addressBuffer;
-                switch (addressType)
-                {
-                    case 0x01:
-                        {
-                            addressBuffer = new byte[4];
-                            stream.Read(addressBuffer, 0, 4);
+            if (reserved != 0)
+            {
+                throw new ProxyException("SOCKS5: 0 is expected for reserved byte.");
+            }
+
+            var addressType = SocketAbstraction.ReadByte(socket, timeout);
+            if (addressType == -1)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
 
-                            ipAddress = new IPAddress(addressBuffer);
+            IPAddress ipAddress;
+            byte[] addressBuffer;
+            switch (addressType)
+            {
+                case 0x01:
+                    {
+                        addressBuffer = new byte[4];
+                        if (SocketAbstraction.Read(socket, addressBuffer, 0, 4, timeout) == 0)
+                        {
+                            // SOCKS client closed connection
+                            return false;
                         }
-                        break;
-                    case 0x03:
+
+                        ipAddress = new IPAddress(addressBuffer);
+                    }
+                    break;
+                case 0x03:
+                    {
+                        var length = SocketAbstraction.ReadByte(socket, timeout);
+                        addressBuffer = new byte[length];
+                        if (SocketAbstraction.Read(socket, addressBuffer, 0, addressBuffer.Length, timeout) == 0)
                         {
-                            var length = stream.ReadByte();
-                            addressBuffer = new byte[length];
-                            stream.Read(addressBuffer, 0, addressBuffer.Length);
+                            // SOCKS client closed connection
+                            return false;
+                        }
 
-                            ipAddress = IPAddress.Parse(SshData.Ascii.GetString(addressBuffer));
+                        ipAddress = IPAddress.Parse(SshData.Ascii.GetString(addressBuffer));
 
-                            //var hostName = new Common.ASCIIEncoding().GetString(addressBuffer);
+                        //var hostName = new Common.ASCIIEncoding().GetString(addressBuffer);
 
-                            //ipAddress = Dns.GetHostEntry(hostName).AddressList[0];
-                        }
-                        break;
-                    case 0x04:
+                        //ipAddress = Dns.GetHostEntry(hostName).AddressList[0];
+                    }
+                    break;
+                case 0x04:
+                    {
+                        addressBuffer = new byte[16];
+                        if (SocketAbstraction.Read(socket, addressBuffer, 0, 16, timeout) == 0)
                         {
-                            addressBuffer = new byte[16];
-                            stream.Read(addressBuffer, 0, 16);
-
-                            ipAddress = new IPAddress(addressBuffer);
+                            // SOCKS client closed connection
+                            return false;
                         }
-                        break;
-                    default:
-                        throw new ProxyException(string.Format("SOCKS5: Address type '{0}' is not supported.", addressType));
-                }
 
-                var portBuffer = new byte[2];
-                stream.Read(portBuffer, 0, portBuffer.Length);
-                var port = (uint)(portBuffer[0] * 256 + portBuffer[1]);
-                var host = ipAddress.ToString();
+                        ipAddress = new IPAddress(addressBuffer);
+                    }
+                    break;
+                default:
+                    throw new ProxyException(string.Format("SOCKS5: Address type '{0}' is not supported.", addressType));
+            }
 
-                RaiseRequestReceived(host, port);
+            var portBuffer = new byte[2];
+            if (SocketAbstraction.Read(socket, portBuffer, 0, portBuffer.Length, timeout) == 0)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
 
-                channel.Open(host, port, this, socket);
+            var port = (uint)(portBuffer[0] * 256 + portBuffer[1]);
+            var host = ipAddress.ToString();
 
-                using (var writeStream = new MemoryStream())
-                {
-                    writeStream.WriteByte(0x05);
+            RaiseRequestReceived(host, port);
 
-                    if (channel.IsOpen)
-                    {
-                        writeStream.WriteByte(0x00);
-                    }
-                    else
-                    {
-                        writeStream.WriteByte(0x01);
-                    }
+            channel.Open(host, port, this, socket);
 
-                    writeStream.WriteByte(0x00);
+            SocketAbstraction.SendByte(socket, 0x05);
 
-                    if (ipAddress.AddressFamily == AddressFamily.InterNetwork)
-                    {
-                        writeStream.WriteByte(0x01);
-                    }
-                    else if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6)
-                    {
-                        writeStream.WriteByte(0x04);
-                    }
-                    else
-                    {
-                        throw new NotSupportedException("Not supported address family.");
-                    }
+            if (channel.IsOpen)
+            {
+                SocketAbstraction.SendByte(socket, 0x00);
+            }
+            else
+            {
+                SocketAbstraction.SendByte(socket, 0x01);
+            }
 
-                    var addressBytes = ipAddress.GetAddressBytes();
-                    writeStream.Write(addressBytes, 0, addressBytes.Length);
-                    writeStream.Write(portBuffer, 0, portBuffer.Length);
+            SocketAbstraction.SendByte(socket, 0x00);
 
-                    // write buffer to stream
-                    var writeBuffer = writeStream.ToArray();
-                    stream.Write(writeBuffer, 0, writeBuffer.Length);
-                    stream.Flush();
-                }
+            if (ipAddress.AddressFamily == AddressFamily.InterNetwork)
+            {
+                SocketAbstraction.SendByte(socket, 0x01);
+            }
+            else if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6)
+            {
+                SocketAbstraction.SendByte(socket, 0x04);
             }
+            else
+            {
+                throw new NotSupportedException("Not supported address family.");
+            }
+
+            var addressBytes = ipAddress.GetAddressBytes();
+            SocketAbstraction.Send(socket, addressBytes, 0, addressBytes.Length);
+            SocketAbstraction.Send(socket, portBuffer, 0, portBuffer.Length);
+
+            return true;
         }
 
         private void Channel_Exception(object sender, ExceptionEventArgs e)
@@ -471,21 +542,30 @@ namespace Renci.SshNet
             RaiseExceptionEvent(e.Exception);
         }
 
-        private static string ReadString(Stream stream)
+        /// <summary>
+        /// Reads a null terminated string from a socket.
+        /// </summary>
+        /// <param name="socket">The <see cref="Socket"/> to read from.</param>
+        /// <param name="timeout">The timeout to apply to individual reads.</param>
+        /// <returns>
+        /// The <see cref="string"/> read, or <c>null</c> when the socket was closed.
+        /// </returns>
+        private static string ReadString(Socket socket, TimeSpan timeout)
         {
             var text = new StringBuilder();
+            var buffer = new byte[1];
             while (true)
             {
-                var byteRead = stream.ReadByte();
-                if (byteRead == 0)
+                if (SocketAbstraction.Read(socket, buffer, 0, 1, timeout) == 0)
                 {
-                    // end of the string
-                    break;
+                    // SOCKS client closed connection
+                    return null;
                 }
 
-                if (byteRead == -1)
+                var byteRead = buffer[0];
+                if (byteRead == 0)
                 {
-                    // the client shut down the socket
+                    // end of the string
                     break;
                 }
 

+ 8 - 3
src/Renci.SshNet/ForwardedPortLocal.NET.cs

@@ -25,8 +25,11 @@ namespace Renci.SshNet
             var addr = DnsAbstraction.GetHostAddresses(BoundHost)[0];
             var ep = new IPEndPoint(addr, (int) BoundPort);
 
-            _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp) {Blocking = true};
+            _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
+            // TODO: decide if we want to have blocking socket
+#if FEATURE_SOCKET_SETSOCKETOPTION
             _listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.NoDelay, true);
+#endif // FEATURE_SOCKET_SETSOCKETOPTION
             _listener.Bind(ep);
             _listener.Listen(1);
 
@@ -55,9 +58,9 @@ namespace Renci.SshNet
                             asyncResult.AsyncWaitHandle.WaitOne();
                         }
 #elif FEATURE_SOCKET_TAP
-                        #error Accepting new socket connections is not implemented.
+#error Accepting new socket connections is not implemented.
 #else
-                        #error Accepting new socket connections is not implemented.
+#error Accepting new socket connections is not implemented.
 #endif
                     }
                     catch (ObjectDisposedException)
@@ -131,8 +134,10 @@ namespace Renci.SshNet
 
             try
             {
+#if FEATURE_SOCKET_SETSOCKETOPTION
                 clientSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.DontLinger, true);
                 clientSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.NoDelay, true);
+#endif //FEATURE_SOCKET_SETSOCKETOPTION
 
                 var originatorEndPoint = (IPEndPoint) clientSocket.RemoteEndPoint;
 

+ 11 - 3
src/Renci.SshNet/Session.NET.cs

@@ -4,10 +4,12 @@ namespace Renci.SshNet
 {
     public partial class Session
     {
+#if FEATURE_SOCKET_POLL
         /// <summary>
         /// Holds the lock object to ensure read access to the socket is synchronized.
         /// </summary>
         private readonly object _socketReadLock = new object();
+#endif // FEATURE_SOCKET_POLL
 
         /// <summary>
         /// Gets a value indicating whether the socket is connected.
@@ -17,9 +19,12 @@ namespace Renci.SshNet
         /// <para>
         /// As a first check we verify whether <see cref="Socket.Connected"/> is
         /// <c>true</c>. However, this only returns the state of the socket as of
-        /// the last I/O operation. Therefore we use the combination of Socket.Poll
-        /// with mode SelectRead and Socket.Available to verify if the socket is
-        /// still connected.
+        /// the last I/O operation.
+        /// </para>
+#if FEATURE_SOCKET_POLL
+        /// <para>
+        /// Therefore we use the combination of <see cref="Socket.Poll(int, SelectMode)"/> with mode <see cref="SelectMode.SelectRead"/>
+        /// and <see cref="Socket.Available"/> to verify if the socket is still connected.
         /// </para>
         /// <para>
         /// The MSDN doc mention the following on the return value of <see cref="Socket.Poll(int, SelectMode)"/>
@@ -44,9 +49,11 @@ namespace Renci.SshNet
         /// when bytes are read from the <see cref="Socket"/>.
         /// </para>
         /// </remarks>
+#endif
         partial void IsSocketConnected(ref bool isConnected)
         {
             isConnected = (_socket != null && _socket.Connected);
+#if FEATURE_SOCKET_POLL
             if (isConnected)
             {
                 // synchronize this to ensure thread B does not reset the wait handle before
@@ -69,6 +76,7 @@ namespace Renci.SshNet
                     }
                 }
             }
+#endif // FEATURE_SOCKET_POLL
         }
     }
 }