Browse Source

Added simple ServiceFactory to allow unit testing of SshClient.
When disposing SshClient, make sure to stop forwarded ports before the session is closed.

Gert Driesen 11 years ago
parent
commit
9366658bc4
25 changed files with 474 additions and 100 deletions
  1. 2 2
      Renci.SshClient/Renci.SshNet.Tests/Classes/ClientAuthenticationTestBase.cs
  2. 23 3
      Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortLocalTest_Dispose_PortDisposed.cs
  3. 60 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortLocalTest_Dispose_PortDisposed_NeverStarted.cs
  4. 0 4
      Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortRemoteTest_Dispose_PortDisposed.cs
  5. 75 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SshClientTest_Disconnect_ForwardedPortStarted.cs
  6. 97 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SshClientTest_Dispose_ForwardedPortStarted.cs
  7. 3 0
      Renci.SshClient/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj
  8. 27 5
      Renci.SshClient/Renci.SshNet/BaseClient.cs
  9. 1 1
      Renci.SshClient/Renci.SshNet/ClientAuthentication.cs
  10. 4 4
      Renci.SshClient/Renci.SshNet/ConnectionInfo.cs
  11. 2 5
      Renci.SshClient/Renci.SshNet/ForwardedPort.cs
  12. 6 8
      Renci.SshClient/Renci.SshNet/ForwardedPortDynamic.NET.cs
  13. 31 22
      Renci.SshClient/Renci.SshNet/IConnectionInfo.cs
  14. 1 1
      Renci.SshClient/Renci.SshNet/IForwardedPort.cs
  15. 7 0
      Renci.SshClient/Renci.SshNet/IServiceFactory.cs
  16. 39 6
      Renci.SshClient/Renci.SshNet/ISession.cs
  17. 1 1
      Renci.SshClient/Renci.SshNet/Netconf/NetConfSession.cs
  18. 2 0
      Renci.SshClient/Renci.SshNet/Renci.SshNet.csproj
  19. 7 7
      Renci.SshClient/Renci.SshNet/ScpClient.NET.cs
  20. 7 7
      Renci.SshClient/Renci.SshNet/ScpClient.cs
  21. 10 0
      Renci.SshClient/Renci.SshNet/ServiceFactory.cs
  22. 3 3
      Renci.SshClient/Renci.SshNet/Session.cs
  23. 1 1
      Renci.SshClient/Renci.SshNet/Sftp/SftpSession.cs
  24. 3 3
      Renci.SshClient/Renci.SshNet/Shell.cs
  25. 62 17
      Renci.SshClient/Renci.SshNet/SshClient.cs

+ 2 - 2
Renci.SshClient/Renci.SshNet.Tests/Classes/ClientAuthenticationTestBase.cs

@@ -10,7 +10,7 @@ namespace Renci.SshNet.Tests.Classes
     [TestClass]
     public abstract class ClientAuthenticationTestBase : TestBase
     {
-        internal Mock<IConnectionInfo> ConnectionInfoMock { get; private set; }
+        internal Mock<IConnectionInfoInternal> ConnectionInfoMock { get; private set; }
         internal Mock<ISession> SessionMock { get; private set; }
         internal Mock<IAuthenticationMethod> NoneAuthenticationMethodMock { get; private set; }
         internal Mock<IAuthenticationMethod> PasswordAuthenticationMethodMock { get; private set; }
@@ -20,7 +20,7 @@ namespace Renci.SshNet.Tests.Classes
 
         protected void CreateMocks()
         {
-            ConnectionInfoMock = new Mock<IConnectionInfo>(MockBehavior.Strict);
+            ConnectionInfoMock = new Mock<IConnectionInfoInternal>(MockBehavior.Strict);
             SessionMock = new Mock<ISession>(MockBehavior.Strict);
             NoneAuthenticationMethodMock = new Mock<IAuthenticationMethod>(MockBehavior.Strict);
             PasswordAuthenticationMethodMock = new Mock<IAuthenticationMethod>(MockBehavior.Strict);

+ 23 - 3
Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortLocalTest_Dispose_PortDisposed.cs

@@ -1,6 +1,8 @@
 using System;
 using System.Collections.Generic;
+using System.Net;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
 using Renci.SshNet.Common;
 
 namespace Renci.SshNet.Tests.Classes
@@ -8,6 +10,8 @@ namespace Renci.SshNet.Tests.Classes
     [TestClass]
     public class ForwardedPortLocalTest_Dispose_PortDisposed
     {
+        private Mock<ISession> _sessionMock;
+        private Mock<IConnectionInfo> _connectionInfoMock;
         private ForwardedPortLocal _forwardedPort;
         private IList<EventArgs> _closingRegister;
         private IList<ExceptionEventArgs> _exceptionRegister;
@@ -34,9 +38,19 @@ namespace Renci.SshNet.Tests.Classes
             _closingRegister = new List<EventArgs>();
             _exceptionRegister = new List<ExceptionEventArgs>();
 
-            _forwardedPort = new ForwardedPortLocal("boundHost", "host", 22);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+            _connectionInfoMock = new Mock<IConnectionInfo>(MockBehavior.Strict);
+
+            var sequence = new MockSequence();
+            _sessionMock.InSequence(sequence).Setup(p => p.IsConnected).Returns(true);
+            _sessionMock.InSequence(sequence).Setup(p => p.ConnectionInfo).Returns(_connectionInfoMock.Object);
+            _connectionInfoMock.InSequence(sequence).Setup(p => p.Timeout).Returns(TimeSpan.FromSeconds(30));
+
+            _forwardedPort = new ForwardedPortLocal(IPAddress.Loopback.ToString(), "host", 22);
             _forwardedPort.Closing += (sender, args) => _closingRegister.Add(args);
             _forwardedPort.Exception += (sender, args) => _exceptionRegister.Add(args);
+            _forwardedPort.Session = _sessionMock.Object;
+            _forwardedPort.Start();
             _forwardedPort.Dispose();
         }
 
@@ -46,9 +60,9 @@ namespace Renci.SshNet.Tests.Classes
         }
 
         [TestMethod]
-        public void ClosingShouldNotHaveFired()
+        public void ClosingShouldHaveFiredOnce()
         {
-            Assert.AreEqual(0, _closingRegister.Count);
+            Assert.AreEqual(1, _closingRegister.Count);
         }
 
         [TestMethod]
@@ -56,5 +70,11 @@ namespace Renci.SshNet.Tests.Classes
         {
             Assert.AreEqual(0, _exceptionRegister.Count);
         }
+
+        [TestMethod]
+        public void SessionShouldBeNull()
+        {
+            Assert.IsNull(_forwardedPort.Session);
+        }
     }
 }

