ソースを参照

Added simple ServiceFactory to allow unit testing of SshClient.
When disposing SshClient, make sure to stop forwarded ports before the session is closed.

Gert Driesen 11 年 前
コミット
9366658bc4
25 ファイル変更474 行追加100 行削除
  1. 2 2
      Renci.SshClient/Renci.SshNet.Tests/Classes/ClientAuthenticationTestBase.cs
  2. 23 3
      Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortLocalTest_Dispose_PortDisposed.cs
  3. 60 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortLocalTest_Dispose_PortDisposed_NeverStarted.cs
  4. 0 4
      Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortRemoteTest_Dispose_PortDisposed.cs
  5. 75 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SshClientTest_Disconnect_ForwardedPortStarted.cs
  6. 97 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SshClientTest_Dispose_ForwardedPortStarted.cs
  7. 3 0
      Renci.SshClient/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj
  8. 27 5
      Renci.SshClient/Renci.SshNet/BaseClient.cs
  9. 1 1
      Renci.SshClient/Renci.SshNet/ClientAuthentication.cs
  10. 4 4
      Renci.SshClient/Renci.SshNet/ConnectionInfo.cs
  11. 2 5
      Renci.SshClient/Renci.SshNet/ForwardedPort.cs
  12. 6 8
      Renci.SshClient/Renci.SshNet/ForwardedPortDynamic.NET.cs
  13. 31 22
      Renci.SshClient/Renci.SshNet/IConnectionInfo.cs
  14. 1 1
      Renci.SshClient/Renci.SshNet/IForwardedPort.cs
  15. 7 0
      Renci.SshClient/Renci.SshNet/IServiceFactory.cs
  16. 39 6
      Renci.SshClient/Renci.SshNet/ISession.cs
  17. 1 1
      Renci.SshClient/Renci.SshNet/Netconf/NetConfSession.cs
  18. 2 0
      Renci.SshClient/Renci.SshNet/Renci.SshNet.csproj
  19. 7 7
      Renci.SshClient/Renci.SshNet/ScpClient.NET.cs
  20. 7 7
      Renci.SshClient/Renci.SshNet/ScpClient.cs
  21. 10 0
      Renci.SshClient/Renci.SshNet/ServiceFactory.cs
  22. 3 3
      Renci.SshClient/Renci.SshNet/Session.cs
  23. 1 1
      Renci.SshClient/Renci.SshNet/Sftp/SftpSession.cs
  24. 3 3
      Renci.SshClient/Renci.SshNet/Shell.cs
  25. 62 17
      Renci.SshClient/Renci.SshNet/SshClient.cs

+ 2 - 2
Renci.SshClient/Renci.SshNet.Tests/Classes/ClientAuthenticationTestBase.cs

@@ -10,7 +10,7 @@ namespace Renci.SshNet.Tests.Classes
     [TestClass]
     public abstract class ClientAuthenticationTestBase : TestBase
     {
-        internal Mock<IConnectionInfo> ConnectionInfoMock { get; private set; }
+        internal Mock<IConnectionInfoInternal> ConnectionInfoMock { get; private set; }
         internal Mock<ISession> SessionMock { get; private set; }
         internal Mock<IAuthenticationMethod> NoneAuthenticationMethodMock { get; private set; }
         internal Mock<IAuthenticationMethod> PasswordAuthenticationMethodMock { get; private set; }
@@ -20,7 +20,7 @@ namespace Renci.SshNet.Tests.Classes
 
         protected void CreateMocks()
         {
-            ConnectionInfoMock = new Mock<IConnectionInfo>(MockBehavior.Strict);
+            ConnectionInfoMock = new Mock<IConnectionInfoInternal>(MockBehavior.Strict);
             SessionMock = new Mock<ISession>(MockBehavior.Strict);
             NoneAuthenticationMethodMock = new Mock<IAuthenticationMethod>(MockBehavior.Strict);
             PasswordAuthenticationMethodMock = new Mock<IAuthenticationMethod>(MockBehavior.Strict);

+ 23 - 3
Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortLocalTest_Dispose_PortDisposed.cs

@@ -1,6 +1,8 @@
 using System;
 using System.Collections.Generic;
+using System.Net;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
 using Renci.SshNet.Common;
 
 namespace Renci.SshNet.Tests.Classes
@@ -8,6 +10,8 @@ namespace Renci.SshNet.Tests.Classes
     [TestClass]
     public class ForwardedPortLocalTest_Dispose_PortDisposed
     {
+        private Mock<ISession> _sessionMock;
+        private Mock<IConnectionInfo> _connectionInfoMock;
         private ForwardedPortLocal _forwardedPort;
         private IList<EventArgs> _closingRegister;
         private IList<ExceptionEventArgs> _exceptionRegister;
@@ -34,9 +38,19 @@ namespace Renci.SshNet.Tests.Classes
             _closingRegister = new List<EventArgs>();
             _exceptionRegister = new List<ExceptionEventArgs>();
 
-            _forwardedPort = new ForwardedPortLocal("boundHost", "host", 22);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+            _connectionInfoMock = new Mock<IConnectionInfo>(MockBehavior.Strict);
+
+            var sequence = new MockSequence();
+            _sessionMock.InSequence(sequence).Setup(p => p.IsConnected).Returns(true);
+            _sessionMock.InSequence(sequence).Setup(p => p.ConnectionInfo).Returns(_connectionInfoMock.Object);
+            _connectionInfoMock.InSequence(sequence).Setup(p => p.Timeout).Returns(TimeSpan.FromSeconds(30));
+
+            _forwardedPort = new ForwardedPortLocal(IPAddress.Loopback.ToString(), "host", 22);
             _forwardedPort.Closing += (sender, args) => _closingRegister.Add(args);
             _forwardedPort.Exception += (sender, args) => _exceptionRegister.Add(args);
+            _forwardedPort.Session = _sessionMock.Object;
+            _forwardedPort.Start();
             _forwardedPort.Dispose();
         }
 
@@ -46,9 +60,9 @@ namespace Renci.SshNet.Tests.Classes
         }
 
         [TestMethod]
-        public void ClosingShouldNotHaveFired()
+        public void ClosingShouldHaveFiredOnce()
         {
-            Assert.AreEqual(0, _closingRegister.Count);
+            Assert.AreEqual(1, _closingRegister.Count);
         }
 
         [TestMethod]
@@ -56,5 +70,11 @@ namespace Renci.SshNet.Tests.Classes
         {
             Assert.AreEqual(0, _exceptionRegister.Count);
         }
+
+        [TestMethod]
+        public void SessionShouldBeNull()
+        {
+            Assert.IsNull(_forwardedPort.Session);
+        }
     }
 }

