2
0
Эх сурвалжийг харах

Ensure keep-alive timer is created when KeepAliveInterval is set after the connection has been established.
Fixes issue #334.

Gert Driesen 8 жил өмнө
parent
commit
66dcb0bfcb

+ 13 - 1
src/Renci.SshNet.Tests.NET35/Renci.SshNet.Tests.NET35.csproj

@@ -78,9 +78,21 @@
     <Compile Include="..\..\test\Renci.SshNet.Shared.Tests\ForwardedPortStatusTest_Stopping.cs">
       <Link>Classes\ForwardedPortStatusTest_Stopping.cs</Link>
     </Compile>
+    <Compile Include="..\Renci.SshNet.Tests\Classes\BaseClientTest_Connected_KeepAliveInterval_NegativeOne.cs">
+      <Link>Classes\BaseClientTest_Connected_KeepAliveInterval_NegativeOne.cs</Link>
+    </Compile>
+    <Compile Include="..\Renci.SshNet.Tests\Classes\BaseClientTest_Connected_KeepAliveInterval_NotNegativeOne.cs">
+      <Link>Classes\BaseClientTest_Connected_KeepAliveInterval_NotNegativeOne.cs</Link>
+    </Compile>
     <Compile Include="..\Renci.SshNet.Tests\Classes\BaseClientTest_Connected_KeepAlivesNotSentConcurrently.cs">
       <Link>Classes\BaseClientTest_Connected_KeepAlivesNotSentConcurrently.cs</Link>
     </Compile>
+    <Compile Include="..\Renci.SshNet.Tests\Classes\BaseClientTest_Disconnected_KeepAliveInterval_NotNegativeOne.cs">
+      <Link>Classes\BaseClientTest_Disconnected_KeepAliveInterval_NotNegativeOne.cs</Link>
+    </Compile>
+    <Compile Include="..\Renci.SshNet.Tests\Classes\BaseClientTest_NotConnected_KeepAliveInterval_NotNegativeOne.cs">
+      <Link>Classes\BaseClientTest_NotConnected_KeepAliveInterval_NotNegativeOne.cs</Link>
+    </Compile>
     <Compile Include="..\Renci.SshNet.Tests\Classes\Channels\ChannelDirectTcpipTest.cs">
       <Link>Classes\Channels\ChannelDirectTcpipTest.cs</Link>
     </Compile>
@@ -1692,7 +1704,7 @@
   <Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
   <ProjectExtensions>
     <VisualStudio>
-      <UserProperties ProjectLinkerExcludeFilter="\\?desktop(\\.*)?$;\\?silverlight(\\.*)?$;\.desktop;\.silverlight;\.xaml;^service references(\\.*)?$;\.clientconfig;^web references(\\.*)?$" ProjectLinkReference="c45379b9-17b1-4e89-bc2e-6d41726413e8" />
+      <UserProperties ProjectLinkReference="c45379b9-17b1-4e89-bc2e-6d41726413e8" ProjectLinkerExcludeFilter="\\?desktop(\\.*)?$;\\?silverlight(\\.*)?$;\.desktop;\.silverlight;\.xaml;^service references(\\.*)?$;\.clientconfig;^web references(\\.*)?$" />
     </VisualStudio>
   </ProjectExtensions>
   <!-- To modify your build process, add your task inside one of the targets below and uncomment it. 

+ 121 - 0
src/Renci.SshNet.Tests/Classes/BaseClientTest_Connected_KeepAliveInterval_NegativeOne.cs