+ 60 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortLocalTest_Dispose_PortDisposed_NeverStarted.cs

@@ -0,0 +1,60 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Renci.SshNet.Common;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class ForwardedPortLocalTest_Dispose_PortDisposed_NeverStarted
+    {
+        private ForwardedPortLocal _forwardedPort;
+        private IList<EventArgs> _closingRegister;
+        private IList<ExceptionEventArgs> _exceptionRegister;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+            if (_forwardedPort != null)
+            {
+                _forwardedPort.Dispose();
+                _forwardedPort = null;
+            }
+        }
+
+        protected void Arrange()
+        {
+            _closingRegister = new List<EventArgs>();
+            _exceptionRegister = new List<ExceptionEventArgs>();
+
+            _forwardedPort = new ForwardedPortLocal("boundHost", "host", 22);
+            _forwardedPort.Closing += (sender, args) => _closingRegister.Add(args);
+            _forwardedPort.Exception += (sender, args) => _exceptionRegister.Add(args);
+            _forwardedPort.Dispose();
+        }
+
+        protected void Act()
+        {
+            _forwardedPort.Dispose();
+        }
+
+        [TestMethod]
+        public void ClosingShouldNotHaveFired()
+        {
+            Assert.AreEqual(0, _closingRegister.Count);
+        }
+
+        [TestMethod]
+        public void ExceptionShouldNotHaveFired()
+        {
+            Assert.AreEqual(0, _exceptionRegister.Count);
+        }
+    }
+}

+ 0 - 4
Renci.SshClient/Renci.SshNet.Tests/Classes/ForwardedPortRemoteTest_Dispose_PortDisposed.cs

@@ -1,12 +1,8 @@
 using System;
 using System.Collections.Generic;
-using System.Globalization;
 using System.Net;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Moq;
-using Renci.SshNet.Channels;
 using Renci.SshNet.Common;
-using Renci.SshNet.Messages.Connection;
 
 namespace Renci.SshNet.Tests.Classes
 {

+ 75 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SshClientTest_Disconnect_ForwardedPortStarted.cs

@@ -0,0 +1,75 @@
+using System.Linq;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SshClientTest_Disconnect_ForwardedPortStarted
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<ISession> _sessionMock;
+        private Mock<ForwardedPort> _forwardedPortMock;
+        private SshClient _sshClient;
+        private ConnectionInfo _connectionInfo;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+        }
+
+        protected void Arrange()
+        {
+            _connectionInfo = new ConnectionInfo("host", "user", new NoneAuthenticationMethod("userauth"));
+
+            var sequence = new MockSequence();
+
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+            _forwardedPortMock = new Mock<ForwardedPort>(MockBehavior.Strict);
+
+            _serviceFactoryMock.InSequence(sequence).Setup(p => p.CreateSession(_connectionInfo)).Returns(_sessionMock.Object);
+            _sessionMock.InSequence(sequence).Setup(p => p.Connect());
+            _forwardedPortMock.InSequence(sequence).Setup(p => p.Start());
+            _sessionMock.InSequence(sequence).Setup(p => p.OnDisconnecting());
+            _forwardedPortMock.InSequence(sequence).Setup(p => p.Stop());
+            _sessionMock.InSequence(sequence).Setup(p => p.Disconnect());
+
+            _sshClient = new SshClient(_connectionInfo, false, _serviceFactoryMock.Object);
+            _sshClient.Connect();
+            _sshClient.AddForwardedPort(_forwardedPortMock.Object);
+
+            _forwardedPortMock.Object.Start();
+        }
+
+        protected void Act()
+        {
+            _sshClient.Disconnect();
+        }
+
+        [TestMethod]
+        public void ForwardedPortShouldBeStopped()
+        {
+            _forwardedPortMock.Verify(p => p.Stop(), Times.Once);
+        }
+
+        [TestMethod]
+        public void ForwardedPortShouldBeRemovedFromSshClient()
+        {
+            Assert.IsFalse(_sshClient.ForwardedPorts.Any());
+        }
+
+        [TestMethod]
+        public void DisconnectOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.Disconnect(), Times.Once);
+        }
+    }
+}

