Quellcode durchsuchen

Handle timeout correctly on Socks5 Proxy (#1342)

* Add timeouts when reading from sockets in Socks5Connector

* Add a Socks5 timeout test for a connection reply

---------

Co-authored-by: Wojciech Nagórski <wojtpl2@gmail.com>
Co-authored-by: Rob Hague <rob.hague00@gmail.com>
BoronBGP vor 1 Jahr
Ursprung
Commit
c8b527397a

+ 9 - 9
src/Renci.SshNet/Connection/Socks5Connector.cs

@@ -43,13 +43,13 @@ namespace Renci.SshNet.Connection
                 };
             SocketAbstraction.Send(socket, greeting);
 
-            var socksVersion = SocketReadByte(socket);
+            var socksVersion = SocketReadByte(socket, connectionInfo.Timeout);
             if (socksVersion != 0x05)
             {
                 throw new ProxyException(string.Format("SOCKS Version '{0}' is not supported.", socksVersion));
             }
 
-            var authenticationMethod = SocketReadByte(socket);
+            var authenticationMethod = SocketReadByte(socket, connectionInfo.Timeout);
             switch (authenticationMethod)
             {
                 case 0x00:
@@ -86,13 +86,13 @@ namespace Renci.SshNet.Connection
             SocketAbstraction.Send(socket, connectionRequest);
 
             // Read Server SOCKS5 version
-            if (SocketReadByte(socket) != 5)
+            if (SocketReadByte(socket, connectionInfo.Timeout) != 5)
             {
                 throw new ProxyException("SOCKS5: Version 5 is expected.");
             }
 
             // Read response code
-            var status = SocketReadByte(socket);
+            var status = SocketReadByte(socket, connectionInfo.Timeout);
 
             switch (status)
             {
@@ -119,21 +119,21 @@ namespace Renci.SshNet.Connection
             }
 
             // Read reserved byte
-            if (SocketReadByte(socket) != 0)
+            if (SocketReadByte(socket, connectionInfo.Timeout) != 0)
             {
                 throw new ProxyException("SOCKS5: 0 byte is expected.");
             }
 
-            var addressType = SocketReadByte(socket);
+            var addressType = SocketReadByte(socket, connectionInfo.Timeout);
             switch (addressType)
             {
                 case 0x01:
                     var ipv4 = new byte[4];
-                    _ = SocketRead(socket, ipv4, 0, 4);
+                    _ = SocketRead(socket, ipv4, 0, 4, connectionInfo.Timeout);
                     break;
                 case 0x04:
                     var ipv6 = new byte[16];
-                    _ =SocketRead(socket, ipv6, 0, 16);
+                    _ =SocketRead(socket, ipv6, 0, 16, connectionInfo.Timeout);
                     break;
                 default:
                     throw new ProxyException(string.Format("Address type '{0}' is not supported.", addressType));
@@ -142,7 +142,7 @@ namespace Renci.SshNet.Connection
             var port = new byte[2];
 
             // Read 2 bytes to be ignored
-            _ = SocketRead(socket, port, 0, 2);
+            _ = SocketRead(socket, port, 0, 2, connectionInfo.Timeout);
         }
 
         /// <summary>

+ 116 - 0
test/Renci.SshNet.Tests/Classes/Connection/Socks5ConnectorTest_Connect_TimeoutConnectionReply.cs

@@ -0,0 +1,116 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Net;
+using System.Net.Sockets;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Moq;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Tests.Common;
+
+namespace Renci.SshNet.Tests.Classes.Connection
+{
+    [TestClass]
+    public class Socks5ConnectorTest_Connect_TimeoutConnectionReply : Socks5ConnectorTestBase
+    {
+        private ConnectionInfo _connectionInfo;
+        private Exception _actualException;
+        private AsyncSocketListener _proxyServer;
+        private Socket _clientSocket;
+        private List<byte> _bytesReceivedByProxy;
+        private Stopwatch _stopWatch;
+
+        protected override void SetupData()
+        {
+            base.SetupData();
+
+            var random = new Random();
+
+            _connectionInfo = CreateConnectionInfo("proxyUser", "proxyPwd");
+            _connectionInfo.Timeout = TimeSpan.FromMilliseconds(random.Next(50, 200));
+            _stopWatch = new Stopwatch();
+            _bytesReceivedByProxy = new List<byte>();
+
+            _clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+
+            _proxyServer = new AsyncSocketListener(new IPEndPoint(IPAddress.Loopback, _connectionInfo.ProxyPort));
+            _proxyServer.BytesReceived += (bytesReceived, socket) => {
+                _bytesReceivedByProxy.AddRange(bytesReceived);
+
+                if (_bytesReceivedByProxy.Count == 4) {
+                    _ = socket.Send(new byte[]
+                        {
+                                    // SOCKS version
+                                    0x05,
+                                    // Require no authentication
+                                    0x00
+                        });
+                }
+            };
+            _proxyServer.Start();
+        }
+
+        protected override void SetupMocks()
+        {
+            _ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+                                 .Returns(_clientSocket);
+        }
+
+        protected override void TearDown()
+        {
+            base.TearDown();
+
+            _proxyServer?.Dispose();
+            _clientSocket?.Dispose();
+        }
+
+        protected override void Act()
+        {
+            _stopWatch.Start();
+
+            try
+            {
+                _ = Connector.Connect(_connectionInfo);
+                Assert.Fail();
+            }
+            catch (SocketException ex) {
+                _actualException = ex;
+            }
+            catch (SshOperationTimeoutException ex) {
+                _actualException = ex;
+            }
+            finally
+            {
+                _stopWatch.Stop();
+            }
+        }
+
+        [TestMethod]
+        public void ConnectShouldHaveThrownSshOperationTimeoutException() {
+            Assert.IsNull(_actualException.InnerException);
+            Assert.IsInstanceOfType<SshOperationTimeoutException>(_actualException);
+        }
+
+        [TestMethod]
+        public void ConnectShouldHaveRespectedTimeout()
+        {
+            var errorText = string.Format("Elapsed: {0}, Timeout: {1}",
+                                          _stopWatch.ElapsedMilliseconds,
+                                          _connectionInfo.Timeout.TotalMilliseconds);
+
+            // Compare elapsed time with configured timeout, allowing for a margin of error
+            Assert.IsTrue(_stopWatch.ElapsedMilliseconds >= _connectionInfo.Timeout.TotalMilliseconds - 10, errorText);
+            Assert.IsTrue(_stopWatch.ElapsedMilliseconds < _connectionInfo.Timeout.TotalMilliseconds + 100, errorText);
+        }
+
+        [TestMethod]
+        public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
+        {
+            SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
+                                     Times.Once());
+        }
+    }
+}