@@ -0,0 +1,121 @@
+using System;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class BaseClientTest_Connected_KeepAliveInterval_NegativeOne
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<ISession> _sessionMock;
+        private BaseClient _client;
+        private ConnectionInfo _connectionInfo;
+        private TimeSpan _keepAliveInterval;
+        private int _keepAliveCount;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+            if (_client != null)
+            {
+                _sessionMock.Setup(p => p.OnDisconnecting());
+                _sessionMock.Setup(p => p.Dispose());
+                _client.Dispose();
+            }
+        }
+
+        private void SetupData()
+        {
+            _connectionInfo = new ConnectionInfo("host", "user", new PasswordAuthenticationMethod("user", "pwd"));
+            _keepAliveInterval = TimeSpan.FromMilliseconds(100d);
+            _keepAliveCount = 0;
+        }
+
+        private void CreateMocks()
+        {
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+        }
+
+        private void SetupMocks()
+        {
+            _serviceFactoryMock.Setup(p => p.CreateSession(_connectionInfo))
+                               .Returns(_sessionMock.Object);
+            _sessionMock.Setup(p => p.Connect());
+            _sessionMock.Setup(p => p.IsConnected).Returns(true);
+            _sessionMock.Setup(p => p.TrySendMessage(It.IsAny<IgnoreMessage>()))
+                        .Returns(true)
+                        .Callback(() => Interlocked.Increment(ref _keepAliveCount));
+        }
+
+        protected void Arrange()
+        {
+            SetupData();
+            CreateMocks();
+            SetupMocks();
+
+            _client = new MyClient(_connectionInfo, false, _serviceFactoryMock.Object);
+            _client.Connect();
+            _client.KeepAliveInterval = _keepAliveInterval;
+        }
+
+        protected void Act()
+        {
+            // allow keep-alive to be sent once
+            Thread.Sleep(150);
+
+            // disable keep-alive
+            _client.KeepAliveInterval = TimeSpan.FromMilliseconds(-1);
+        }
+
+        [TestMethod]
+        public void KeepAliveIntervalShouldReturnConfiguredValue()
+        {
+            Assert.AreEqual(TimeSpan.FromMilliseconds(-1), _client.KeepAliveInterval);
+        }
+
+        [TestMethod]
+        public void CreateSessionOnServiceFactoryShouldBeInvokedOnce()
+        {
+            _serviceFactoryMock.Verify(p => p.CreateSession(_connectionInfo), Times.Once);
+        }
+
+        [TestMethod]
+        public void ConnectOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.Connect(), Times.Once);
+        }
+
+        [TestMethod]
+        public void IsConnectedOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.IsConnected, Times.Once);
+        }
+
+        [TestMethod]
+        public void SendMessageOnSessionShouldBeInvokedThreeTimes()
+        {
+            // allow keep-alive to be sent once
+            Thread.Sleep(100);
+
+            _sessionMock.Verify(p => p.TrySendMessage(It.IsAny<IgnoreMessage>()), Times.Exactly(1));
+        }
+
+        private class MyClient : BaseClient
+        {
+            public MyClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory) : base(connectionInfo, ownsConnectionInfo, serviceFactory)
+            {
+            }
+        }
+    }
+}

+ 116 - 0
src/Renci.SshNet.Tests/Classes/BaseClientTest_Connected_KeepAliveInterval_NotNegativeOne.cs

@@ -0,0 +1,116 @@
+using System;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class BaseClientTest_Connected_KeepAliveInterval_NotNegativeOne
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<ISession> _sessionMock;
+        private BaseClient _client;
+        private ConnectionInfo _connectionInfo;
+        private TimeSpan _keepAliveInterval;
+        private int _keepAliveCount;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+            if (_client != null)
+            {
+                _sessionMock.Setup(p => p.OnDisconnecting());
+                _sessionMock.Setup(p => p.Dispose());
+                _client.Dispose();
+            }
+        }
+
+        private void SetupData()
+        {
+            _connectionInfo = new ConnectionInfo("host", "user", new PasswordAuthenticationMethod("user", "pwd"));
+            _keepAliveInterval = TimeSpan.FromMilliseconds(50d);
+            _keepAliveCount = 0;
+        }
+
+        private void CreateMocks()
+        {
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+        }
+
+        private void SetupMocks()
+        {
+            _serviceFactoryMock.Setup(p => p.CreateSession(_connectionInfo))
+                               .Returns(_sessionMock.Object);
+            _sessionMock.Setup(p => p.Connect());
+            _sessionMock.Setup(p => p.IsConnected).Returns(true);
+            _sessionMock.Setup(p => p.TrySendMessage(It.IsAny<IgnoreMessage>()))
+                        .Returns(true)
+                        .Callback(() => Interlocked.Increment(ref _keepAliveCount));
+        }
+
+        protected void Arrange()
+        {
+            SetupData();
+            CreateMocks();
+            SetupMocks();
+
+            _client = new MyClient(_connectionInfo, false, _serviceFactoryMock.Object);
+            _client.Connect();
+        }
+
+        protected void Act()
+        {
+            _client.KeepAliveInterval = _keepAliveInterval;
+
+            // allow keep-alive to be sent a few times
+            Thread.Sleep(195);
+        }
+
+        [TestMethod]
+        public void KeepAliveIntervalShouldReturnConfiguredValue()
+        {
+            Assert.AreEqual(_keepAliveInterval, _client.KeepAliveInterval);
+        }
+
+        [TestMethod]
+        public void CreateSessionOnServiceFactoryShouldBeInvokedOnce()
+        {
+            _serviceFactoryMock.Verify(p => p.CreateSession(_connectionInfo), Times.Once);
+        }
+
+        [TestMethod]
+        public void ConnectOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.Connect(), Times.Once);
+        }
+
+        [TestMethod]
+        public void IsConnectedOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.IsConnected, Times.Once);
+        }
+
+        [TestMethod]
+        public void SendMessageOnSessionShouldBeInvokedThreeTimes()
+        {
+            _sessionMock.Verify(p => p.TrySendMessage(It.IsAny<IgnoreMessage>()), Times.Exactly(3));
+        }
+
+        private class MyClient : BaseClient
+        {
+            public MyClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory) : base(connectionInfo, ownsConnectionInfo, serviceFactory)
+            {
+            }
+        }
+    }
+}