+ 97 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SshClientTest_Dispose_ForwardedPortStarted.cs

@@ -0,0 +1,97 @@
+using System;
+using System.Linq;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SshClientTest_Dispose_ForwardedPortStarted
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<ISession> _sessionMock;
+        private Mock<ForwardedPort> _forwardedPortMock;
+        private SshClient _sshClient;
+        private ConnectionInfo _connectionInfo;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+        }
+
+        protected void Arrange()
+        {
+            _connectionInfo = new ConnectionInfo("host", "user", new NoneAuthenticationMethod("userauth"));
+
+            var sequence = new MockSequence();
+
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+            _forwardedPortMock = new Mock<ForwardedPort>(MockBehavior.Strict);
+
+            _serviceFactoryMock.InSequence(sequence).Setup(p => p.CreateSession(_connectionInfo)).Returns(_sessionMock.Object);
+            _sessionMock.InSequence(sequence).Setup(p => p.Connect());
+            _forwardedPortMock.InSequence(sequence).Setup(p => p.Start());
+            _sessionMock.InSequence(sequence).Setup(p => p.OnDisconnecting());
+            _forwardedPortMock.InSequence(sequence).Setup(p => p.Stop());
+            _sessionMock.InSequence(sequence).Setup(p => p.Disconnect());
+            _sessionMock.InSequence(sequence).Setup(p => p.Dispose());
+
+            _sshClient = new SshClient(_connectionInfo, false, _serviceFactoryMock.Object);
+            _sshClient.Connect();
+            _sshClient.AddForwardedPort(_forwardedPortMock.Object);
+
+            _forwardedPortMock.Object.Start();
+        }
+
+        protected void Act()
+        {
+            _sshClient.Dispose();
+        }
+
+        [TestMethod]
+        public void ForwardedPortShouldBeStopped()
+        {
+            _forwardedPortMock.Verify(p => p.Stop(), Times.Once);
+        }
+
+        [TestMethod]
+        public void ForwardedPortShouldBeRemovedFromSshClient()
+        {
+            Assert.IsFalse(_sshClient.ForwardedPorts.Any());
+        }
+
+        [TestMethod]
+        public void IsConnectedShouldThrowObjectDisposedException()
+        {
+            try
+            {
+                var connected = _sshClient.IsConnected;
+                Assert.Fail("IsConnected should have thrown {0} but returned {1}.",
+                    typeof (ObjectDisposedException).FullName, connected);
+            }
+            catch (ObjectDisposedException)
+            {
+            }
+        }
+
+        [TestMethod]
+        public void DisconnectOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.Disconnect(), Times.Once);
+        }
+
+        [TestMethod]
+        public void DisposeOnSessionShouldBeInvokedOnce()
+        {
+            _sessionMock.Verify(p => p.Disconnect(), Times.Once);
+        }
+    }
+}

+ 3 - 0
Renci.SshClient/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj

@@ -113,6 +113,7 @@
     <Compile Include="Classes\ForwardedPortDynamicTest_Start_PortStopped.cs" />
     <Compile Include="Classes\ForwardedPortDynamicTest_Start_PortStarted.cs" />
     <Compile Include="Classes\ForwardedPortLocalTest_Dispose_PortDisposed.cs" />
+    <Compile Include="Classes\ForwardedPortLocalTest_Dispose_PortDisposed_NeverStarted.cs" />
     <Compile Include="Classes\ForwardedPortLocalTest_Dispose_PortNeverStarted.cs" />
     <Compile Include="Classes\ForwardedPortLocalTest_Dispose_PortStarted_ChannelNotBound.cs" />
     <Compile Include="Classes\ForwardedPortLocalTest_Dispose_PortStopped.cs" />
@@ -305,6 +306,8 @@
     <Compile Include="Classes\ShellTestTest.cs" />
     <Compile Include="Classes\ShellStreamTest.cs" />
     <Compile Include="Classes\SshClientTest.cs" />
+    <Compile Include="Classes\SshClientTest_Disconnect_ForwardedPortStarted.cs" />
+    <Compile Include="Classes\SshClientTest_Dispose_ForwardedPortStarted.cs" />
     <Compile Include="Classes\SshCommandTest.cs" />
     <Compile Include="Classes\ForwardedPortLocalTest.NET40.cs" />
     <Compile Include="Classes\SshCommandTest.NET40.cs" />

+ 27 - 5
Renci.SshClient/Renci.SshNet/BaseClient.cs

@@ -14,6 +14,8 @@ namespace Renci.SshNet
         /// Holds value indicating whether the connection info is owned by this client.
         /// </summary>
         private readonly bool _ownsConnectionInfo;
+
+        private readonly IServiceFactory _serviceFactory;
         private TimeSpan _keepAliveInterval;
         private Timer _keepAliveTimer;
         private ConnectionInfo _connectionInfo;
@@ -21,7 +23,7 @@ namespace Renci.SshNet
         /// <summary>
         /// Gets current session.
         /// </summary>
-        protected Session Session { get; private set; }
+        internal ISession Session { get; private set; }
 
         /// <summary>
         /// Gets the connection info.
