瀏覽代碼

Only keep reference to session if the connection was established successfully.
Fixes issue #338.

Gert Driesen 8 年之前
父節點
當前提交
8141113259

+ 4 - 1
src/Renci.SshNet.Tests.NET35/Renci.SshNet.Tests.NET35.csproj

@@ -87,6 +87,9 @@
     <Compile Include="..\Renci.SshNet.Tests\Classes\BaseClientTest_Connected_KeepAlivesNotSentConcurrently.cs">
       <Link>Classes\BaseClientTest_Connected_KeepAlivesNotSentConcurrently.cs</Link>
     </Compile>
+    <Compile Include="..\Renci.SshNet.Tests\Classes\BaseClientTest_Connect_OnConnectedThrowsException.cs">
+      <Link>Classes\BaseClientTest_Connect_OnConnectedThrowsException.cs</Link>
+    </Compile>
     <Compile Include="..\Renci.SshNet.Tests\Classes\BaseClientTest_Disconnected_KeepAliveInterval_NotNegativeOne.cs">
       <Link>Classes\BaseClientTest_Disconnected_KeepAliveInterval_NotNegativeOne.cs</Link>
     </Compile>
@@ -1704,7 +1707,7 @@
   <Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
   <ProjectExtensions>
     <VisualStudio>
-      <UserProperties ProjectLinkReference="c45379b9-17b1-4e89-bc2e-6d41726413e8" ProjectLinkerExcludeFilter="\\?desktop(\\.*)?$;\\?silverlight(\\.*)?$;\.desktop;\.silverlight;\.xaml;^service references(\\.*)?$;\.clientconfig;^web references(\\.*)?$" />
+      <UserProperties ProjectLinkerExcludeFilter="\\?desktop(\\.*)?$;\\?silverlight(\\.*)?$;\.desktop;\.silverlight;\.xaml;^service references(\\.*)?$;\.clientconfig;^web references(\\.*)?$" ProjectLinkReference="c45379b9-17b1-4e89-bc2e-6d41726413e8" />
     </VisualStudio>
   </ProjectExtensions>
   <!-- To modify your build process, add your task inside one of the targets below and uncomment it. 

+ 181 - 0
src/Renci.SshNet.Tests/Classes/BaseClientTest_Connect_OnConnectedThrowsException.cs

