فهرست منبع

fix ConnectAsync not respecting the connection timeout (#1502)

mus65 1 سال پیش
والد
کامیت
28e674228f
2فایلهای تغییر یافته به همراه84 افزوده شده و 1 حذف شده
  1. 11 1
      src/Renci.SshNet/BaseClient.cs
  2. 73 0
      test/Renci.SshNet.Tests/Classes/BaseClientTest_ConnectAsync_Timeout.cs

+ 11 - 1
src/Renci.SshNet/BaseClient.cs

@@ -307,7 +307,17 @@ namespace Renci.SshNet
                     DisposeSession(session);
                 }
 
-                Session = await CreateAndConnectSessionAsync(cancellationToken).ConfigureAwait(false);
+                using var timeoutCancellationTokenSource = new CancellationTokenSource(ConnectionInfo.Timeout);
+                using var linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCancellationTokenSource.Token);
+
+                try
+                {
+                    Session = await CreateAndConnectSessionAsync(linkedCancellationTokenSource.Token).ConfigureAwait(false);
+                }
+                catch (OperationCanceledException ex) when (timeoutCancellationTokenSource.IsCancellationRequested)
+                {
+                    throw new SshOperationTimeoutException("Connection has timed out.", ex);
+                }
             }
 
             try

+ 73 - 0
test/Renci.SshNet.Tests/Classes/BaseClientTest_ConnectAsync_Timeout.cs

@@ -0,0 +1,73 @@
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Moq;
+
+#if !NET8_0_OR_GREATER
+using Renci.SshNet.Abstractions;
+#endif
+using Renci.SshNet.Common;
+using Renci.SshNet.Connection;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class BaseClientTest_ConnectAsync_Timeout
+    {
+        private BaseClient _client;
+
+        [TestInitialize]
+        public void Init()
+        {
+            var sessionMock = new Mock<ISession>();
+            var serviceFactoryMock = new Mock<IServiceFactory>();
+            var socketFactoryMock = new Mock<ISocketFactory>();
+
+            sessionMock.Setup(p => p.ConnectAsync(It.IsAny<CancellationToken>()))
+                .Returns<CancellationToken>(c => Task.Delay(Timeout.Infinite, c));
+
+            serviceFactoryMock.Setup(p => p.CreateSocketFactory())
+                                               .Returns(socketFactoryMock.Object);
+
+            var connectionInfo = new ConnectionInfo("host", "user", new PasswordAuthenticationMethod("user", "pwd"))
+            {
+                Timeout = TimeSpan.FromSeconds(1)
+            };
+
+            serviceFactoryMock.Setup(p => p.CreateSession(connectionInfo, socketFactoryMock.Object))
+                                   .Returns(sessionMock.Object);
+
+            _client = new MyClient(connectionInfo, false, serviceFactoryMock.Object);
+        }
+
+        [TestMethod]
+        public async Task ConnectAsyncWithTimeoutThrowsSshTimeoutException()
+        {
+            await Assert.ThrowsExceptionAsync<SshOperationTimeoutException>(() => _client.ConnectAsync(CancellationToken.None));
+        }
+
+        [TestMethod]
+        public async Task ConnectAsyncWithCancelledTokenThrowsOperationCancelledException()
+        {
+            using var cancellationTokenSource = new CancellationTokenSource();
+            await cancellationTokenSource.CancelAsync();
+            await Assert.ThrowsExceptionAsync<OperationCanceledException>(() => _client.ConnectAsync(cancellationTokenSource.Token));
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+            _client?.Dispose();
+        }
+
+        private class MyClient : BaseClient
+        {
+            public MyClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory) : base(connectionInfo, ownsConnectionInfo, serviceFactory)
+            {
+            }
+        }
+    }
+}