+ 20 - 7
src/Renci.SshNet.Tests/Classes/BaseClientTest_Connected_KeepAlivesNotSentConcurrently.cs

@@ -34,25 +34,38 @@ namespace Renci.SshNet.Tests.Classes
             }
         }
 
-        protected void Arrange()
+        private void SetupData()
         {
             _connectionInfo = new ConnectionInfo("host", "user", new PasswordAuthenticationMethod("user", "pwd"));
             _keepAliveSent = new ManualResetEvent(false);
+        }
 
+        private void CreateMocks()
+        {
             _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
             _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+        }
 
+        private void SetupMocks()
+        {
             _mockSequence = new MockSequence();
 
             _serviceFactoryMock.InSequence(_mockSequence).Setup(p => p.CreateSession(_connectionInfo)).Returns(_sessionMock.Object);
             _sessionMock.InSequence(_mockSequence).Setup(p => p.Connect());
             _sessionMock.InSequence(_mockSequence).Setup(p => p.TrySendMessage(It.IsAny<IgnoreMessage>()))
-                .Returns(true)
-                .Callback(() =>
-                    {
-                        Thread.Sleep(300);
-                        _keepAliveSent.Set();
-                    });
+                        .Returns(true)
+                        .Callback(() =>
+                        {
+                            Thread.Sleep(300);
+                            _keepAliveSent.Set();
+                        });
+        }
+
+        protected void Arrange()
+        {
+            SetupData();
+            CreateMocks();
+            SetupMocks();
 
             _client = new MyClient(_connectionInfo, false, _serviceFactoryMock.Object)
                 {

+ 115 - 0
src/Renci.SshNet.Tests/Classes/BaseClientTest_Disconnected_KeepAliveInterval_NotNegativeOne.cs

@@ -0,0 +1,115 @@
+using System;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class BaseClientTest_Disconnected_KeepAliveInterval_NotNegativeOne
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<ISession> _sessionMock;
+        private BaseClient _client;
+        private ConnectionInfo _connectionInfo;
+        private TimeSpan _keepAliveInterval;
+        private int _keepAliveCount;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+            if (_client != null)
+            {
+                _sessionMock.Setup(p => p.OnDisconnecting());
+                _sessionMock.Setup(p => p.Dispose());
+                _client.Dispose();
+            }
+        }
+
+        private void SetupData()
+        {
+            _connectionInfo = new ConnectionInfo("host", "user", new PasswordAuthenticationMethod("user", "pwd"));
+            _keepAliveInterval = TimeSpan.FromMilliseconds(50d);
+            _keepAliveCount = 0;
+        }
+
+        private void CreateMocks()
+        {
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+        }
+
+        private void SetupMocks()
+        {
+            _serviceFactoryMock.Setup(p => p.CreateSession(_connectionInfo))
+                               .Returns(_sessionMock.Object);
+            _sessionMock.Setup(p => p.Connect());
+            _sessionMock.Setup(p => p.IsConnected).Returns(false);
+            _sessionMock.Setup(p => p.TrySendMessage(It.IsAny<IgnoreMessage>()))
+                        .Returns(true);
+        }
+
+        protected void Arrange()
+        {
+            SetupData();
+            CreateMocks();
+            SetupMocks();
+
+            _client = new MyClient(_connectionInfo, false, _serviceFactoryMock.Object);
+            _client.Connect();
+        }
+
+        protected void Act()
+        {
+            _client.KeepAliveInterval = _keepAliveInterval;
+
+            // allow keep-alive to be sent a few times
+            Thread.Sleep(195);
+        }
+
+        [TestMethod]
+        public void KeepAliveIntervalShouldReturnConfiguredValue()
+        {
+            Assert.AreEqual(_keepAliveInterval, _client.KeepAliveInterval);
+        }
+
+        [TestMethod]
+        public void CreateSessionOnServiceFactoryShouldBeInvokedOnce()
+        {
+            _serviceFactoryMock.Verify(p => p.CreateSession(_connectionInfo), Times.Once);
+        }
+
+        [TestMethod]
+        public void ConnectOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.Connect(), Times.Once);
+        }
+
+        [TestMethod]
+        public void IsConnectedOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.IsConnected, Times.Once);
+        }
+
+        [TestMethod]
+        public void SendMessageOnSessionShouldNeverBeInvoked()
+        {
+            _sessionMock.Verify(p => p.TrySendMessage(It.IsAny<IgnoreMessage>()), Times.Never);
+        }
+
+        private class MyClient : BaseClient
+        {
+            public MyClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory) : base(connectionInfo, ownsConnectionInfo, serviceFactory)
+            {
+            }
+        }
+    }
+}