@@ -0,0 +1,181 @@
+using System;
+using System.Reflection;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Common;
+using Renci.SshNet.Security;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class BaseClientTest_Connect_OnConnectedThrowsException
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<ISession> _sessionMock;
+        private MyClient _client;
+        private ConnectionInfo _connectionInfo;
+        private ApplicationException _onConnectException;
+        private ApplicationException _actualException;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+            if (_client != null)
+            {
+                _sessionMock.Setup(p => p.OnDisconnecting());
+                _sessionMock.Setup(p => p.Dispose());
+                _client.Dispose();
+            }
+        }
+
+        private void SetupData()
+        {
+            _connectionInfo = new ConnectionInfo("host", "user", new PasswordAuthenticationMethod("user", "pwd"));
+            _onConnectException = new ApplicationException();
+        }
+
+        private void CreateMocks()
+        {
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+        }
+
+        private void SetupMocks()
+        {
+            _serviceFactoryMock.Setup(p => p.CreateSession(_connectionInfo))
+                               .Returns(_sessionMock.Object);
+            _sessionMock.Setup(p => p.Connect());
+            _sessionMock.Setup(p => p.Dispose());
+        }
+
+        protected void Arrange()
+        {
+            SetupData();
+            CreateMocks();
+            SetupMocks();
+
+            _client = new MyClient(_connectionInfo, false, _serviceFactoryMock.Object)
+                {
+                    OnConnectedException = _onConnectException
+                };
+        }
+
+        protected void Act()
+        {
+            try
+            {
+                _client.Connect();
+                Assert.Fail();
+            }
+            catch (ApplicationException ex)
+            {
+                _actualException = ex;
+            }
+        }
+
+        [TestMethod]
+        public void ConnectShouldRethrowExceptionThrownByOnConnect()
+        {
+            Assert.IsNotNull(_actualException);
+            Assert.AreSame(_onConnectException, _actualException);
+        }
+
+        [TestMethod]
+        public void CreateSessionOnServiceFactoryShouldBeInvokedOnce()
+        {
+            _serviceFactoryMock.Verify(p => p.CreateSession(_connectionInfo), Times.Once);
+        }
+
+        [TestMethod]
+        public void ConnectOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.Connect(), Times.Once);
+        }
+
+        [TestMethod]
+        public void DisposeOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.Dispose(), Times.Once);
+        }
+
+        [TestMethod]
+        public void ErrorOccuredOnSessionShouldNoLongerBeSignaledViaErrorOccurredOnBaseClient()
+        {
+            var errorOccurredSignalCount = 0;
+
+            _client.ErrorOccurred += (sender, args) => Interlocked.Increment(ref errorOccurredSignalCount);
+
+            _sessionMock.Raise(p => p.ErrorOccured += null, new ExceptionEventArgs(new Exception()));
+
+            Assert.AreEqual(0, errorOccurredSignalCount);
+        }
+
+        [TestMethod]
+        public void HostKeyReceivedOnSessionShouldNoLongerBeSignaledViaHostKeyReceivedOnBaseClient()
+        {
+            var hostKeyReceivedSignalCount = 0;
+
+            _client.HostKeyReceived += (sender, args) => Interlocked.Increment(ref hostKeyReceivedSignalCount);
+
+            _sessionMock.Raise(p => p.HostKeyReceived += null, new HostKeyEventArgs(GetKeyHostAlgorithm()));
+
+            Assert.AreEqual(0, hostKeyReceivedSignalCount);
+        }
+
+        [TestMethod]
+        public void SessionShouldBeNull()
+        {
+            Assert.IsNull(_client.Session);
+        }
+
+        [TestMethod]
+        public void IsConnectedShouldReturnFalse()
+        {
+            Assert.IsFalse(_client.IsConnected);
+        }
+
+        private static KeyHostAlgorithm GetKeyHostAlgorithm()
+        {
+            var executingAssembly = Assembly.GetExecutingAssembly();
+
+            using (var s = executingAssembly.GetManifestResourceStream(string.Format("Renci.SshNet.Tests.Data.{0}", "Key.RSA.txt")))
+            {
+                var privateKey = new PrivateKeyFile(s);
+                return (KeyHostAlgorithm) privateKey.HostKey;
+            }
+        }
+
+        private class MyClient : BaseClient
+        {
+            private int _onConnectedCount;
+
+            public MyClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory) : base(connectionInfo, ownsConnectionInfo, serviceFactory)
+            {
+            }
+
+            public Exception OnConnectedException { get; set; }
+
+            protected override void OnConnected()
+            {
+                base.OnConnected();
+
+                Interlocked.Increment(ref _onConnectedCount);
+
+                if (OnConnectedException != null)
+                {
+                    throw OnConnectedException;
+                }
+            }
+        }
+
+
+    }
+}

+ 3 - 5
src/Renci.SshNet.Tests/Classes/SubsystemSession_Connect_Connected.cs

@@ -1,7 +1,6 @@
 using System;
 using System.Collections.Generic;
 using System.Globalization;
-using System.Text;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
 using Moq;
 using Renci.SshNet.Channels;
@@ -46,10 +45,9 @@ namespace Renci.SshNet.Tests.Classes
             _channelMock.InSequence(_sequence).Setup(p => p.SendSubsystemRequest(_subsystemName)).Returns(true);
             _channelMock.InSequence(_sequence).Setup(p => p.IsOpen).Returns(true);
 
-            _subsystemSession = new SubsystemSessionStub(
-                _sessionMock.Object,
-                _subsystemName,
-                _operationTimeout);
+            _subsystemSession = new SubsystemSessionStub(_sessionMock.Object,
+                                                         _subsystemName,
+                                                         _operationTimeout);
             _subsystemSession.Disconnected += (sender, args) => _disconnectedRegister.Add(args);
             _subsystemSession.ErrorOccurred += (sender, args) => _errorOccurredRegister.Add(args);
             _subsystemSession.Connect();