+ 60 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortLocalTest_Dispose_PortDisposed_NeverStarted.cs

@@ -0,0 +1,60 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Renci.SshNet.Common;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class ForwardedPortLocalTest_Dispose_PortDisposed_NeverStarted
+    {
+        private ForwardedPortLocal _forwardedPort;
+        private IList<EventArgs> _closingRegister;
+        private IList<ExceptionEventArgs> _exceptionRegister;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+            if (_forwardedPort != null)
+            {
+                _forwardedPort.Dispose();
+                _forwardedPort = null;
+            }
+        }
+
+        protected void Arrange()
+        {
+            _closingRegister = new List<EventArgs>();
+            _exceptionRegister = new List<ExceptionEventArgs>();
+
+            _forwardedPort = new ForwardedPortLocal("boundHost", "host", 22);
+            _forwardedPort.Closing += (sender, args) => _closingRegister.Add(args);
+            _forwardedPort.Exception += (sender, args) => _exceptionRegister.Add(args);
+            _forwardedPort.Dispose();
+        }
+
+        protected void Act()
+        {
+            _forwardedPort.Dispose();
+        }
+
+        [TestMethod]
+        public void ClosingShouldNotHaveFired()
+        {
+            Assert.AreEqual(0, _closingRegister.Count);
+        }
+
+        [TestMethod]
+        public void ExceptionShouldNotHaveFired()
+        {
+            Assert.AreEqual(0, _exceptionRegister.Count);
+        }
+    }
+}

+ 0 - 4
Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortRemoteTest_Dispose_PortDisposed.cs

@@ -1,12 +1,8 @@
 using System;
 using System.Collections.Generic;
-using System.Globalization;
 using System.Net;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Moq;
-using Renci.SshNet.Channels;
 using Renci.SshNet.Common;
-using Renci.SshNet.Messages.Connection;
 
 namespace Renci.SshNet.Tests.Classes
 {

+ 75 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SshClientTest_Disconnect_ForwardedPortStarted.cs

@@ -0,0 +1,75 @@
+using System.Linq;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SshClientTest_Disconnect_ForwardedPortStarted
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<ISession> _sessionMock;
+        private Mock<ForwardedPort> _forwardedPortMock;
+        private SshClient _sshClient;
+        private ConnectionInfo _connectionInfo;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+        }
+
+        protected void Arrange()
+        {
+            _connectionInfo = new ConnectionInfo("host", "user", new NoneAuthenticationMethod("userauth"));
+
+            var sequence = new MockSequence();
+
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+            _forwardedPortMock = new Mock<ForwardedPort>(MockBehavior.Strict);
+
+            _serviceFactoryMock.InSequence(sequence).Setup(p => p.CreateSession(_connectionInfo)).Returns(_sessionMock.Object);
+            _sessionMock.InSequence(sequence).Setup(p => p.Connect());
+            _forwardedPortMock.InSequence(sequence).Setup(p => p.Start());
+            _sessionMock.InSequence(sequence).Setup(p => p.OnDisconnecting());
+            _forwardedPortMock.InSequence(sequence).Setup(p => p.Stop());
+            _sessionMock.InSequence(sequence).Setup(p => p.Disconnect());
+
+            _sshClient = new SshClient(_connectionInfo, false, _serviceFactoryMock.Object);
+            _sshClient.Connect();
+            _sshClient.AddForwardedPort(_forwardedPortMock.Object);
+
+            _forwardedPortMock.Object.Start();
+        }
+
+        protected void Act()
+        {
+            _sshClient.Disconnect();
+        }
+
+        [TestMethod]
+        public void ForwardedPortShouldBeStopped()
+        {
+            _forwardedPortMock.Verify(p => p.Stop(), Times.Once);
+        }
+
+        [TestMethod]
+        public void ForwardedPortShouldBeRemovedFromSshClient()
+        {
+            Assert.IsFalse(_sshClient.ForwardedPorts.Any());
+        }
+
+        [TestMethod]
+        public void DisconnectOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.Disconnect(), Times.Once);
+        }
+    }
+}