+ 102 - 0
src/Renci.SshNet.Tests/Classes/BaseClientTest_NotConnected_KeepAliveInterval_NotNegativeOne.cs

@@ -0,0 +1,102 @@
+using System;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class BaseClientTest_NotConnected_KeepAliveInterval_NotNegativeOne
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<ISession> _sessionMock;
+        private BaseClient _client;
+        private ConnectionInfo _connectionInfo;
+        private TimeSpan _keepAliveInterval;
+        private int _keepAliveCount;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+            if (_client != null)
+            {
+                _sessionMock.Setup(p => p.OnDisconnecting());
+                _sessionMock.Setup(p => p.Dispose());
+                _client.Dispose();
+            }
+        }
+
+        private void SetupData()
+        {
+            _connectionInfo = new ConnectionInfo("host", "user", new PasswordAuthenticationMethod("user", "pwd"));
+            _keepAliveInterval = TimeSpan.FromMilliseconds(100d);
+            _keepAliveCount = 0;
+        }
+
+        private void CreateMocks()
+        {
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+        }
+
+        private static void SetupMocks()
+        {
+        }
+
+        protected void Arrange()
+        {
+            SetupData();
+            CreateMocks();
+            SetupMocks();
+
+            _client = new MyClient(_connectionInfo, false, _serviceFactoryMock.Object);
+        }
+
+        protected void Act()
+        {
+            _client.KeepAliveInterval = _keepAliveInterval;
+
+            // allow keep-alive to be sent at least once
+            Thread.Sleep(150);
+        }
+
+        [TestMethod]
+        public void KeepAliveIntervalShouldReturnConfiguredValue()
+        {
+            Assert.AreEqual(_keepAliveInterval, _client.KeepAliveInterval);
+        }
+
+        [TestMethod]
+        public void ConnectShouldActivateKeepAliveIfSessionIs()
+        {
+            _serviceFactoryMock.Setup(p => p.CreateSession(_connectionInfo)).Returns(_sessionMock.Object);
+            _sessionMock.Setup(p => p.Connect());
+            _sessionMock.Setup(p => p.TrySendMessage(It.IsAny<IgnoreMessage>()))
+                        .Returns(true)
+                        .Callback(() => Interlocked.Increment(ref _keepAliveCount));
+
+            _client.Connect();
+
+            // allow keep-alive to be sent twice
+            Thread.Sleep(250);
+
+            // Exactly two keep-alives should be sent
+            _sessionMock.Verify(p => p.TrySendMessage(It.IsAny<IgnoreMessage>()), Times.Exactly(2));
+        }
+
+        private class MyClient : BaseClient
+        {
+            public MyClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory) : base(connectionInfo, ownsConnectionInfo, serviceFactory)
+            {
+            }
+        }
+    }
+}