@@ -81,7 +83,7 @@ namespace Renci.SshNet
                 if (value == _keepAliveInterval)
                     return;
 
-                if (value == Session.InfiniteTimeSpan)
+                if (value == SshNet.Session.InfiniteTimeSpan)
                 {
                     // stop the timer when the value is -1 milliseconds
                     StopKeepAliveTimer();
@@ -127,13 +129,33 @@ namespace Renci.SshNet
         /// connection info will be disposed when this instance is disposed.
         /// </remarks>
         protected BaseClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo)
+            : this(connectionInfo, ownsConnectionInfo, new ServiceFactory())
+        {
+        }
+
+        /// <summary>
+        /// Initializes a new instance of the <see cref="BaseClient"/> class.
+        /// </summary>
+        /// <param name="connectionInfo">The connection info.</param>
+        /// <param name="ownsConnectionInfo">Specified whether this instance owns the connection info.</param>
+        /// <param name="serviceFactory">The factory to use for creating new services.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is null.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serviceFactory"/> is null.</exception>
+        /// <remarks>
+        /// If <paramref name="ownsConnectionInfo"/> is <c>true</c>, then the
+        /// connection info will be disposed when this instance is disposed.
+        /// </remarks>
+        internal BaseClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory)
         {
             if (connectionInfo == null)
                 throw new ArgumentNullException("connectionInfo");
+            if (serviceFactory == null)
+                throw new ArgumentNullException("serviceFactory");
 
             ConnectionInfo = connectionInfo;
             _ownsConnectionInfo = ownsConnectionInfo;
-            _keepAliveInterval = Session.InfiniteTimeSpan;
+            _serviceFactory = serviceFactory;
+            _keepAliveInterval = SshNet.Session.InfiniteTimeSpan;
         }
 
         /// <summary>
@@ -169,7 +191,7 @@ namespace Renci.SshNet
                 throw new InvalidOperationException("The client is already connected.");
 
             OnConnecting();
-            Session = new Session(ConnectionInfo);
+            Session = _serviceFactory.CreateSession(ConnectionInfo);
             Session.HostKeyReceived += Session_HostKeyReceived;
             Session.ErrorOccured += Session_ErrorOccured;
             Session.Connect();