+ 97 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SshClientTest_Dispose_ForwardedPortStarted.cs

@@ -0,0 +1,97 @@
+using System;
+using System.Linq;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SshClientTest_Dispose_ForwardedPortStarted
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<ISession> _sessionMock;
+        private Mock<ForwardedPort> _forwardedPortMock;
+        private SshClient _sshClient;
+        private ConnectionInfo _connectionInfo;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+        }
+
+        protected void Arrange()
+        {
+            _connectionInfo = new ConnectionInfo("host", "user", new NoneAuthenticationMethod("userauth"));
+
+            var sequence = new MockSequence();
+
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+            _forwardedPortMock = new Mock<ForwardedPort>(MockBehavior.Strict);
+
+            _serviceFactoryMock.InSequence(sequence).Setup(p => p.CreateSession(_connectionInfo)).Returns(_sessionMock.Object);
+            _sessionMock.InSequence(sequence).Setup(p => p.Connect());
+            _forwardedPortMock.InSequence(sequence).Setup(p => p.Start());
+            _sessionMock.InSequence(sequence).Setup(p => p.OnDisconnecting());
+            _forwardedPortMock.InSequence(sequence).Setup(p => p.Stop());
+            _sessionMock.InSequence(sequence).Setup(p => p.Disconnect());
+            _sessionMock.InSequence(sequence).Setup(p => p.Dispose());
+
+            _sshClient = new SshClient(_connectionInfo, false, _serviceFactoryMock.Object);
+            _sshClient.Connect();
+            _sshClient.AddForwardedPort(_forwardedPortMock.Object);
+
+            _forwardedPortMock.Object.Start();
+        }
+
+        protected void Act()
+        {
+            _sshClient.Dispose();
+        }
+
+        [TestMethod]
+        public void ForwardedPortShouldBeStopped()
+        {
+            _forwardedPortMock.Verify(p => p.Stop(), Times.Once);
+        }
+
+        [TestMethod]
+        public void ForwardedPortShouldBeRemovedFromSshClient()
+        {
+            Assert.IsFalse(_sshClient.ForwardedPorts.Any());
+        }
+
+        [TestMethod]
+        public void IsConnectedShouldThrowObjectDisposedException()
+        {
+            try
+            {
+                var connected = _sshClient.IsConnected;
+                Assert.Fail("IsConnected should have thrown {0} but returned {1}.",
+                    typeof (ObjectDisposedException).FullName, connected);
+            }
+            catch (ObjectDisposedException)
+            {
+            }
+        }
+
+        [TestMethod]
+        public void DisconnectOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.Disconnect(), Times.Once);
+        }
+
+        [TestMethod]
+        public void DisposeOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.Disconnect(), Times.Once);
+        }
+    }
+}

+ 3 - 0
Renci.SshClient/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj

@@ -113,6 +113,7 @@
     <Compile Include="Classes\ForwardedPortDynamicTest_Start_PortStopped.cs" />
     <Compile Include="Classes\ForwardedPortDynamicTest_Start_PortStarted.cs" />
     <Compile Include="Classes\ForwardedPortLocalTest_Dispose_PortDisposed.cs" />
+    <Compile Include="Classes\ForwardedPortLocalTest_Dispose_PortDisposed_NeverStarted.cs" />
     <Compile Include="Classes\ForwardedPortLocalTest_Dispose_PortNeverStarted.cs" />
     <Compile Include="Classes\ForwardedPortLocalTest_Dispose_PortStarted_ChannelNotBound.cs" />
     <Compile Include="Classes\ForwardedPortLocalTest_Dispose_PortStopped.cs" />
@@ -305,6 +306,8 @@
     <Compile Include="Classes\ShellTestTest.cs" />
     <Compile Include="Classes\ShellStreamTest.cs" />
     <Compile Include="Classes\SshClientTest.cs" />
+    <Compile Include="Classes\SshClientTest_Disconnect_ForwardedPortStarted.cs" />
+    <Compile Include="Classes\SshClientTest_Dispose_ForwardedPortStarted.cs" />
     <Compile Include="Classes\SshCommandTest.cs" />
     <Compile Include="Classes\ForwardedPortLocalTest.NET40.cs" />
     <Compile Include="Classes\SshCommandTest.NET40.cs" />

+ 27 - 5
Renci.SshClient/Renci.SshNet/BaseClient.cs

@@ -14,6 +14,8 @@ namespace Renci.SshNet
         /// Holds value indicating whether the connection info is owned by this client.
         /// </summary>
         private readonly bool _ownsConnectionInfo;
+
+        private readonly IServiceFactory _serviceFactory;
         private TimeSpan _keepAliveInterval;
         private Timer _keepAliveTimer;
         private ConnectionInfo _connectionInfo;
@@ -21,7 +23,7 @@ namespace Renci.SshNet
         /// <summary>
         /// Gets current session.
         /// </summary>
-        protected Session Session { get; private set; }
+        internal ISession Session { get; private set; }
 
         /// <summary>
         /// Gets the connection info.