+ 1 - 0
src/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj

@@ -91,6 +91,7 @@
     <Compile Include="Classes\BaseClientTest_Connected_KeepAliveInterval_NotNegativeOne.cs" />
     <Compile Include="Classes\BaseClientTest_Connected_KeepAliveInterval_NegativeOne.cs" />
     <Compile Include="Classes\BaseClientTest_Connected_KeepAlivesNotSentConcurrently.cs" />
+    <Compile Include="Classes\BaseClientTest_Connect_OnConnectedThrowsException.cs" />
     <Compile Include="Classes\BaseClientTest_Disconnected_KeepAliveInterval_NotNegativeOne.cs" />
     <Compile Include="Classes\BaseClientTest_NotConnected_KeepAliveInterval_NotNegativeOne.cs" />
     <Compile Include="Classes\Channels\ChannelDirectTcpipTest.cs" />

+ 82 - 25
src/Renci.SshNet/BaseClient.cs

@@ -74,7 +74,8 @@ namespace Renci.SshNet
             get
             {
                 CheckDisposed();
-                return Session != null && Session.IsConnected;
+
+                return IsSessionConnected();
             }
         }
 
@@ -114,7 +115,7 @@ namespace Renci.SshNet
 
                         _keepAliveTimer.Change(value, value);
                     }
-                    else if (IsConnected)
+                    else if (IsSessionConnected())
                     {
                         // if timer has not yet been created and the client is already connected,
                         // then we need to create the timer now
@@ -216,16 +217,26 @@ namespace Renci.SshNet
             // forwarded port with a client instead of with a session
             //
             // To be discussed with Oleg (or whoever is interested)
-            if (Session != null && Session.IsConnected)
+            if (IsSessionConnected())
                 throw new InvalidOperationException("The client is already connected.");
 
             OnConnecting();
-            Session = _serviceFactory.CreateSession(ConnectionInfo);
-            Session.HostKeyReceived += Session_HostKeyReceived;
-            Session.ErrorOccured += Session_ErrorOccured;
-            Session.Connect();
+
+            Session = CreateAndConnectSession();
+            try
+            {
+                // Even though the method we invoke makes you believe otherwise, at this point only
+                // the SSH session itself is connected.
+                OnConnected();
+            }
+            catch
+            {
+                // Only dispose the session as Disconnect() would have side-effects (such as remove forwarded
+                // ports in SshClient).
+                DisposeSession();
+                throw;
+            }
             StartKeepAliveTimer();
-            OnConnected();
         }
 
         /// <summary>
@@ -240,20 +251,11 @@ namespace Renci.SshNet
 
             OnDisconnecting();
 
-            // stop sending keep-alive messages before we close the
-            // session
+            // stop sending keep-alive messages before we close the session
             StopKeepAliveTimer();
 
-            // disconnect and dispose the SSH session
-            if (Session != null)
-            {
-                // a new session is created in Connect(), so we should dispose and
-                // dereference the current session here
-                Session.ErrorOccured -= Session_ErrorOccured;
-                Session.HostKeyReceived -= Session_HostKeyReceived;
-                Session.Dispose();
-                Session = null;
-            }
+            // dispose the SSH session
+            DisposeSession();
 
             OnDisconnected();
         }
@@ -293,8 +295,11 @@ namespace Renci.SshNet
         /// </summary>
         protected virtual void OnDisconnecting()
         {
-            if (Session != null)
-                Session.OnDisconnecting();
+            var session = Session;
+            if (session != null)
+            {
+                session.OnDisconnecting();
+            }
         }
 
         /// <summary>