@@ -353,7 +375,7 @@ namespace Renci.SshNet
         /// </remarks>
         private void StartKeepAliveTimer()
         {
-            if (_keepAliveInterval == Session.InfiniteTimeSpan)
+            if (_keepAliveInterval == SshNet.Session.InfiniteTimeSpan)
                 return;
 
             if (_keepAliveTimer == null)

+ 1 - 1
Renci.SshClient/Renci.SshNet/ClientAuthentication.cs

@@ -7,7 +7,7 @@ namespace Renci.SshNet
 {
     internal class ClientAuthentication
     {
-        public void Authenticate(IConnectionInfo connectionInfo, ISession session)
+        public void Authenticate(IConnectionInfoInternal connectionInfo, ISession session)
         {
             if (connectionInfo == null)
                 throw new ArgumentNullException("connectionInfo");

+ 4 - 4
Renci.SshClient/Renci.SshNet/ConnectionInfo.cs

@@ -19,7 +19,7 @@ namespace Renci.SshNet
     /// This class is NOT thread-safe. Do not use the same <see cref="ConnectionInfo"/> with multiple
     /// client instances.
     /// </remarks>
-    public class ConnectionInfo : IConnectionInfo
+    public class ConnectionInfo : IConnectionInfoInternal
     {
         internal static int DEFAULT_PORT = 22;
 
@@ -417,7 +417,7 @@ namespace Renci.SshNet
         /// </summary>
         /// <param name="sender">The session in which the banner message was received.</param>
         /// <param name="e">The banner message.{</param>
-        void IConnectionInfo.UserAuthenticationBannerReceived(object sender, MessageEventArgs<BannerMessage> e)
+        void IConnectionInfoInternal.UserAuthenticationBannerReceived(object sender, MessageEventArgs<BannerMessage> e)
         {
             var authenticationBanner = AuthenticationBanner;
             if (authenticationBanner != null)
@@ -427,12 +427,12 @@ namespace Renci.SshNet
             }
         }
 
-        IAuthenticationMethod IConnectionInfo.CreateNoneAuthenticationMethod()
+        IAuthenticationMethod IConnectionInfoInternal.CreateNoneAuthenticationMethod()
         {
             return new NoneAuthenticationMethod(Username);
         }
 
-        IEnumerable<IAuthenticationMethod> IConnectionInfo.AuthenticationMethods
+        IEnumerable<IAuthenticationMethod> IConnectionInfoInternal.AuthenticationMethods
         {
             get { return AuthenticationMethods.Cast<IAuthenticationMethod>(); }
         }

+ 2 - 5
Renci.SshClient/Renci.SshNet/ForwardedPort.cs

@@ -69,7 +69,7 @@ namespace Renci.SshNet
         /// <summary>
         /// Stops port forwarding.
         /// </summary>
-        public void Stop()
+        public virtual void Stop()
         {
             CheckDisposed();
 
@@ -111,12 +111,9 @@ namespace Renci.SshNet
                 {
                     Session.ErrorOccured -= Session_ErrorOccured;
                     StopPort(Session.ConnectionInfo.Timeout);
+                    Session = null;
                 }
             }
-            else
-            {
-                StopPort(TimeSpan.Zero);
-            }
         }
 
         /// <summary>

+ 6 - 8
Renci.SshClient/Renci.SshNet/ForwardedPortDynamic.NET.cs

@@ -60,6 +60,12 @@ namespace Renci.SshNet
                     }
                     finally
                     {
+                        if (Session != null)
+                        {
+                            Session.ErrorOccured -= Session_ErrorOccured;
+                            Session.Disconnected -= Session_Disconnected;
+                        }
+
                         // mark listener stopped
                         _listenerCompleted.Set();
                     }
@@ -185,9 +191,6 @@ namespace Renci.SshNet
             if (!IsStarted)
                 return;
 
-            Session.ErrorOccured -= Session_ErrorOccured;
-            Session.Disconnected -= Session_Disconnected;
-
             // close listener socket
             _listener.Close();
             // wait for listener loop to finish
@@ -229,11 +232,6 @@ namespace Renci.SshNet
         {
             if (disposing)
             {
-                if (Session != null)
-                {
-                    Session.ErrorOccured -= Session_ErrorOccured;
-                    Session.Disconnected -= Session_Disconnected;
-                }
                 if (_listener != null)
                 {
                     _listener.Dispose();

+ 31 - 22
Renci.SshClient/Renci.SshNet/IConnectionInfo.cs

@@ -1,11 +1,40 @@
 using System;
 using System.Collections.Generic;
 using System.Text;
+using Renci.SshNet.Common;
 using Renci.SshNet.Messages.Authentication;
 using Renci.SshNet.Messages.Connection;
 
 namespace Renci.SshNet
 {
+    internal interface IConnectionInfoInternal : IConnectionInfo
+    {
+        /// <summary>
+        /// Signals that an authentication banner message was received from the server.
+        /// </summary>
+        /// <param name="sender">The session in which the banner message was received.</param>
+        /// <param name="e">The banner message.{</param>
+        void UserAuthenticationBannerReceived(object sender, MessageEventArgs<BannerMessage> e);
+
+        /// <summary>
+        /// Gets the supported authentication methods for this connection.
+        /// </summary>
+        /// <value>
+        /// The supported authentication methods for this connection.
+        /// </value>
+        IEnumerable<IAuthenticationMethod> AuthenticationMethods { get; }
+
+        /// <summary>
+        /// Creates a <see cref="NoneAuthenticationMethod"/> for the credentials represented
+        /// by the current <see cref="IConnectionInfo"/>.
+        /// </summary>
+        /// <returns>
+        /// A <see cref="NoneAuthenticationMethod"/> for the credentials represented by the
+        /// current <see cref="IConnectionInfo"/>.
+        /// </returns>
+        IAuthenticationMethod CreateNoneAuthenticationMethod();
+    }
+
     /// <summary>
     /// Represents remote connection information.
     /// </summary>
@@ -47,28 +76,8 @@ namespace Renci.SshNet
         TimeSpan Timeout { get; }
 
         /// <summary>
-        /// Gets the supported authentication methods for this connection.
-        /// </summary>
-        /// <value>
-        /// The supported authentication methods for this connection.
-        /// </value>
-        IEnumerable<IAuthenticationMethod> AuthenticationMethods { get; }
-
-        /// <summary>
-        /// Signals that an authentication banner message was received from the server.
-        /// </summary>
-        /// <param name="sender">The session in which the banner message was received.</param>
-        /// <param name="e">The banner message.{</param>
-        void UserAuthenticationBannerReceived(object sender, MessageEventArgs<BannerMessage> e);
-
-        /// <summary>
-        /// Creates a <see cref="NoneAuthenticationMethod"/> for the credentials represented
-        /// by the current <see cref="IConnectionInfo"/>.
+        /// Occurs when authentication banner is sent by the server.
         /// </summary>
-        /// <returns>
-        /// A <see cref="NoneAuthenticationMethod"/> for the credentials represented by the
-        /// current <see cref="IConnectionInfo"/>.
-        /// </returns>
-        IAuthenticationMethod CreateNoneAuthenticationMethod();
+        event EventHandler<AuthenticationBannerEventArgs> AuthenticationBanner;
     }
 }

+ 1 - 1
Renci.SshClient/Renci.SshNet/IForwardedPort.cs

@@ -5,7 +5,7 @@ namespace Renci.SshNet
     /// <summary>
     /// Supports port forwarding functionality.
     /// </summary>
-    internal interface IForwardedPort
+    public interface IForwardedPort
     {
         /// <summary>
         /// The <see cref="Closing"/> event occurs as the forwarded port is being stopped.

+ 7 - 0
Renci.SshClient/Renci.SshNet/IServiceFactory.cs

@@ -0,0 +1,7 @@
+namespace Renci.SshNet
+{
+    internal interface IServiceFactory
+    {
+        ISession CreateSession(ConnectionInfo connectionInfo);
+    }
+}

+ 39 - 6
Renci.SshClient/Renci.SshNet/ISession.cs

@@ -12,12 +12,12 @@ namespace Renci.SshNet
     /// <summary>
     /// Provides functionality to connect and interact with SSH server.
     /// </summary>
-    internal interface ISession
+    internal interface ISession : IDisposable
     {
-        ///// <summary>
-        ///// Gets or sets the connection info.
-        ///// </summary>
-        ///// <value>The connection info.</value>
+        /// <summary>
+        /// Gets or sets the connection info.
+        /// </summary>
+        /// <value>The connection info.</value>
         IConnectionInfo ConnectionInfo { get; }
 
         /// <summary>
@@ -53,6 +53,15 @@ namespace Renci.SshNet
         /// </value>
         WaitHandle MessageListenerCompleted { get; }
 
+        /// <summary>
+        /// Connects to the server.
+        /// </summary>
+        /// <exception cref="SocketException">Socket connection to the SSH server or proxy server could not be established, or an error occurred while resolving the hostname.</exception>
+        /// <exception cref="SshConnectionException">SSH session could not be established.</exception>
+        /// <exception cref="SshAuthenticationException">Authentication of SSH session failed.</exception>
+        /// <exception cref="ProxyException">Failed to establish proxy connection.</exception>
+        void Connect();
+
         /// <summary>
         /// Create a new SSH session channel.
         /// </summary>
@@ -77,12 +86,31 @@ namespace Renci.SshNet
         /// </returns>
         IChannelForwardedTcpip CreateChannelForwardedTcpip(uint remoteChannelNumber, uint remoteWindowSize, uint remoteChannelDataPacketSize);
 
-       /// <summary>
+        /// <summary>
+        /// Disconnects from the server.
+        /// </summary>
+        /// <remarks>
+        /// This sends a <b>SSH_MSG_DISCONNECT</b> message to the server, waits for the
+        /// server to close the socket on its end and subsequently closes the client socket.
+        /// </remarks>
+        void Disconnect();
+
+        /// <summary>
+        /// Called when client is disconnecting from the server.
+        /// </summary>
+        void OnDisconnecting();
+
+        /// <summary>
         /// Registers SSH message with the session.
         /// </summary>
         /// <param name="messageName">The name of the message to register with the session.</param>
         void RegisterMessage(string messageName);
 
+        /// <summary>
+        /// Sends "keep alive" message to keep connection alive.
+        /// </summary>
+        void SendKeepAlive();
+
         /// <summary>
         /// Sends a message to the server.
         /// </summary>
@@ -192,6 +220,11 @@ namespace Renci.SshNet
         /// </summary>
         event EventHandler<ExceptionEventArgs> ErrorOccured;
 
+        /// <summary>
+        /// Occurs when host key received.
+        /// </summary>
+        event EventHandler<HostKeyEventArgs> HostKeyReceived;
+
         /// <summary>
         /// Occurs when <see cref="RequestSuccessMessage"/> message received
         /// </summary>

+ 1 - 1
Renci.SshClient/Renci.SshNet/Netconf/NetConfSession.cs

@@ -39,7 +39,7 @@ namespace Renci.SshNet.NetConf
         /// </summary>
         /// <param name="session">The session.</param>
         /// <param name="operationTimeout">The operation timeout.</param>
-        public NetConfSession(Session session, TimeSpan operationTimeout)
+        public NetConfSession(ISession session, TimeSpan operationTimeout)
             : base(session, "netconf", operationTimeout, Encoding.UTF8)
         {
             ClientCapabilities = new XmlDocument();

+ 2 - 0
Renci.SshClient/Renci.SshNet/Renci.SshNet.csproj

@@ -155,6 +155,7 @@
     <Compile Include="IAuthenticationMethod.cs" />
     <Compile Include="IConnectionInfo.cs" />
     <Compile Include="IForwardedPort.cs" />
+    <Compile Include="IServiceFactory.cs" />
     <Compile Include="ISession.cs" />
     <Compile Include="Messages\Transport\KeyExchangeEcdhInitMessage.cs" />
     <Compile Include="Messages\Transport\KeyExchangeEcdhReplyMessage.cs" />
@@ -171,6 +172,7 @@
     <Compile Include="Security\GroupExchangeHashData.cs" />
     <Compile Include="Security\KeyExchangeDiffieHellmanGroupExchangeShaBase.cs" />
     <Compile Include="Security\KeyExchangeEllipticCurveDiffieHellman.cs" />
+    <Compile Include="ServiceFactory.cs" />
     <Compile Include="ShellStream.NET40.cs" />
     <Compile Include="ExpectAsyncResult.cs" />
     <Compile Include="Security\KeyExchangeDiffieHellmanGroupSha1.cs" />

+ 7 - 7
Renci.SshClient/Renci.SshNet/ScpClient.NET.cs

@@ -29,7 +29,7 @@ namespace Renci.SshNet
                 throw new ArgumentException("path");
 
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -64,7 +64,7 @@ namespace Renci.SshNet
                 throw new ArgumentException("path");
 
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -106,7 +106,7 @@ namespace Renci.SshNet
                 throw new ArgumentNullException("fileInfo");
 
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -141,7 +141,7 @@ namespace Renci.SshNet
                 throw new ArgumentNullException("directoryInfo");
 
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -161,7 +161,7 @@ namespace Renci.SshNet
             }
         }
 
-        private void InternalUpload(ChannelSession channel, Stream input, FileInfo fileInfo, string filename)
+        private void InternalUpload(IChannelSession channel, Stream input, FileInfo fileInfo, string filename)
         {
             this.InternalSetTimestamp(channel, input, fileInfo.LastWriteTimeUtc, fileInfo.LastAccessTimeUtc);
             using (var source = fileInfo.OpenRead())
@@ -170,7 +170,7 @@ namespace Renci.SshNet
             }
         }
 
-        private void InternalUpload(ChannelSession channel, Stream input, DirectoryInfo directoryInfo)
+        private void InternalUpload(IChannelSession channel, Stream input, DirectoryInfo directoryInfo)
         {
             //  Upload files
             var files = directoryInfo.GetFiles();
@@ -194,7 +194,7 @@ namespace Renci.SshNet
             }
         }
 
-        private void InternalDownload(ChannelSession channel, Stream input, FileSystemInfo fileSystemInfo)
+        private void InternalDownload(IChannelSession channel, Stream input, FileSystemInfo fileSystemInfo)
         {
             DateTime modifiedTime = DateTime.Now;
             DateTime accessedTime = DateTime.Now;

+ 7 - 7
Renci.SshClient/Renci.SshNet/ScpClient.cs

@@ -152,7 +152,7 @@ namespace Renci.SshNet
         public void Upload(Stream source, string path)
         {
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -198,7 +198,7 @@ namespace Renci.SshNet
                 throw new ArgumentNullException("destination");
 
             using (var input = new PipeStream())
-            using (var channel = this.Session.CreateClientChannel<ChannelSession>())
+            using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
                 {
@@ -235,7 +235,7 @@ namespace Renci.SshNet
             }
         }
 
-        private void InternalSetTimestamp(ChannelSession channel, Stream input, DateTime lastWriteTime, DateTime lastAccessime)
+        private void InternalSetTimestamp(IChannelSession channel, Stream input, DateTime lastWriteTime, DateTime lastAccessime)
         {
             var zeroTime = new DateTime(1970, 1, 1, 0, 0, 0, 0, DateTimeKind.Utc);
             var modificationSeconds = (long)(lastWriteTime - zeroTime).TotalSeconds;
@@ -244,7 +244,7 @@ namespace Renci.SshNet
             this.CheckReturnCode(input);
         }
 
-        private void InternalUpload(ChannelSession channel, Stream input, Stream source, string filename)
+        private void InternalUpload(IChannelSession channel, Stream input, Stream source, string filename)
         {
             var length = source.Length;
 
@@ -271,7 +271,7 @@ namespace Renci.SshNet
             this.CheckReturnCode(input);
         }
 
-        private void InternalDownload(ChannelSession channel, Stream input, Stream output, string filename, long length)
+        private void InternalDownload(IChannelSession channel, Stream input, Stream output, string filename, long length)
         {
             var buffer = new byte[Math.Min(length, this.BufferSize)];
             var needToRead = length;
@@ -315,12 +315,12 @@ namespace Renci.SshNet
             }
         }
 
-        private void SendConfirmation(ChannelSession channel)
+        private void SendConfirmation(IChannelSession channel)
         {
             this.SendData(channel, new byte[] { 0 });
         }
 
-        private void SendConfirmation(ChannelSession channel, byte errorCode, string message)
+        private void SendConfirmation(IChannelSession channel, byte errorCode, string message)
         {
             this.SendData(channel, new[] { errorCode });
             this.SendData(channel, string.Format("{0}\n", message));

+ 10 - 0
Renci.SshClient/Renci.SshNet/ServiceFactory.cs

@@ -0,0 +1,10 @@
+namespace Renci.SshNet
+{
+    internal class ServiceFactory : IServiceFactory
+    {
+        public ISession CreateSession(ConnectionInfo connectionInfo)
+        {
+            return new Session(connectionInfo);
+        }
+    }
+}

+ 3 - 3
Renci.SshClient/Renci.SshNet/Session.cs

@@ -23,7 +23,7 @@ namespace Renci.SshNet
     /// <summary>
     /// Provides functionality to connect and interact with SSH server.
     /// </summary>
-    public partial class Session : IDisposable, ISession
+    public partial class Session : ISession
     {
         /// <summary>
         /// Specifies an infinite waiting period.
@@ -703,7 +703,7 @@ namespace Renci.SshNet
         /// <summary>
         /// Sends "keep alive" message to keep connection alive.
         /// </summary>
-        internal void SendKeepAlive()
+        void ISession.SendKeepAlive()
         {
             this.SendMessage(new IgnoreMessage());
         }
@@ -1405,7 +1405,7 @@ namespace Renci.SshNet
         /// <summary>
         /// Called when client is disconnecting from the server.
         /// </summary>
-        internal void OnDisconnecting()
+        void ISession.OnDisconnecting()
         {
             _isDisconnecting = true;
         }

+ 1 - 1
Renci.SshClient/Renci.SshNet/Sftp/SftpSession.cs

@@ -56,7 +56,7 @@ namespace Renci.SshNet.Sftp
             }
         }
 
-        public SftpSession(Session session, TimeSpan operationTimeout, Encoding encoding)
+        public SftpSession(ISession session, TimeSpan operationTimeout, Encoding encoding)
             : base(session, "sftp", operationTimeout, encoding)
         {
         }

+ 3 - 3
Renci.SshClient/Renci.SshNet/Shell.cs

@@ -13,7 +13,7 @@ namespace Renci.SshNet
     /// </summary>
     public partial class Shell : IDisposable
     {
-        private readonly Session _session;
+        private readonly ISession _session;
 
         private IChannelSession _channel;
 
@@ -88,7 +88,7 @@ namespace Renci.SshNet
         /// <param name="height">The height.</param>
         /// <param name="terminalModes">The terminal modes.</param>
         /// <param name="bufferSize">Size of the buffer for output stream.</param>
-        internal Shell(Session session, Stream input, Stream output, Stream extendedOutput, string terminalName, uint columns, uint rows, uint width, uint height, IDictionary<TerminalModes, uint> terminalModes, int bufferSize)
+        internal Shell(ISession session, Stream input, Stream output, Stream extendedOutput, string terminalName, uint columns, uint rows, uint width, uint height, IDictionary<TerminalModes, uint> terminalModes, int bufferSize)
         {
             this._session = session;
             this._input = input;
@@ -119,7 +119,7 @@ namespace Renci.SshNet
                 this.Starting(this, new EventArgs());
             }
 
-            this._channel = this._session.CreateClientChannel<ChannelSession>();
+            this._channel = this._session.CreateChannelSession();
             this._channel.DataReceived += Channel_DataReceived;
             this._channel.ExtendedDataReceived += Channel_ExtendedDataReceived;
             this._channel.Closed += Channel_Closed;

+ 62 - 17
Renci.SshClient/Renci.SshNet/SshClient.cs

@@ -17,6 +17,14 @@ namespace Renci.SshNet
         /// </summary>
         private readonly List<ForwardedPort> _forwardedPorts;
 
+        /// <summary>
+        /// Holds a value indicating whether the current instance is disposed.
+        /// </summary>
+        /// <value>
+        /// <c>true</c> if the current instance is disposed; otherwise, <c>false</c>.
+        /// </value>
+        private bool _isDisposed;
+
         private Stream _inputStream;
 
         /// <summary>
@@ -127,8 +135,26 @@ namespace Renci.SshNet
         /// If <paramref name="ownsConnectionInfo"/> is <c>true</c>, then the
         /// connection info will be disposed when this instance is disposed.
         /// </remarks>
-        private SshClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo)
-            : base(connectionInfo, ownsConnectionInfo)
+        internal SshClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo)
+            : base(connectionInfo, ownsConnectionInfo, new ServiceFactory())
+        {
+            _forwardedPorts = new List<ForwardedPort>();
+        }
+
+        /// <summary>
+        /// Initializes a new instance of the <see cref="SshClient"/> class.
+        /// </summary>
+        /// <param name="connectionInfo">The connection info.</param>
+        /// <param name="ownsConnectionInfo">Specified whether this instance owns the connection info.</param>
+        /// <param name="serviceFactory">The factory to use for creating new services.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is null.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serviceFactory"/> is null.</exception>
+        /// <remarks>
+        /// If <paramref name="ownsConnectionInfo"/> is <c>true</c>, then the
+        /// connection info will be disposed when this instance is disposed.
+        /// </remarks>
+        internal SshClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory)
+            : base(connectionInfo, ownsConnectionInfo, serviceFactory)
         {
             _forwardedPorts = new List<ForwardedPort>();
         }
@@ -164,11 +190,7 @@ namespace Renci.SshNet
             if (port == null)
                 throw new ArgumentNullException("port");
 
-            if (port.Session != null && port.Session != Session)
-                throw new InvalidOperationException("Forwarded port is already added to a different client.");
-
-            port.Session = Session;
-
+            AttachForwardedPort(port);
             _forwardedPorts.Add(port);
         }
 