@@ -81,7 +83,7 @@ namespace Renci.SshNet
                 if (value == _keepAliveInterval)
                     return;
 
-                if (value == Session.InfiniteTimeSpan)
+                if (value == SshNet.Session.InfiniteTimeSpan)
                 {
                     // stop the timer when the value is -1 milliseconds
                     StopKeepAliveTimer();
@@ -127,13 +129,33 @@ namespace Renci.SshNet
         /// connection info will be disposed when this instance is disposed.
         /// </remarks>
         protected BaseClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo)
+            : this(connectionInfo, ownsConnectionInfo, new ServiceFactory())
+        {
+        }
+
+        /// <summary>
+        /// Initializes a new instance of the <see cref="BaseClient"/> class.
+        /// </summary>
+        /// <param name="connectionInfo">The connection info.</param>
+        /// <param name="ownsConnectionInfo">Specified whether this instance owns the connection info.</param>
+        /// <param name="serviceFactory">The factory to use for creating new services.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is null.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serviceFactory"/> is null.</exception>
+        /// <remarks>
+        /// If <paramref name="ownsConnectionInfo"/> is <c>true</c>, then the
+        /// connection info will be disposed when this instance is disposed.
+        /// </remarks>
+        internal BaseClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory)
         {
             if (connectionInfo == null)
                 throw new ArgumentNullException("connectionInfo");
+            if (serviceFactory == null)
+                throw new ArgumentNullException("serviceFactory");
 
             ConnectionInfo = connectionInfo;
             _ownsConnectionInfo = ownsConnectionInfo;
-            _keepAliveInterval = Session.InfiniteTimeSpan;
+            _serviceFactory = serviceFactory;
+            _keepAliveInterval = SshNet.Session.InfiniteTimeSpan;
         }
 
         /// <summary>
@@ -169,7 +191,7 @@ namespace Renci.SshNet
                 throw new InvalidOperationException("The client is already connected.");
 
             OnConnecting();
-            Session = new Session(ConnectionInfo);
+            Session = _serviceFactory.CreateSession(ConnectionInfo);
             Session.HostKeyReceived += Session_HostKeyReceived;
             Session.ErrorOccured += Session_ErrorOccured;
             Session.Connect();