@@ -398,8 +403,10 @@ namespace Renci.SshNet
 
         private void SendKeepAliveMessage()
         {
+            var session = Session;
+
             // do nothing if we have disposed or disconnected
-            if (Session == null)
+            if (session == null)
                 return;
 
             // do not send multiple keep-alive messages concurrently
@@ -407,7 +414,7 @@ namespace Renci.SshNet
             {
                 try
                 {
-                    Session.TrySendMessage(new IgnoreMessage());
+                    session.TrySendMessage(new IgnoreMessage());
                 }
                 finally
                 {
@@ -445,7 +452,57 @@ namespace Renci.SshNet
         /// </returns>
         private Timer CreateKeepAliveTimer(TimeSpan dueTime, TimeSpan period)
         {
-            return new Timer(state => SendKeepAliveMessage(), null, dueTime, period);
+            return new Timer(state => SendKeepAliveMessage(), Session, dueTime, period);
+        }
+
+        private ISession CreateAndConnectSession()
+        {
+            var session = _serviceFactory.CreateSession(ConnectionInfo);
+            session.HostKeyReceived += Session_HostKeyReceived;
+            session.ErrorOccured += Session_ErrorOccured;
+
+            try
+            {
+                session.Connect();
+                return session;
+            }
+            catch
+            {
+                DisposeSession(session);
+                throw;
+            }
+        }
+
+        private void DisposeSession(ISession session)
+        {
+            session.ErrorOccured -= Session_ErrorOccured;
+            session.HostKeyReceived -= Session_HostKeyReceived;
+            session.Dispose();
+        }
+
+        /// <summary>
+        /// Disposes the SSH session, and assigns <c>null</c> to <see cref="Session"/>.
+        /// </summary>
+        private void DisposeSession()
+        {
+            var session = Session;
+            if (session != null)
+            {
+                Session = null;
+                DisposeSession(session);
+            }
+        }
+
+        /// <summary>
+        /// Returns a value indicating whether the SSH session is established.
+        /// </summary>
+        /// <returns>
+        /// <c>true</c> if the SSH session is established; otherwise, <c>false</c>.
+        /// </returns>
+        private bool IsSessionConnected()
+        {
+            var session = Session;
+            return session != null && session.IsConnected;
         }
     }
 }

+ 16 - 2
src/Renci.SshNet/NetConfClient.cs

@@ -219,8 +219,7 @@ namespace Renci.SshNet
         {
             base.OnConnected();
 
-            _netConfSession = ServiceFactory.CreateNetConfSession(Session, _operationTimeout);
-            _netConfSession.Connect();
+            _netConfSession = CreateAndConnectNetConfSession();
         }
 
         /// <summary>
@@ -250,5 +249,20 @@ namespace Renci.SshNet
                 }
             }
         }
+
+        private INetConfSession CreateAndConnectNetConfSession()
+        {
+            var netConfSession = ServiceFactory.CreateNetConfSession(Session, _operationTimeout);
+            try
+            {
+                netConfSession.Connect();
+                return netConfSession;
+            }
+            catch
+            {
+                netConfSession.Dispose();
+                throw;
+            }
+        }
     }
 }

+ 19 - 5
src/Renci.SshNet/SftpClient.cs

@@ -2151,11 +2151,7 @@ namespace Renci.SshNet
         {
             base.OnConnected();
 
-            _sftpSession = ServiceFactory.CreateSftpSession(Session,
-                                                            _operationTimeout,
-                                                            ConnectionInfo.Encoding,
-                                                            ServiceFactory.CreateSftpResponseFactory());
-            _sftpSession.Connect();
+            _sftpSession = CreateAndConnectToSftpSession();
         }
 
         /// <summary>
@@ -2193,5 +2189,23 @@ namespace Renci.SshNet
                 }
             }
         }
+
+        private ISftpSession CreateAndConnectToSftpSession()
+        {
+            var sftpSession = ServiceFactory.CreateSftpSession(Session,
+                                                               _operationTimeout,
+                                                               ConnectionInfo.Encoding,
+                                                               ServiceFactory.CreateSftpResponseFactory());
+            try
+            {
+                sftpSession.Connect();
+                return sftpSession;
+            }
+            catch
+            {
+                sftpSession.Dispose();
+                throw;
+            }
+        }
     }
 }