+ 4 - 0
src/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj

@@ -88,7 +88,11 @@
     <Compile Include="..\..\test\Renci.SshNet.Shared.Tests\SshMessageFactoryTest.cs">
       <Link>Classes\SshMessageFactoryTest.cs</Link>
     </Compile>
+    <Compile Include="Classes\BaseClientTest_Connected_KeepAliveInterval_NotNegativeOne.cs" />
+    <Compile Include="Classes\BaseClientTest_Connected_KeepAliveInterval_NegativeOne.cs" />
     <Compile Include="Classes\BaseClientTest_Connected_KeepAlivesNotSentConcurrently.cs" />
+    <Compile Include="Classes\BaseClientTest_Disconnected_KeepAliveInterval_NotNegativeOne.cs" />
+    <Compile Include="Classes\BaseClientTest_NotConnected_KeepAliveInterval_NotNegativeOne.cs" />
     <Compile Include="Classes\Channels\ChannelDirectTcpipTest.cs" />
     <Compile Include="Classes\Channels\ChannelDirectTcpipTest_Dispose_SessionIsConnectedAndChannelIsOpen.cs" />
     <Compile Include="Classes\Channels\ChannelSessionTest_Dispose_Disposed.cs" />

+ 31 - 6
src/Renci.SshNet/BaseClient.cs

@@ -107,13 +107,25 @@ namespace Renci.SshNet
                 }
                 else
                 {
-                    // change the due time and interval of the timer if has already
-                    // been created (which means the client is connected)
-                    // 
-                    // if the client is not yet connected, then the timer will be
-                    // created with the new interval when Connect() is invoked
                     if (_keepAliveTimer != null)
+                    {
+                        // change the due time and interval of the timer if has already
+                        // been created (which means the client is connected)
+
                         _keepAliveTimer.Change(value, value);
+                    }
+                    else if (IsConnected)
+                    {
+                        // if timer has not yet been created and the client is already connected,
+                        // then we need to create the timer now
+                        //
+                        // this means that - before connecting - the keep-alive interval was set to
+                        // negative one (-1) and as such we did not create the timer
+                        _keepAliveTimer = CreateKeepAliveTimer(value, value);
+                    }
+
+                    // note that if the client is not yet connected, then the timer will be created with the 
+                    // new interval when Connect() is invoked
                 }
                 _keepAliveInterval = value;
             }
@@ -420,7 +432,20 @@ namespace Renci.SshNet
                 // timer is already started
                 return;
 
-            _keepAliveTimer = new Timer(state => SendKeepAliveMessage(), null, _keepAliveInterval, _keepAliveInterval);
+            _keepAliveTimer = CreateKeepAliveTimer(_keepAliveInterval, _keepAliveInterval);
+        }
+
+        /// <summary>
+        /// Creates a <see cref="Timer"/> with the specified due time and interval.
+        /// </summary>
+        /// <param name="dueTime">The amount of time to delay before the keep-alive message is first sent. Specify negative one (-1) milliseconds to prevent the timer from starting. Specify zero (0) to start the timer immediately.</param>
+        /// <param name="period">The time interval between attempts to send a keep-alive message. Specify negative one (-1) milliseconds to disable periodic signaling.</param>
+        /// <returns>
+        /// A <see cref="Timer"/> with the specified due time and interval.
+        /// </returns>
+        private Timer CreateKeepAliveTimer(TimeSpan dueTime, TimeSpan period)
+        {
+            return new Timer(state => SendKeepAliveMessage(), null, dueTime, period);
         }
     }
 }

+ 1 - 1
src/Renci.SshNet/Session.cs

@@ -34,7 +34,7 @@ namespace Renci.SshNet
         /// Specifies an infinite waiting period.
         /// </summary>
         /// <remarks>
-        /// The value of this field is <c>-1</c> millisecond. 
+        /// The value of this field is <c>-1</c> millisecond.
         /// </remarks>
         internal static readonly TimeSpan InfiniteTimeSpan = new TimeSpan(0, 0, 0, 0, -1);