@@ -185,11 +207,23 @@ namespace Renci.SshNet
             //  Stop port forwarding before removing it
             port.Stop();
 
-            port.Session = null;
-
+            DetachForwardedPort(port);
             _forwardedPorts.Remove(port);
         }
 
+        private void AttachForwardedPort(ForwardedPort port)
+        {
+            if (port.Session != null && port.Session != Session)
+                throw new InvalidOperationException("Forwarded port is already added to a different client.");
+
+            port.Session = Session;
+        }
+
+        private static void DetachForwardedPort(ForwardedPort port)
+        {
+            port.Session = null;
+        }
+
         /// <summary>
         /// Creates the command to be executed.
         /// </summary>
@@ -399,11 +433,12 @@ namespace Renci.SshNet
         {
             base.OnDisconnected();
 
-            foreach (var forwardedPort in _forwardedPorts.ToArray())
+            for (var i = _forwardedPorts.Count - 1; i >= 0; i--)
             {
-                RemoveForwardedPort(forwardedPort);
+                var port = _forwardedPorts[i];
+                DetachForwardedPort(port);
+                _forwardedPorts.RemoveAt(i);
             }
-
         }
 
         /// <summary>
@@ -412,13 +447,23 @@ namespace Renci.SshNet
         /// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged ResourceMessages.</param>
         protected override void Dispose(bool disposing)
         {
-            base.Dispose(disposing);
-
-            if (_inputStream != null)
+            if (!_isDisposed)
             {
-                _inputStream.Dispose();
-                _inputStream = null;
+                if (disposing)
+                {
+                    Disconnect();
+
+                    if (_inputStream != null)
+                    {
+                        _inputStream.Dispose();
+                        _inputStream = null;
+                    }
+                }
             }
+
+            base.Dispose(disposing);
+
+            _isDisposed = true;
         }
     }
 }