@@ -353,7 +375,7 @@ namespace Renci.SshNet
         /// </remarks>
         private void StartKeepAliveTimer()
         {
-            if (_keepAliveInterval == Session.InfiniteTimeSpan)
+            if (_keepAliveInterval == SshNet.Session.InfiniteTimeSpan)
                 return;
 
             if (_keepAliveTimer == null)

+ 1 - 1
Renci.SshClient/Renci.SshNet/ClientAuthentication.cs

@@ -7,7 +7,7 @@ namespace Renci.SshNet
 {
     internal class ClientAuthentication
     {
-        public void Authenticate(IConnectionInfo connectionInfo, ISession session)
+        public void Authenticate(IConnectionInfoInternal connectionInfo, ISession session)
         {
             if (connectionInfo == null)
                 throw new ArgumentNullException("connectionInfo");

+ 4 - 4
Renci.SshClient/Renci.SshNet/ConnectionInfo.cs

@@ -19,7 +19,7 @@ namespace Renci.SshNet
     /// This class is NOT thread-safe. Do not use the same <see cref="ConnectionInfo"/> with multiple
     /// client instances.
     /// </remarks>
-    public class ConnectionInfo : IConnectionInfo
+    public class ConnectionInfo : IConnectionInfoInternal
     {
         internal static int DEFAULT_PORT = 22;
 
@@ -417,7 +417,7 @@ namespace Renci.SshNet
         /// </summary>
         /// <param name="sender">The session in which the banner message was received.</param>
         /// <param name="e">The banner message.{</param>
-        void IConnectionInfo.UserAuthenticationBannerReceived(object sender, MessageEventArgs<BannerMessage> e)
+        void IConnectionInfoInternal.UserAuthenticationBannerReceived(object sender, MessageEventArgs<BannerMessage> e)
         {
             var authenticationBanner = AuthenticationBanner;
             if (authenticationBanner != null)
@@ -427,12 +427,12 @@ namespace Renci.SshNet
             }
         }
 
-        IAuthenticationMethod IConnectionInfo.CreateNoneAuthenticationMethod()
+        IAuthenticationMethod IConnectionInfoInternal.CreateNoneAuthenticationMethod()
         {
             return new NoneAuthenticationMethod(Username);
         }
 
-        IEnumerable<IAuthenticationMethod> IConnectionInfo.AuthenticationMethods
+        IEnumerable<IAuthenticationMethod> IConnectionInfoInternal.AuthenticationMethods
         {
             get { return AuthenticationMethods.Cast<IAuthenticationMethod>(); }
         }

+ 2 - 5
Renci.SshClient/Renci.SshNet/ForwardedPort.cs

@@ -69,7 +69,7 @@ namespace Renci.SshNet
         /// <summary>
         /// Stops port forwarding.
         /// </summary>
-        public void Stop()
+        public virtual void Stop()
         {
             CheckDisposed();
 
@@ -111,12 +111,9 @@ namespace Renci.SshNet
                 {
                     Session.ErrorOccured -= Session_ErrorOccured;
                     StopPort(Session.ConnectionInfo.Timeout);
+                    Session = null;
                 }
             }
-            else
-            {
-                StopPort(TimeSpan.Zero);
-            }
         }
 
         /// <summary>

+ 6 - 8
Renci.SshClient/Renci.SshNet/ForwardedPortDynamic.NET.cs

@@ -60,6 +60,12 @@ namespace Renci.SshNet
                     }
                     finally
                     {
+                        if (Session != null)
+                        {
+                            Session.ErrorOccured -= Session_ErrorOccured;
+                            Session.Disconnected -= Session_Disconnected;
+                        }
+
                         // mark listener stopped
                         _listenerCompleted.Set();
                     }
@@ -185,9 +191,6 @@ namespace Renci.SshNet
             if (!IsStarted)
                 return;
 
-            Session.ErrorOccured -= Session_ErrorOccured;
-            Session.Disconnected -= Session_Disconnected;
-
             // close listener socket
             _listener.Close();
             // wait for listener loop to finish
@@ -229,11 +232,6 @@ namespace Renci.SshNet
         {
             if (disposing)
             {
-                if (Session != null)
-                {
-                    Session.ErrorOccured -= Session_ErrorOccured;
-                    Session.Disconnected -= Session_Disconnected;
-                }
                 if (_listener != null)
                 {
                     _listener.Dispose();

+ 31 - 22
Renci.SshClient/Renci.SshNet/IConnectionInfo.cs

@@ -1,11 +1,40 @@
 using System;
 using System.Collections.Generic;
 using System.Text;
+using Renci.SshNet.Common;
 using Renci.SshNet.Messages.Authentication;
 using Renci.SshNet.Messages.Connection;
 
 namespace Renci.SshNet
 {
+    internal interface IConnectionInfoInternal : IConnectionInfo
+    {
+        /// <summary>
+        /// Signals that an authentication banner message was received from the server.
+        /// </summary>
+        /// <param name="sender">The session in which the banner message was received.</param>
+        /// <param name="e">The banner message.{</param>
+        void UserAuthenticationBannerReceived(object sender, MessageEventArgs<BannerMessage> e);
+
+        /// <summary>
+        /// Gets the supported authentication methods for this connection.
+        /// </summary>
+        /// <value>
+        /// The supported authentication methods for this connection.
+        /// </value>
+        IEnumerable<IAuthenticationMethod> AuthenticationMethods { get; }
+
+        /// <summary>
+        /// Creates a <see cref="NoneAuthenticationMethod"/> for the credentials represented
+        /// by the current <see cref="IConnectionInfo"/>.
+        /// </summary>
+        /// <returns>
+        /// A <see cref="NoneAuthenticationMethod"/> for the credentials represented by the
+        /// current <see cref="IConnectionInfo"/>.
+        /// </returns>
+        IAuthenticationMethod CreateNoneAuthenticationMethod();
+    }
+
     /// <summary>
     /// Represents remote connection information.
     /// </summary>
@@ -47,28 +76,8 @@ namespace Renci.SshNet
         TimeSpan Timeout { get; }
 
         /// <summary>
-        /// Gets the supported authentication methods for this connection.
-        /// </summary>
-        /// <value>
-        /// The supported authentication methods for this connection.
-        /// </value>
-        IEnumerable<IAuthenticationMethod> AuthenticationMethods { get; }
-
-        /// <summary>
-        /// Signals that an authentication banner message was received from the server.
-        /// </summary>
-        /// <param name="sender">The session in which the banner message was received.</param>
-        /// <param name="e">The banner message.{</param>
-        void UserAuthenticationBannerReceived(object sender, MessageEventArgs<BannerMessage> e);
-
-        /// <summary>
-        /// Creates a <see cref="NoneAuthenticationMethod"/> for the credentials represented
-        /// by the current <see cref="IConnectionInfo"/>.
+        /// Occurs when authentication banner is sent by the server.
         /// </summary>
-        /// <returns>
-        /// A <see cref="NoneAuthenticationMethod"/> for the credentials represented by the
-        /// current <see cref="IConnectionInfo"/>.
-        /// </returns>
-        IAuthenticationMethod CreateNoneAuthenticationMethod();
+        event EventHandler<AuthenticationBannerEventArgs> AuthenticationBanner;
     }
 }

+ 1 - 1
Renci.SshClient/Renci.SshNet/IForwardedPort.cs

@@ -5,7 +5,7 @@ namespace Renci.SshNet
     /// <summary>
     /// Supports port forwarding functionality.
     /// </summary>
-    internal interface IForwardedPort
+    public interface IForwardedPort
     {
         /// <summary>
         /// The <see cref="Closing"/> event occurs as the forwarded port is being stopped.

+ 7 - 0
Renci.SshClient/Renci.SshNet/IServiceFactory.cs

@@ -0,0 +1,7 @@
+namespace Renci.SshNet
+{
+    internal interface IServiceFactory
+    {
+        ISession CreateSession(ConnectionInfo connectionInfo);
+    }
+}

+ 39 - 6
Renci.SshClient/Renci.SshNet/ISession.cs

@@ -12,12 +12,12 @@ namespace Renci.SshNet
     /// <summary>
     /// Provides functionality to connect and interact with SSH server.
     /// </summary>
-    internal interface ISession
+    internal interface ISession : IDisposable
     {
-        ///// <summary>
-        ///// Gets or sets the connection info.
-        ///// </summary>
-        ///// <value>The connection info.</value>
+        /// <summary>
+        /// Gets or sets the connection info.
+        /// </summary>
+        /// <value>The connection info.</value>
         IConnectionInfo ConnectionInfo { get; }
 
         /// <summary>
@@ -53,6 +53,15 @@ namespace Renci.SshNet
         /// </value>
         WaitHandle MessageListenerCompleted { get; }
 
+        /// <summary>
+        /// Connects to the server.
+        /// </summary>
+        /// <exception cref="SocketException">Socket connection to the SSH server or proxy server could not be established, or an error occurred while resolving the hostname.</exception>
+        /// <exception cref="SshConnectionException">SSH session could not be established.</exception>
+        /// <exception cref="SshAuthenticationException">Authentication of SSH session failed.</exception>
+        /// <exception cref="ProxyException">Failed to establish proxy connection.</exception>
+        void Connect();
+
         /// <summary>
         /// Create a new SSH session channel.
         /// </summary>
@@ -77,12 +86,31 @@ namespace Renci.SshNet
         /// </returns>
         IChannelForwardedTcpip CreateChannelForwardedTcpip(uint remoteChannelNumber, uint remoteWindowSize, uint remoteChannelDataPacketSize);
 
-       /// <summary>
+        /// <summary>
+        /// Disconnects from the server.
+        /// </summary>
+        /// <remarks>
+        /// This sends a <b>SSH_MSG_DISCONNECT</b> message to the server, waits for the
+        /// server to close the socket on its end and subsequently closes the client socket.
+        /// </remarks>
+        void Disconnect();
+
+        /// <summary>
+        /// Called when client is disconnecting from the server.
+        /// </summary>
+        void OnDisconnecting();
+
+        /// <summary>
         /// Registers SSH message with the session.
         /// </summary>
         /// <param name="messageName">The name of the message to register with the session.</param>
         void RegisterMessage(string messageName);
 
+        /// <summary>
+        /// Sends "keep alive" message to keep connection alive.
+        /// </summary>
+        void SendKeepAlive();
+
         /// <summary>
         /// Sends a message to the server.
         /// </summary>
@@ -192,6 +220,11 @@ namespace Renci.SshNet
         /// </summary>
         event EventHandler<ExceptionEventArgs> ErrorOccured;
 
+        /// <summary>
+        /// Occurs when host key received.
+        /// </summary>
+        event EventHandler<HostKeyEventArgs> HostKeyReceived;
+
         /// <summary>
         /// Occurs when <see cref="RequestSuccessMessage"/> message received
         /// </summary>

+ 1 - 1
Renci.SshClient/Renci.SshNet/Netconf/NetConfSession.cs

@@ -39,7 +39,7 @@ namespace Renci.SshNet.NetConf
         /// </summary>
         /// <param name="session">The session.</param>
         /// <param name="operationTimeout">The operation timeout.</param>
-        public NetConfSession(Session session, TimeSpan operationTimeout)
+        public NetConfSession(ISession session, TimeSpan operationTimeout)
             : base(session, "netconf", operationTimeout, Encoding.UTF8)
         {
             ClientCapabilities = new XmlDocument();

+ 2 - 0
Renci.SshClient/Renci.SshNet/Renci.SshNet.csproj

@@ -155,6 +155,7 @@
     <Compile Include="IAuthenticationMethod.cs" />
     <Compile Include="IConnectionInfo.cs" />
     <Compile Include="IForwardedPort.cs" />
+    <Compile Include="IServiceFactory.cs" />
     <Compile Include="ISession.cs" />
     <Compile Include="Messages\Transport\KeyExchangeEcdhInitMessage.cs" />
     <Compile Include="Messages\Transport\KeyExchangeEcdhReplyMessage.cs" />
@@ -171,6 +172,7 @@
     <Compile Include="Security\GroupExchangeHashData.cs" />
     <Compile Include="Security\KeyExchangeDiffieHellmanGroupExchangeShaBase.cs" />
     <Compile Include="Security\KeyExchangeEllipticCurveDiffieHellman.cs" />
+    <Compile Include="ServiceFactory.cs" />
     <Compile Include="ShellStream.NET40.cs" />
     <Compile Include="ExpectAsyncResult.cs" />
     <Compile Include="Security\KeyExchangeDiffieHellmanGroupSha1.cs" />

+ 7 - 7
Renci.SshClient/Renci.SshNet/ScpClient.NET.cs

@@ -29,7 +29,7 @@ namespace Renci.SshNet
                 throw new ArgumentException("path");
 
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -64,7 +64,7 @@ namespace Renci.SshNet
                 throw new ArgumentException("path");
 
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -106,7 +106,7 @@ namespace Renci.SshNet
                 throw new ArgumentNullException("fileInfo");
 
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -141,7 +141,7 @@ namespace Renci.SshNet
                 throw new ArgumentNullException("directoryInfo");
 
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -161,7 +161,7 @@ namespace Renci.SshNet
             }
         }
 
-        private void InternalUpload(ChannelSession channel, Stream input, FileInfo fileInfo, string filename)
+        private void InternalUpload(IChannelSession channel, Stream input, FileInfo fileInfo, string filename)
         {
             this.InternalSetTimestamp(channel, input, fileInfo.LastWriteTimeUtc, fileInfo.LastAccessTimeUtc);
             using (var source = fileInfo.OpenRead())
@@ -170,7 +170,7 @@ namespace Renci.SshNet
             }
         }
 
-        private void InternalUpload(ChannelSession channel, Stream input, DirectoryInfo directoryInfo)
+        private void InternalUpload(IChannelSession channel, Stream input, DirectoryInfo directoryInfo)
         {
             //  Upload files
             var files = directoryInfo.GetFiles();
@@ -194,7 +194,7 @@ namespace Renci.SshNet
             }
         }
 
-        private void InternalDownload(ChannelSession channel, Stream input, FileSystemInfo fileSystemInfo)
+        private void InternalDownload(IChannelSession channel, Stream input, FileSystemInfo fileSystemInfo)
         {
             DateTime modifiedTime = DateTime.Now;
             DateTime accessedTime = DateTime.Now;

+ 7 - 7
Renci.SshClient/Renci.SshNet/ScpClient.cs

@@ -152,7 +152,7 @@ namespace Renci.SshNet
         public void Upload(Stream source, string path)
         {
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -198,7 +198,7 @@ namespace Renci.SshNet
                 throw new ArgumentNullException("destination");
 
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -235,7 +235,7 @@ namespace Renci.SshNet
             }
         }
 
-        private void InternalSetTimestamp(ChannelSession channel, Stream input, DateTime lastWriteTime, DateTime lastAccessime)
+        private void InternalSetTimestamp(IChannelSession channel, Stream input, DateTime lastWriteTime, DateTime lastAccessime)
         {
             var zeroTime = new DateTime(1970, 1, 1, 0, 0, 0, 0, DateTimeKind.Utc);
             var modificationSeconds = (long)(lastWriteTime - zeroTime).TotalSeconds;
@@ -244,7 +244,7 @@ namespace Renci.SshNet
             this.CheckReturnCode(input);
         }
 
-        private void InternalUpload(ChannelSession channel, Stream input, Stream source, string filename)
+        private void InternalUpload(IChannelSession channel, Stream input, Stream source, string filename)
         {
             var length = source.Length;
 
@@ -271,7 +271,7 @@ namespace Renci.SshNet
             this.CheckReturnCode(input);
         }
 
-        private void InternalDownload(ChannelSession channel, Stream input, Stream output, string filename, long length)
+        private void InternalDownload(IChannelSession channel, Stream input, Stream output, string filename, long length)
         {
             var buffer = new byte[Math.Min(length, this.BufferSize)];
             var needToRead = length;
@@ -315,12 +315,12 @@ namespace Renci.SshNet
             }
         }
 
-        private void SendConfirmation(ChannelSession channel)
+        private void SendConfirmation(IChannelSession channel)
         {
             this.SendData(channel, new byte[] { 0 });
         }
 
-        private void SendConfirmation(ChannelSession channel, byte errorCode, string message)
+        private void SendConfirmation(IChannelSession channel, byte errorCode, string message)
         {
             this.SendData(channel, new[] { errorCode });
             this.SendData(channel, string.Format("{0}\n", message));

+ 10 - 0
Renci.SshClient/Renci.SshNet/ServiceFactory.cs

@@ -0,0 +1,10 @@
+namespace Renci.SshNet
+{
+    internal class ServiceFactory : IServiceFactory
+    {
+        public ISession CreateSession(ConnectionInfo connectionInfo)
+        {
+            return new Session(connectionInfo);
+        }
+    }
+}

+ 3 - 3
Renci.SshClient/Renci.SshNet/Session.cs

@@ -23,7 +23,7 @@ namespace Renci.SshNet
     /// <summary>
     /// Provides functionality to connect and interact with SSH server.
     /// </summary>
-    public partial class Session : IDisposable, ISession
+    public partial class Session : ISession
     {
         /// <summary>
         /// Specifies an infinite waiting period.
@@ -703,7 +703,7 @@ namespace Renci.SshNet
         /// <summary>
         /// Sends "keep alive" message to keep connection alive.
         /// </summary>
-        internal void SendKeepAlive()
+        void ISession.SendKeepAlive()
         {
             this.SendMessage(new IgnoreMessage());
         }
@@ -1405,7 +1405,7 @@ namespace Renci.SshNet
         /// <summary>
         /// Called when client is disconnecting from the server.
         /// </summary>
-        internal void OnDisconnecting()
+        void ISession.OnDisconnecting()
         {
             _isDisconnecting = true;
         }

+ 1 - 1
Renci.SshClient/Renci.SshNet/Sftp/SftpSession.cs

@@ -56,7 +56,7 @@ namespace Renci.SshNet.Sftp
             }
         }
 
-        public SftpSession(Session session, TimeSpan operationTimeout, Encoding encoding)
+        public SftpSession(ISession session, TimeSpan operationTimeout, Encoding encoding)
             : base(session, "sftp", operationTimeout, encoding)
         {
         }

+ 3 - 3
Renci.SshClient/Renci.SshNet/Shell.cs

@@ -13,7 +13,7 @@ namespace Renci.SshNet
     /// </summary>
     public partial class Shell : IDisposable
     {
-        private readonly Session _session;
+        private readonly ISession _session;
 
         private IChannelSession _channel;
 
@@ -88,7 +88,7 @@ namespace Renci.SshNet
         /// <param name="height">The height.</param>
         /// <param name="terminalModes">The terminal modes.</param>
         /// <param name="bufferSize">Size of the buffer for output stream.</param>
-        internal Shell(Session session, Stream input, Stream output, Stream extendedOutput, string terminalName, uint columns, uint rows, uint width, uint height, IDictionary<TerminalModes, uint> terminalModes, int bufferSize)
+        internal Shell(ISession session, Stream input, Stream output, Stream extendedOutput, string terminalName, uint columns, uint rows, uint width, uint height, IDictionary<TerminalModes, uint> terminalModes, int bufferSize)
         {
             this._session = session;
             this._input = input;
@@ -119,7 +119,7 @@ namespace Renci.SshNet
                 this.Starting(this, new EventArgs());
             }
 
-            this._channel = this._session.CreateClientChannel<ChannelSession>();
+            this._channel = this._session.CreateChannelSession();
             this._channel.DataReceived += Channel_DataReceived;
             this._channel.ExtendedDataReceived += Channel_ExtendedDataReceived;
             this._channel.Closed += Channel_Closed;

+ 62 - 17
Renci.SshClient/Renci.SshNet/SshClient.cs

@@ -17,6 +17,14 @@ namespace Renci.SshNet
         /// </summary>
         private readonly List<ForwardedPort> _forwardedPorts;
 
+        /// <summary>
+        /// Holds a value indicating whether the current instance is disposed.
+        /// </summary>
+        /// <value>
+        /// <c>true</c> if the current instance is disposed; otherwise, <c>false</c>.
+        /// </value>
+        private bool _isDisposed;
+
         private Stream _inputStream;
 
         /// <summary>
@@ -127,8 +135,26 @@ namespace Renci.SshNet
         /// If <paramref name="ownsConnectionInfo"/> is <c>true</c>, then the
         /// connection info will be disposed when this instance is disposed.
         /// </remarks>
-        private SshClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo)
-            : base(connectionInfo, ownsConnectionInfo)
+        internal SshClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo)
+            : base(connectionInfo, ownsConnectionInfo, new ServiceFactory())
+        {
+            _forwardedPorts = new List<ForwardedPort>();
+        }
+
+        /// <summary>
+        /// Initializes a new instance of the <see cref="SshClient"/> class.
+        /// </summary>
+        /// <param name="connectionInfo">The connection info.</param>
+        /// <param name="ownsConnectionInfo">Specified whether this instance owns the connection info.</param>
+        /// <param name="serviceFactory">The factory to use for creating new services.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is null.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serviceFactory"/> is null.</exception>
+        /// <remarks>
+        /// If <paramref name="ownsConnectionInfo"/> is <c>true</c>, then the
+        /// connection info will be disposed when this instance is disposed.
+        /// </remarks>
+        internal SshClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory)
+            : base(connectionInfo, ownsConnectionInfo, serviceFactory)
         {
             _forwardedPorts = new List<ForwardedPort>();
         }
@@ -164,11 +190,7 @@ namespace Renci.SshNet
             if (port == null)
                 throw new ArgumentNullException("port");
 
-            if (port.Session != null && port.Session != Session)
-                throw new InvalidOperationException("Forwarded port is already added to a different client.");
-
-            port.Session = Session;
-
+            AttachForwardedPort(port);
             _forwardedPorts.Add(port);
         }
 
@@ -185,11 +207,23 @@ namespace Renci.SshNet
             //  Stop port forwarding before removing it
             port.Stop();
 
-            port.Session = null;
-
+            DetachForwardedPort(port);
             _forwardedPorts.Remove(port);
         }
 
+        private void AttachForwardedPort(ForwardedPort port)
+        {
+            if (port.Session != null && port.Session != Session)
+                throw new InvalidOperationException("Forwarded port is already added to a different client.");
+
+            port.Session = Session;
+        }
+
+        private static void DetachForwardedPort(ForwardedPort port)
+        {
+            port.Session = null;
+        }
+
         /// <summary>
         /// Creates the command to be executed.
         /// </summary>
@@ -399,11 +433,12 @@ namespace Renci.SshNet
         {
             base.OnDisconnected();
 
-            foreach (var forwardedPort in _forwardedPorts.ToArray())
+            for (var i = _forwardedPorts.Count - 1; i >= 0; i--)
             {
-                RemoveForwardedPort(forwardedPort);
+                var port = _forwardedPorts[i];
+                DetachForwardedPort(port);
+                _forwardedPorts.RemoveAt(i);
             }
-
         }
 
         /// <summary>
@@ -412,13 +447,23 @@ namespace Renci.SshNet
         /// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged ResourceMessages.</param>
         protected override void Dispose(bool disposing)
         {
-            base.Dispose(disposing);
-
-            if (_inputStream != null)
+            if (!_isDisposed)
             {
-                _inputStream.Dispose();
-                _inputStream = null;
+                if (disposing)
+                {
+                    Disconnect();
+
+                    if (_inputStream != null)
+                    {
+                        _inputStream.Dispose();
+                        _inputStream = null;
+                    }
+                }
             }
+
+            base.Dispose(disposing);
+
+            _isDisposed = true;
         }
     }
 }