ソースを参照

Ensure we do not block forever waiting for message listener thread to complete when the server sends a disconnect message, or the message listener thread never actually started. Fixes issue #2591.
Refactored creating of key exchange algorithm to allow unit tests without encyption and decryption.
Added tests for KeyExchangeDhGroupExchangeRequest and IgnoreMessage.
Added large set of tests for Session in both connected and disconnected mode.
Remove SshData.ReadInt64() and SshData.Write(long). Fixes issue #2579.

Gert Driesen 11 年 前
コミット
80702f796d
44 ファイル変更2555 行追加339 行削除
  1. 57 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/ClientAuthenticationTest.cs
  2. 1 4
      Renci.SshClient/Renci.SshNet.Tests/Classes/ClientAuthenticationTestBase.cs
  3. 0 12
      Renci.SshClient/Renci.SshNet.Tests/Classes/Common/SshConnectionExceptionTest.cs
  4. 8 4
      Renci.SshClient/Renci.SshNet.Tests/Classes/ConnectionInfoTest.cs
  5. 78 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/ConnectionInfoTest_Authenticate_Failure.cs
  6. 59 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/ConnectionInfoTest_Authenticate_Success.cs
  7. 83 24
      Renci.SshClient/Renci.SshNet.Tests/Classes/Messages/Transport/IgnoreMessageTest.cs
  8. 32 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/Messages/Transport/KeyExchangeDhGroupExchangeGroupBuilder.cs
  9. 52 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/Messages/Transport/KeyExchangeDhGroupExchangeReplyBuilder.cs
  10. 50 8
      Renci.SshClient/Renci.SshNet.Tests/Classes/Messages/Transport/KeyExchangeDhGroupExchangeRequestTest.cs
  11. 8 8
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest.HttpProxy.cs
  12. 98 52
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest.cs
  13. 169 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs
  14. 185 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs
  15. 168 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_ConnectionReset.cs
  16. 148 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_Disconnect.cs
  17. 164 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsBadPacket.cs
  18. 168 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsDisconnectMessage.cs
  19. 172 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsDisconnectMessageAndShutsDownSocket.cs
  20. 161 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerShutsDownSocket.cs
  21. 191 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs
  22. 95 0
      Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_SocketConnected_BadPacketAndDispose.cs
  23. 1 1
      Renci.SshClient/Renci.SshNet.Tests/Classes/Sftp/Responses/SftpHandleResponseTest.cs
  24. 15 0
      Renci.SshClient/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj
  25. 2 2
      Renci.SshClient/Renci.SshNet/BaseClient.cs
  26. 1 1
      Renci.SshClient/Renci.SshNet/ClientAuthentication.cs
  27. 0 10
      Renci.SshClient/Renci.SshNet/Common/SshConnectionException.cs
  28. 24 87
      Renci.SshClient/Renci.SshNet/Common/SshData.cs
  29. 93 13
      Renci.SshClient/Renci.SshNet/Common/SshDataStream.cs
  30. 7 2
      Renci.SshClient/Renci.SshNet/ConnectionInfo.cs
  31. 12 0
      Renci.SshClient/Renci.SshNet/IClientAuthentication.cs
  32. 18 0
      Renci.SshClient/Renci.SshNet/IServiceFactory.cs
  33. 0 15
      Renci.SshClient/Renci.SshNet/ISession.cs
  34. 11 4
      Renci.SshClient/Renci.SshNet/Messages/Transport/IgnoreMessage.cs
  35. 3 1
      Renci.SshClient/Renci.SshNet/Messages/Transport/KeyExchangeDhGroupExchangeGroup.cs
  36. 3 1
      Renci.SshClient/Renci.SshNet/Messages/Transport/KeyExchangeDhGroupExchangeReply.cs
  37. 3 1
      Renci.SshClient/Renci.SshNet/Messages/Transport/KeyExchangeDhGroupExchangeRequest.cs
  38. 3 1
      Renci.SshClient/Renci.SshNet/Messages/Transport/ServiceAcceptMessage.cs
  39. 2 0
      Renci.SshClient/Renci.SshNet/Renci.SshNet.csproj
  40. 96 0
      Renci.SshClient/Renci.SshNet/Security/IKeyExchange.cs
  41. 2 1
      Renci.SshClient/Renci.SshNet/Security/KeyExchange.cs
  42. 49 1
      Renci.SshClient/Renci.SshNet/ServiceFactory.cs
  43. 3 8
      Renci.SshClient/Renci.SshNet/Session.NET.cs
  44. 60 78
      Renci.SshClient/Renci.SshNet/Session.cs

+ 57 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/ClientAuthenticationTest.cs

@@ -0,0 +1,57 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class ClientAuthenticationTest
+    {
+        private ClientAuthentication _clientAuthentication;
+
+        [TestInitialize]
+        public void Init()
+        {
+            _clientAuthentication = new ClientAuthentication();
+        }
+
+        [TestMethod]
+        public void AuthenticateShouldThrowArgumentNullExceptionWhenConnectionInfoIsNull()
+        {
+            IConnectionInfoInternal connectionInfo = null;
+            var session = new Mock<ISession>(MockBehavior.Strict).Object;
+
+            try
+            {
+                _clientAuthentication.Authenticate(connectionInfo, session);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("connectionInfo", ex.ParamName);
+            }
+        }
+
+        [TestMethod]
+        public void AuthenticateShouldThrowArgumentNullExceptionWhenSessionIsNull()
+        {
+            var connectionInfo = new Mock<IConnectionInfoInternal>(MockBehavior.Strict).Object;
+            ISession session = null;
+
+            try
+            {
+                _clientAuthentication.Authenticate(connectionInfo, session);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("session", ex.ParamName);
+            }
+        }
+    }
+}

+ 1 - 4
Renci.SshClient/Renci.SshNet.Tests/Classes/ClientAuthenticationTestBase.cs

@@ -1,8 +1,5 @@
-using System;
-using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
 using Moq;
-using Moq.Protected;
-using Renci.SshNet.Common;
 using Renci.SshNet.Tests.Common;
 
 namespace Renci.SshNet.Tests.Classes

+ 0 - 12
Renci.SshClient/Renci.SshNet.Tests/Classes/Common/SshConnectionExceptionTest.cs

@@ -60,18 +60,6 @@ namespace Renci.SshNet.Tests.Classes.Common
             Assert.Inconclusive("TODO: Implement code to verify target");
         }
 
-        /// <summary>
-        ///A test for SshConnectionException Constructor
-        ///</summary>
-        [TestMethod()]
-        public void SshConnectionExceptionConstructorTest4()
-        {
-            string message = string.Empty; // TODO: Initialize to an appropriate value
-            Exception innerException = null; // TODO: Initialize to an appropriate value
-            SshConnectionException target = new SshConnectionException(message, innerException);
-            Assert.Inconclusive("TODO: Implement code to verify target");
-        }
-
         /// <summary>
         ///A test for GetObjectData
         ///</summary>

+ 8 - 4
Renci.SshClient/Renci.SshNet.Tests/Classes/ConnectionInfoTest.cs

@@ -1,6 +1,7 @@
 using System.Globalization;
 using System.Net;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
 using Renci.SshNet.Tests.Common;
 using Renci.SshNet.Tests.Properties;
 using System;
@@ -323,20 +324,23 @@ namespace Renci.SshNet.Tests.Classes
 
         [TestMethod]
         [TestCategory("ConnectionInfo")]
-        public void AuthenticateShouldThrowArgumentNullExceptionWhenSessionIsNull()
+        public void AuthenticateShouldThrowArgumentNullExceptionWhenServiceFactoryIsNull()
         {
-            var ret = new ConnectionInfo(Resources.HOST, int.Parse(Resources.PORT), Resources.USERNAME, ProxyTypes.None,
+            var connectionInfo = new ConnectionInfo(Resources.HOST, int.Parse(Resources.PORT), Resources.USERNAME, ProxyTypes.None,
                 Resources.HOST, int.Parse(Resources.PORT), Resources.USERNAME, Resources.PASSWORD,
                 new KeyboardInteractiveAuthenticationMethod(Resources.USERNAME));
+            var session = new Mock<ISession>(MockBehavior.Strict).Object;
+            IServiceFactory serviceFactory = null;
 
             try
             {
-                ret.Authenticate(null);
+                connectionInfo.Authenticate(session, serviceFactory);
+                Assert.Fail();
             }
             catch (ArgumentNullException ex)
             {
                 Assert.IsNull(ex.InnerException);
-                Assert.AreEqual("session", ex.ParamName);
+                Assert.AreEqual("serviceFactory", ex.ParamName);
             }
         }
    }

+ 78 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/ConnectionInfoTest_Authenticate_Failure.cs

@@ -0,0 +1,78 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Common;
+using Renci.SshNet.Tests.Properties;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class ConnectionInfoTest_Authenticate_Failure
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<IClientAuthentication> _clientAuthenticationMock;
+        private Mock<ISession> _sessionMock;
+        private ConnectionInfo _connectionInfo;
+        private SshAuthenticationException _authenticationException;
+        private SshAuthenticationException _actualException;
+
+        [TestInitialize]
+        public void Init()
+        {
+            Arrange();
+            Act();
+        }
+
+        protected void Arrange()
+        {
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _clientAuthenticationMock = new Mock<IClientAuthentication>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+
+            _connectionInfo = new ConnectionInfo(Resources.HOST, int.Parse(Resources.PORT), Resources.USERNAME, ProxyTypes.None,
+                Resources.HOST, int.Parse(Resources.PORT), Resources.USERNAME, Resources.PASSWORD,
+                new KeyboardInteractiveAuthenticationMethod(Resources.USERNAME));
+            _authenticationException = new SshAuthenticationException();
+
+            _serviceFactoryMock.Setup(p => p.CreateClientAuthentication()).Returns(_clientAuthenticationMock.Object);
+            _clientAuthenticationMock.Setup(p => p.Authenticate(_connectionInfo, _sessionMock.Object))
+                .Throws(_authenticationException);
+        }
+
+        protected void Act()
+        {
+            try
+            {
+                _connectionInfo.Authenticate(_sessionMock.Object, _serviceFactoryMock.Object);
+            }
+            catch (SshAuthenticationException ex)
+            {
+                _actualException = ex;
+            }
+        }
+
+        [TestMethod]
+        public void AuthenticateShouldHaveThrownSshAuthenticationException()
+        {
+            Assert.IsNotNull(_actualException);
+            Assert.AreSame(_authenticationException, _actualException);
+        }
+
+        [TestMethod]
+        public void IsAuthenticatedShouldReturnFalse()
+        {
+            Assert.IsFalse(_connectionInfo.IsAuthenticated);
+        }
+
+        [TestMethod]
+        public void CreateClientAuthenticationOnServiceFactoryShouldBeInvokedOnce()
+        {
+            _serviceFactoryMock.Verify(p => p.CreateClientAuthentication(), Times.Once);
+        }
+
+        [TestMethod]
+        public void AuthenticateOnClientAuthenticationShouldBeInvokedOnce()
+        {
+            _clientAuthenticationMock.Verify(p => p.Authenticate(_connectionInfo, _sessionMock.Object), Times.Once);
+        }
+    }
+}

+ 59 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/ConnectionInfoTest_Authenticate_Success.cs

@@ -0,0 +1,59 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Tests.Properties;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class ConnectionInfoTest_Authenticate_Success
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<IClientAuthentication> _clientAuthenticationMock;
+        private Mock<ISession> _sessionMock;
+        private ConnectionInfo _connectionInfo;
+
+        [TestInitialize]
+        public void Init()
+        {
+            Arrange();
+            Act();
+        }
+
+        protected void Arrange()
+        {
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _clientAuthenticationMock = new Mock<IClientAuthentication>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+
+            _connectionInfo = new ConnectionInfo(Resources.HOST, int.Parse(Resources.PORT), Resources.USERNAME, ProxyTypes.None,
+                Resources.HOST, int.Parse(Resources.PORT), Resources.USERNAME, Resources.PASSWORD,
+                new KeyboardInteractiveAuthenticationMethod(Resources.USERNAME));
+
+            _serviceFactoryMock.Setup(p => p.CreateClientAuthentication()).Returns(_clientAuthenticationMock.Object);
+            _clientAuthenticationMock.Setup(p => p.Authenticate(_connectionInfo, _sessionMock.Object));
+        }
+
+        protected void Act()
+        {
+            _connectionInfo.Authenticate(_sessionMock.Object, _serviceFactoryMock.Object);
+        }
+
+        [TestMethod]
+        public void IsAuthenticatedShouldReturnTrue()
+        {
+            Assert.IsTrue(_connectionInfo.IsAuthenticated);
+        }
+
+        [TestMethod]
+        public void CreateClientAuthenticationOnServiceFactoryShouldBeInvokedOnce()
+        {
+            _serviceFactoryMock.Verify(p => p.CreateClientAuthentication(), Times.Once);
+        }
+
+        [TestMethod]
+        public void AuthenticateOnClientAuthenticationShouldBeInvokedOnce()
+        {
+            _clientAuthenticationMock.Verify(p => p.Authenticate(_connectionInfo, _sessionMock.Object), Times.Once);
+        }
+    }
+}

+ 83 - 24
Renci.SshClient/Renci.SshNet.Tests/Classes/Messages/Transport/IgnoreMessageTest.cs

@@ -1,36 +1,95 @@
-using Renci.SshNet.Messages.Transport;
+using System;
+using System.Linq;
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
-using System;
-using Renci.SshNet.Tests.Common;
 
 namespace Renci.SshNet.Tests.Classes.Messages.Transport
 {
-    /// <summary>
-    ///This is a test class for IgnoreMessageTest and is intended
-    ///to contain all IgnoreMessageTest Unit Tests
-    ///</summary>
-    [TestClass()]
-    public class IgnoreMessageTest : TestBase
+    [TestClass]
+    public class IgnoreMessageTest
     {
-        /// <summary>
-        ///A test for IgnoreMessage Constructor
-        ///</summary>
-        [TestMethod()]
-        public void IgnoreMessageConstructorTest()
+        private Random _random;
+        private byte[] _data;
+
+        [TestInitialize]
+        public void Init()
+        {
+            _random = new Random();
+            _data = new byte[_random.Next(1, 10)];
+            _random.NextBytes(_data);
+        }
+
+        [TestMethod]
+        public void DefaultConstructor()
+        {
+            var target = new IgnoreMessage();
+            Assert.IsNotNull(target.Data);
+            Assert.AreEqual(0, target.Data.Length);
+        }
+
+        [TestMethod]
+        public void Constructor_Data()
+        {
+            var target = new IgnoreMessage(_data);
+            Assert.AreSame(_data, target.Data);
+        }
+
+        [TestMethod]
+        public void Constructor_Data_ShouldThrowArgumentNullExceptionWhenDataIsNull()
         {
-            IgnoreMessage target = new IgnoreMessage();
-            Assert.Inconclusive("TODO: Implement code to verify target");
+            byte[] data = null;
+
+            try
+            {
+                new IgnoreMessage(data);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("data", ex.ParamName);
+            }
         }
 
-        /// <summary>
-        ///A test for IgnoreMessage Constructor
-        ///</summary>
-        [TestMethod()]
-        public void IgnoreMessageConstructorTest1()
+        [TestMethod]
+        public void GetBytes()
         {
-            byte[] data = null; // TODO: Initialize to an appropriate value
-            IgnoreMessage target = new IgnoreMessage(data);
-            Assert.Inconclusive("TODO: Implement code to verify target");
+            var request = new IgnoreMessage(_data);
+
+            var bytes = request.GetBytes();
+
+            var expectedBytesLength = 0;
+            expectedBytesLength += 1; // Type
+            expectedBytesLength += 4; // Data length
+            expectedBytesLength += _data.Length; // Data
+
+            Assert.AreEqual(expectedBytesLength, bytes.Length);
+
+            var sshDataStream = new SshDataStream(bytes);
+
+            Assert.AreEqual(IgnoreMessage.MessageNumber, sshDataStream.ReadByte());
+            Assert.AreEqual((uint) _data.Length, sshDataStream.ReadUInt32());
+
+            var actualData = new byte[_data.Length];
+            sshDataStream.Read(actualData, 0, actualData.Length);
+            Assert.IsTrue(_data.SequenceEqual(actualData));
+
+            Assert.IsTrue(sshDataStream.IsEndOfData);
+        }
+
+        [TestMethod]
+        public void Load()
+        {
+            var ignoreMessage = new IgnoreMessage(_data);
+            var bytes = ignoreMessage.GetBytes();
+            var target = new IgnoreMessage();
+
+            target.Load(bytes);
+
+            Assert.IsNotNull(target.Data);
+            Assert.AreEqual(_data.Length, target.Data.Length);
+            Assert.IsTrue(target.Data.SequenceEqual(_data));
         }
     }
 }

+ 32 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/Messages/Transport/KeyExchangeDhGroupExchangeGroupBuilder.cs

@@ -0,0 +1,32 @@
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes.Messages.Transport
+{
+    public class KeyExchangeDhGroupExchangeGroupBuilder
+    {
+        private BigInteger _safePrime;
+        private BigInteger _subGroup;
+
+        public KeyExchangeDhGroupExchangeGroupBuilder WithSafePrime(BigInteger safePrime)
+        {
+            _safePrime = safePrime;
+            return this;
+        }
+
+        public KeyExchangeDhGroupExchangeGroupBuilder WithSubGroup(BigInteger subGroup)
+        {
+            _subGroup = subGroup;
+            return this;
+        }
+
+        public byte[] Build()
+        {
+            var sshDataStream = new SshDataStream(0);
+            sshDataStream.WriteByte(KeyExchangeDhGroupExchangeGroup.MessageNumber);
+            sshDataStream.Write(_safePrime);
+            sshDataStream.Write(_subGroup);
+            return sshDataStream.ToArray();
+        }
+    }
+}

+ 52 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/Messages/Transport/KeyExchangeDhGroupExchangeReplyBuilder.cs

@@ -0,0 +1,52 @@
+using System.Collections.Generic;
+using System.Text;
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes.Messages.Transport
+{
+    public class KeyExchangeDhGroupExchangeReplyBuilder
+    {
+        private byte[] _hostKeyAlgorithm;
+        private byte[] _hostKeys;
+        private BigInteger _f;
+        private byte[] _signature;
+
+        public KeyExchangeDhGroupExchangeReplyBuilder WithHostKey(string hostKeyAlgorithm, params BigInteger[] hostKeys)
+        {
+            _hostKeyAlgorithm = Encoding.UTF8.GetBytes(hostKeyAlgorithm);
+
+            var sshDataStream = new SshDataStream(0);
+            foreach (var hostKey in hostKeys)
+                sshDataStream.Write(hostKey);
+            _hostKeys = sshDataStream.ToArray();
+
+            return this;
+        }
+
+        public KeyExchangeDhGroupExchangeReplyBuilder WithF(BigInteger f)
+        {
+            _f = f;
+            return this;
+        }
+
+        public KeyExchangeDhGroupExchangeReplyBuilder WithSignature(byte[] signature)
+        {
+            _signature = signature;
+            return this;
+        }
+
+        public byte[] Build()
+        {
+            var sshDataStream = new SshDataStream(0);
+            sshDataStream.WriteByte(KeyExchangeDhGroupExchangeReply.MessageNumber);
+            sshDataStream.Write((uint)(4 + _hostKeyAlgorithm.Length + _hostKeys.Length));
+            sshDataStream.Write((uint) _hostKeyAlgorithm.Length);
+            sshDataStream.Write(_hostKeyAlgorithm, 0, _hostKeyAlgorithm.Length);
+            sshDataStream.Write(_hostKeys, 0, _hostKeys.Length);
+            sshDataStream.Write(_f);
+            sshDataStream.WriteBinary(_signature);
+            return sshDataStream.ToArray();
+        }
+    }
+}

+ 50 - 8
Renci.SshClient/Renci.SshNet.Tests/Classes/Messages/Transport/KeyExchangeDhGroupExchangeRequestTest.cs

@@ -1,7 +1,8 @@
-using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Renci.SshNet.Common;
 using Renci.SshNet.Messages.Transport;
 using Renci.SshNet.Tests.Common;
-using System.Linq;
 
 namespace Renci.SshNet.Tests.Classes.Messages.Transport
 {
@@ -9,20 +10,61 @@ namespace Renci.SshNet.Tests.Classes.Messages.Transport
     /// Represents SSH_MSG_KEX_DH_GEX_REQUEST message.
     /// </summary>
     [TestClass]
-    public class KeyExchangeDhGroupExchangeRequestTest : TestBase
+    public class KeyExchangeDhGroupExchangeRequestTest
     {
+        private uint _minimum;
+        private uint _preferred;
+        private uint _maximum;
+
+        public void Init()
+        {
+            var random = new Random();
+            _minimum = (uint) random.Next(1, int.MaxValue);
+            _preferred = (uint) random.Next(1, int.MaxValue);
+            _maximum = (uint) random.Next(1, int.MaxValue);
+        }
+
+
         [TestMethod]
         [TestCategory("KeyExchangeInitMessage")]
         [Owner("olegkap")]
         [Description("Validates KeyExchangeInitMessage message serialization.")]
         public void Test_KeyExchangeDhGroupExchangeRequest_GetBytes()
         {
-            var m = new KeyExchangeDhGroupExchangeRequest(1024, 1024, 1204);
-            var input = new byte[] { 0x22, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, };
-            var output = m.GetBytes();
+            var request = new KeyExchangeDhGroupExchangeRequest(_minimum, _preferred, _maximum);
+
+            var bytes = request.GetBytes();
+
+            var expectedBytesLength = 0;
+            expectedBytesLength += 1; // Type
+            expectedBytesLength += 4; // Minimum
+            expectedBytesLength += 4; // Preferred
+            expectedBytesLength += 4; // Maximum
+
+            Assert.AreEqual(expectedBytesLength, bytes.Length);
+
+            var sshDataStream = new SshDataStream(bytes);
+
+            Assert.AreEqual(KeyExchangeDhGroupExchangeRequest.MessageNumber, sshDataStream.ReadByte());
+            Assert.AreEqual(_minimum, sshDataStream.ReadUInt32());
+            Assert.AreEqual(_preferred, sshDataStream.ReadUInt32());
+            Assert.AreEqual(_maximum, sshDataStream.ReadUInt32());
+
+            Assert.IsTrue(sshDataStream.IsEndOfData);
+        }
+
+        [TestMethod]
+        public void Load()
+        {
+            var request = new KeyExchangeDhGroupExchangeRequest(_minimum, _preferred, _maximum);
+            var bytes = request.GetBytes();
+            var target = new KeyExchangeDhGroupExchangeRequest(0, 0, 0);
+
+            target.Load(bytes);
 
-            //  Skip first 17 bytes since 16 bytes are randomly generated
-            Assert.IsTrue(input.SequenceEqual(output.Skip(17)));
+            Assert.AreEqual(_minimum, target.Minimum);
+            Assert.AreEqual(_preferred, target.Preferred);
+            Assert.AreEqual(_maximum, target.Maximum);
         }
     }
 }

+ 8 - 8
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest.HttpProxy.cs

@@ -21,7 +21,7 @@ namespace Renci.SshNet.Tests.Classes
                 proxyStub.Responses.Add(Encoding.ASCII.GetBytes("Whatever\r\n"));
                 proxyStub.Start();
 
-                using (var session = new Session(CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon")))
+                using (var session = new Session(CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon"), _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -48,7 +48,7 @@ namespace Renci.SshNet.Tests.Classes
                 proxyStub.Responses.Add(Encoding.ASCII.GetBytes("HTTP/1.0 501 Custom\r\n"));
                 proxyStub.Start();
 
-                using (var session = new Session(CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon")))
+                using (var session = new Session(CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon"), _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -78,7 +78,7 @@ namespace Renci.SshNet.Tests.Classes
                 proxyStub.Responses.Add(Encoding.ASCII.GetBytes("SSH-666-SshStub"));
                 proxyStub.Start();
 
-                using (var session = new Session(CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon")))
+                using (var session = new Session(CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon"), _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -110,7 +110,7 @@ namespace Renci.SshNet.Tests.Classes
                 proxyStub.Responses.Add(Encoding.ASCII.GetBytes("SSH-666-SshStub"));
                 proxyStub.Start();
 
-                using (var session = new Session(CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon")))
+                using (var session = new Session(CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon"), _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -137,7 +137,7 @@ namespace Renci.SshNet.Tests.Classes
                 proxyStub.Responses.Add(Encoding.ASCII.GetBytes("HTTP/1.0 501 Custom\r\n"));
                 proxyStub.Start();
 
-                using (var session = new Session(CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon")))
+                using (var session = new Session(CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon"), _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -165,7 +165,7 @@ namespace Renci.SshNet.Tests.Classes
                 proxyStub.Start();
 
                 var connectionInfo = CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, "anon");
-                using (var session = new Session(connectionInfo))
+                using (var session = new Session(connectionInfo, _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -194,7 +194,7 @@ namespace Renci.SshNet.Tests.Classes
                 proxyStub.Start();
 
                 var connectionInfo = CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, string.Empty);
-                using (var session = new Session(connectionInfo))
+                using (var session = new Session(connectionInfo, _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -222,7 +222,7 @@ namespace Renci.SshNet.Tests.Classes
                 proxyStub.Start();
 
                 var connectionInfo = CreateConnectionInfoWithProxy(proxyEndPoint, serverEndPoint, null);
-                using (var session = new Session(connectionInfo))
+                using (var session = new Session(connectionInfo, _serviceFactoryMock.Object))
                 {
                     try
                     {

+ 98 - 52
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest.cs

@@ -4,6 +4,7 @@ using System.Net;
 using System.Net.Sockets;
 using System.Text;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
 using Renci.SshNet.Common;
 using Renci.SshNet.Tests.Common;
 using Renci.SshNet.Tests.Properties;
@@ -16,6 +17,52 @@ namespace Renci.SshNet.Tests.Classes
     [TestClass]
     public partial class SessionTest : TestBase
     {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+
+        protected override void OnInit()
+        {
+            base.OnInit();
+
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+        }
+
+        [TestMethod]
+        public void ConstructorShouldThrowArgumentNullExceptionWhenConnectionInfoIsNull()
+        {
+            ConnectionInfo connectionInfo = null;
+            var serviceFactory = new Mock<IServiceFactory>(MockBehavior.Strict).Object;
+
+            try
+            {
+                new Session(connectionInfo, serviceFactory);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("connectionInfo", ex.ParamName);
+            }
+        }
+
+        [TestMethod]
+        public void ConstructorShouldThrowArgumentNullExceptionWhenServiceFactoryIsNull()
+        {
+            var serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+            var connectionInfo = CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5));
+            IServiceFactory serviceFactory = null;
+
+            try
+            {
+                new Session(connectionInfo, serviceFactory);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("serviceFactory", ex.ParamName);
+            }
+        }
+
         [TestMethod]
         public void ConnectShouldSkipLinesBeforeProtocolIdentificationString()
         {
@@ -24,16 +71,16 @@ namespace Renci.SshNet.Tests.Classes
 
             using (var serverStub = new AsyncSocketListener(serverEndPoint))
             {
-                serverStub.Connected += (socket) =>
-                {
-                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
-                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
-                    socket.Send(Encoding.ASCII.GetBytes("SSH-666-SshStub\r\n"));
-                    socket.Shutdown(SocketShutdown.Send);
-                };
+                serverStub.Connected += socket =>
+                    {
+                        socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                        socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                        socket.Send(Encoding.ASCII.GetBytes("SSH-666-SshStub\r\n"));
+                        socket.Shutdown(SocketShutdown.Send);
+                    };
                 serverStub.Start();
 
-                using (var session = new Session(connectionInfo))
+                using (var session = new Session(connectionInfo, _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -59,16 +106,16 @@ namespace Renci.SshNet.Tests.Classes
 
             using (var serverStub = new AsyncSocketListener(serverEndPoint))
             {
-                serverStub.Connected += (socket) =>
-                {
-                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
-                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
-                    socket.Send(Encoding.ASCII.GetBytes("SSH-666-SshStub"));
-                    socket.Shutdown(SocketShutdown.Send);
-                };
+                serverStub.Connected += socket =>
+                    {
+                        socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                        socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                        socket.Send(Encoding.ASCII.GetBytes("SSH-666-SshStub"));
+                        socket.Shutdown(SocketShutdown.Send);
+                    };
                 serverStub.Start();
 
-                using (var session = new Session(connectionInfo))
+                using (var session = new Session(connectionInfo, _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -95,15 +142,15 @@ namespace Renci.SshNet.Tests.Classes
 
             using (var serverStub = new AsyncSocketListener(serverEndPoint))
             {
-                serverStub.Connected += (socket) =>
-                {
-                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
-                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
-                    clientSocket = socket;
-                };
+                serverStub.Connected += socket =>
+                    {
+                        socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                        socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                        clientSocket = socket;
+                    };
                 serverStub.Start();
 
-                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromMilliseconds(500))))
+                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromMilliseconds(500)), _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -133,15 +180,15 @@ namespace Renci.SshNet.Tests.Classes
             // response ends with CRLF
             using (var serverStub = new AsyncSocketListener(serverEndPoint))
             {
-                serverStub.Connected += (socket) =>
-                {
-                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
-                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
-                    socket.Shutdown(SocketShutdown.Send);
-                };
+                serverStub.Connected += socket =>
+                    {
+                        socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                        socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                        socket.Shutdown(SocketShutdown.Send);
+                    };
                 serverStub.Start();
 
-                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5))))
+                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5)), _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -159,15 +206,15 @@ namespace Renci.SshNet.Tests.Classes
             // response does not end with CRLF
             using (var serverStub = new AsyncSocketListener(serverEndPoint))
             {
-                serverStub.Connected += (socket) =>
-                {
-                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
-                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner"));
-                    socket.Shutdown(SocketShutdown.Send);
-                };
+                serverStub.Connected += socket =>
+                    {
+                        socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                        socket.Send(Encoding.ASCII.GetBytes("WELCOME banner"));
+                        socket.Shutdown(SocketShutdown.Send);
+                    };
                 serverStub.Start();
 
-                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5))))
+                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5)), _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -185,16 +232,16 @@ namespace Renci.SshNet.Tests.Classes
             // last line is empty
             using (var serverStub = new AsyncSocketListener(serverEndPoint))
             {
-                serverStub.Connected += (socket) =>
-                {
-                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
-                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
-                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
-                    socket.Shutdown(SocketShutdown.Send);
-                };
+                serverStub.Connected += socket =>
+                    {
+                        socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                        socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                        socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                        socket.Shutdown(SocketShutdown.Send);
+                    };
                 serverStub.Start();
 
-                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5))))
+                using (var session = new Session(CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5)), _serviceFactoryMock.Object))
                 {
                     try
                     {
@@ -215,11 +262,10 @@ namespace Renci.SshNet.Tests.Classes
         {
             var connectionInfo = new ConnectionInfo("invalid.", 40, "user",
                 new KeyboardInteractiveAuthenticationMethod("user"));
-            var session = new Session(connectionInfo);
+            var session = new Session(connectionInfo, _serviceFactoryMock.Object);
 
             try
             {
-
                 session.Connect();
                 Assert.Fail();
             }
@@ -234,7 +280,7 @@ namespace Renci.SshNet.Tests.Classes
         {
             var connectionInfo = new ConnectionInfo("localhost", 40, "user", ProxyTypes.Http, "invalid.", 80,
                 "proxyUser", "proxyPwd", new KeyboardInteractiveAuthenticationMethod("user"));
-            var session = new Session(connectionInfo);
+            var session = new Session(connectionInfo, _serviceFactoryMock.Object);
 
             try
             {
@@ -252,7 +298,7 @@ namespace Renci.SshNet.Tests.Classes
         {
             var connectionInfo = new ConnectionInfo("localhost", 6767, Resources.USERNAME,
                 new KeyboardInteractiveAuthenticationMethod(Resources.USERNAME));
-            var session = new Session(connectionInfo);
+            var session = new Session(connectionInfo, _serviceFactoryMock.Object);
 
             try
             {
@@ -270,7 +316,7 @@ namespace Renci.SshNet.Tests.Classes
         {
             var connectionInfo = new ConnectionInfo("localhost", 6767, Resources.USERNAME,
                 new KeyboardInteractiveAuthenticationMethod(Resources.USERNAME));
-            var session = new Session(connectionInfo);
+            var session = new Session(connectionInfo, _serviceFactoryMock.Object);
 
             session.Disconnect();
         }
@@ -280,7 +326,7 @@ namespace Renci.SshNet.Tests.Classes
         {
             var connectionInfo = new ConnectionInfo("localhost", 6767, Resources.USERNAME,
                 new KeyboardInteractiveAuthenticationMethod(Resources.USERNAME));
-            var session = new Session(connectionInfo);
+            var session = new Session(connectionInfo, _serviceFactoryMock.Object);
 
             try
             {
@@ -298,7 +344,7 @@ namespace Renci.SshNet.Tests.Classes
         {
             var connectionInfo = new ConnectionInfo("localhost", 6767, Resources.USERNAME,
                 new KeyboardInteractiveAuthenticationMethod(Resources.USERNAME));
-            var session = new Session(connectionInfo);
+            var session = new Session(connectionInfo, _serviceFactoryMock.Object);
 
             session.Disconnect();
         }

+ 169 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs

@@ -0,0 +1,169 @@
+using System;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connected : SessionTest_ConnectedBase
+    {
+        private IgnoreMessage _ignoreMessage;
+
+        protected override void Arrange()
+        {
+            base.Arrange();
+
+            var data = new byte[10];
+            Random.NextBytes(data);
+            _ignoreMessage = new IgnoreMessage(data);
+        }
+
+        protected override void Act()
+        {
+        }
+
+        [TestMethod]
+        public void ClientVersionIsRenciSshNet()
+        {
+            Assert.AreEqual("SSH-2.0-Renci.SshNet.SshClient.0.0.1", Session.ClientVersion);
+        }
+
+        [TestMethod]
+        public void ConnectionInfoShouldReturnConnectionInfoPassedThroughConstructor()
+        {
+            Assert.AreSame(ConnectionInfo, Session.ConnectionInfo);
+        }
+
+        [TestMethod]
+        public void IsConnectedShouldReturnTrue()
+        {
+            Assert.IsTrue(Session.IsConnected);
+        }
+
+        [TestMethod]
+        public void SendMessageShouldSendPacketToServer()
+        {
+            ServerBytesReceivedRegister.Clear();
+
+            Session.SendMessage(_ignoreMessage);
+
+            // give session time to process message
+            Thread.Sleep(100);
+
+            Assert.AreEqual(1, ServerBytesReceivedRegister.Count);
+        }
+
+        [TestMethod]
+        public void SessionIdShouldReturnExchangeHashCalculatedFromKeyExchangeInitMessage()
+        {
+            Assert.IsNotNull(Session.SessionId);
+            Assert.AreSame(SessionId, Session.SessionId);
+        }
+
+        [TestMethod]
+        public void ServerVersionShouldNotReturnNull()
+        {
+            Assert.IsNotNull(Session.ServerVersion);
+            Assert.AreEqual("SSH-2.0-SshStub", Session.ServerVersion);
+        }
+
+        [TestMethod]
+        public void WaitOnHandle_WaitHandle_ShouldThrowArgumentNullExceptionWhenWaitHandleIsNull()
+        {
+            WaitHandle waitHandle = null;
+
+            try
+            {
+                Session.WaitOnHandle(waitHandle);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("waitHandle", ex.ParamName);
+            }
+        }
+
+        [TestMethod]
+        public void WaitOnHandle_WaitHandleAndTimeout_ShouldThrowArgumentNullExceptionWhenWaitHandleIsNull()
+        {
+            WaitHandle waitHandle = null;
+            var timeout = TimeSpan.FromMinutes(5);
+
+            try
+            {
+                Session.WaitOnHandle(waitHandle, timeout);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("waitHandle", ex.ParamName);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_ConnectionInfoShouldReturnConnectionInfoPassedThroughConstructor()
+        {
+            var session = (ISession) Session;
+            Assert.AreSame(ConnectionInfo, session.ConnectionInfo);
+        }
+
+        [TestMethod]
+        public void ISession_MessageListenerCompletedShouldNotBeSignaled()
+        {
+            var session = (ISession) Session;
+
+            Assert.IsNotNull(session.MessageListenerCompleted);
+            Assert.IsFalse(session.MessageListenerCompleted.WaitOne(0));
+        }
+
+        [TestMethod]
+        public void ISession_SendMessageShouldSendPacketToServer()
+        {
+            var session = (ISession) Session;
+            ServerBytesReceivedRegister.Clear();
+
+            session.SendMessage(_ignoreMessage);
+
+            // give session time to process message
+            Thread.Sleep(100);
+
+            Assert.AreEqual(1, ServerBytesReceivedRegister.Count);
+        }
+
+        [TestMethod]
+        public void ISession_TrySendMessageShouldSendPacketToServerAndReturnTrue()
+        {
+            var session = (ISession) Session;
+            ServerBytesReceivedRegister.Clear();
+
+            var actual = session.TrySendMessage(new IgnoreMessage());
+
+            // give session time to process message
+            Thread.Sleep(100);
+
+            Assert.IsTrue(actual);
+            Assert.AreEqual(1, ServerBytesReceivedRegister.Count);
+        }
+
+        [TestMethod]
+        public void ISession_WaitOnHandleShouldThrowArgumentNullExceptionWhenWaitHandleIsNull()
+        {
+            WaitHandle waitHandle = null;
+            var session = (ISession) Session;
+
+            try
+            {
+                session.WaitOnHandle(waitHandle);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("waitHandle", ex.ParamName);
+            }
+        }
+    }
+}

+ 185 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs

@@ -0,0 +1,185 @@
+using System;
+using System.Collections.Generic;
+using System.Globalization;
+using System.Net;
+using System.Net.Sockets;
+using System.Security.Cryptography;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Common;
+using Renci.SshNet.Compression;
+using Renci.SshNet.Messages;
+using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Security;
+using Renci.SshNet.Security.Cryptography;
+using Renci.SshNet.Tests.Common;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public abstract class SessionTest_ConnectedBase
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<IKeyExchange> _keyExchangeMock;
+        private Mock<IClientAuthentication> _clientAuthenticationMock;
+        private IPEndPoint _serverEndPoint;
+        private string _keyExchangeAlgorithm;
+
+        protected Random Random { get; private set; }
+        protected byte[] SessionId { get; private set; }
+        protected ConnectionInfo ConnectionInfo { get; private set; }
+        protected IList<EventArgs> DisconnectedRegister { get; private set; }
+        protected IList<MessageEventArgs<DisconnectMessage>> DisconnectReceivedRegister { get; private set; }
+        protected IList<ExceptionEventArgs> ErrorOccurredRegister { get; private set; }
+        protected AsyncSocketListener ServerListener { get; private set; }
+        protected IList<byte[]> ServerBytesReceivedRegister { get; private set; }
+        protected Session Session { get; private set; }
+        protected Socket ServerSocket { get; private set; }
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void TearDown()
+        {
+            if (ServerListener != null)
+            {
+                ServerListener.Dispose();
+            }
+
+            if (Session != null)
+            {
+                Session.Dispose();
+            }
+        }
+
+        protected virtual void Arrange()
+        {
+            Random = new Random();
+
+            _serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+            ConnectionInfo = new ConnectionInfo(
+                _serverEndPoint.Address.ToString(),
+                _serverEndPoint.Port,
+                "user",
+                new PasswordAuthenticationMethod("user", "password"));
+            ConnectionInfo.Timeout = TimeSpan.FromSeconds(20);
+            _keyExchangeAlgorithm = Random.Next().ToString(CultureInfo.InvariantCulture);
+            SessionId = new byte[10];
+            Random.NextBytes(SessionId);
+            DisconnectedRegister = new List<EventArgs>();
+            DisconnectReceivedRegister = new List<MessageEventArgs<DisconnectMessage>>();
+            ErrorOccurredRegister = new List<ExceptionEventArgs>();
+            ServerBytesReceivedRegister = new List<byte[]>();
+
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _keyExchangeMock = new Mock<IKeyExchange>(MockBehavior.Strict);
+            _clientAuthenticationMock = new Mock<IClientAuthentication>(MockBehavior.Strict);
+
+            Session = new Session(ConnectionInfo, _serviceFactoryMock.Object);
+            Session.Disconnected += (sender, args) => DisconnectedRegister.Add(args);
+            Session.DisconnectReceived += (sender, args) => DisconnectReceivedRegister.Add(args);
+            Session.ErrorOccured += (sender, args) => ErrorOccurredRegister.Add(args);
+            Session.KeyExchangeInitReceived += (sender, args) =>
+                {
+                    var newKeysMessage = new NewKeysMessage();
+                    var newKeys = newKeysMessage.GetPacket(8, null);
+                    ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None);
+                };
+
+            _serviceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
+            _keyExchangeMock.Setup(p => p.Name).Returns(_keyExchangeAlgorithm);
+            _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny<KeyExchangeInitMessage>()));
+            _keyExchangeMock.Setup(p => p.ExchangeHash).Returns(SessionId);
+            _keyExchangeMock.Setup(p => p.CreateServerCipher()).Returns((Cipher) null);
+            _keyExchangeMock.Setup(p => p.CreateClientCipher()).Returns((Cipher) null);
+            _keyExchangeMock.Setup(p => p.CreateServerHash()).Returns((HashAlgorithm) null);
+            _keyExchangeMock.Setup(p => p.CreateClientHash()).Returns((HashAlgorithm) null);
+            _keyExchangeMock.Setup(p => p.CreateCompressor()).Returns((Compressor) null);
+            _keyExchangeMock.Setup(p => p.CreateDecompressor()).Returns((Compressor) null);
+            _keyExchangeMock.Setup(p => p.Dispose());
+            _serviceFactoryMock.Setup(p => p.CreateClientAuthentication()).Returns(_clientAuthenticationMock.Object);
+            _clientAuthenticationMock.Setup(p => p.Authenticate(ConnectionInfo, Session));
+
+            ServerListener = new AsyncSocketListener(_serverEndPoint);
+            ServerListener.Connected += socket =>
+                {
+                    ServerSocket = socket;
+
+                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("SSH-2.0-SshStub\r\n"));
+                };
+
+            var counter = 0;
+            ServerListener.BytesReceived += (received, socket) =>
+                {
+                    ServerBytesReceivedRegister.Add(received);
+
+                    switch (counter++)
+                    {
+                        case 0:
+                            var keyExchangeInitMessage = new KeyExchangeInitMessage
+                                {
+                                    CompressionAlgorithmsClientToServer = new string[0],
+                                    CompressionAlgorithmsServerToClient = new string[0],
+                                    EncryptionAlgorithmsClientToServer = new string[0],
+                                    EncryptionAlgorithmsServerToClient = new string[0],
+                                    KeyExchangeAlgorithms = new[] {_keyExchangeAlgorithm},
+                                    LanguagesClientToServer = new string[0],
+                                    LanguagesServerToClient = new string[0],
+                                    MacAlgorithmsClientToServer = new string[0],
+                                    MacAlgorithmsServerToClient = new string[0],
+                                    ServerHostKeyAlgorithms = new string[0]
+                                };
+                            var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null);
+                            ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None);
+                            break;
+                        case 1:
+                            var serviceAcceptMessage =
+                                ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication)
+                                    .Build();
+                            ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None);
+                            break;
+                    }
+                };
+            ServerListener.Start();
+
+            Session.Connect();
+        }
+
+        protected abstract void Act();
+
+        private class ServiceAcceptMessageBuilder
+        {
+            private readonly ServiceName _serviceName;
+
+            private ServiceAcceptMessageBuilder(ServiceName serviceName)
+            {
+                _serviceName = serviceName;
+            }
+
+            public static ServiceAcceptMessageBuilder Create(ServiceName serviceName)
+            {
+                return new ServiceAcceptMessageBuilder(serviceName);
+            }
+
+            public byte[] Build()
+            {
+                var serviceName = _serviceName.ToArray();
+
+                var sshDataStream = new SshDataStream(4 + 1 + 1 + 4 + serviceName.Length);
+                sshDataStream.Write((uint)(sshDataStream.Capacity - 4)); // packet length
+                sshDataStream.WriteByte(0); // padding length
+                sshDataStream.WriteByte(ServiceAcceptMessage.MessageNumber);
+                sshDataStream.WriteBinary(serviceName);
+                return sshDataStream.ToArray();
+            }
+        }
+    }
+}

+ 168 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_ConnectionReset.cs

@@ -0,0 +1,168 @@
+using System.Diagnostics;
+using System.Net.Sockets;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connected_ConnectionReset : SessionTest_ConnectedBase
+    {
+        protected override void Act()
+        {
+            ServerSocket.Close();
+
+            // give session some time to react to connection reset
+            Thread.Sleep(200);
+        }
+
+        [TestMethodAttribute]
+        public void IsConnectedShouldReturnFalse()
+        {
+            Assert.IsFalse(Session.IsConnected);
+        }
+
+        [TestMethodAttribute]
+        public void DisconnectShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Disconnect();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void DisconnectedIsNeverRaised()
+        {
+            Assert.AreEqual(0, DisconnectedRegister.Count);
+        }
+
+        [TestMethod]
+        public void DisconnectReceivedIsNeverRaised()
+        {
+            Assert.AreEqual(0, DisconnectReceivedRegister.Count);
+        }
+
+        [TestMethod]
+        public void ErrorOccurredIsRaisedOnce()
+        {
+            Assert.AreEqual(1, ErrorOccurredRegister.Count);
+
+            var errorOccurred = ErrorOccurredRegister[0];
+            Assert.IsNotNull(errorOccurred);
+
+            var exception = errorOccurred.Exception;
+            Assert.IsNotNull(exception);
+            Assert.AreEqual(typeof(SshConnectionException), exception.GetType());
+
+            var connectionException = (SshConnectionException) exception;
+            Assert.AreEqual(DisconnectReason.ConnectionLost, connectionException.DisconnectReason);
+
+            var innerException = exception.InnerException;
+            Assert.IsNotNull(innerException);
+            Assert.AreEqual(typeof(SocketException), innerException.GetType());
+
+            var socketException = (SocketException) innerException;
+            Assert.AreEqual(SocketError.ConnectionReset, socketException.SocketErrorCode);
+
+            Assert.AreSame(innerException.Message, connectionException.Message);
+        }
+
+        [TestMethodAttribute]
+        public void DisposeShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Dispose();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethodAttribute]
+        public void SendMessageShouldThrowSshConnectionException()
+        {
+            try
+            {
+                Session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethodAttribute]
+        public void ISession_MessageListenerCompletedShouldBeSignaled()
+        {
+            var session = (ISession) Session;
+
+            Assert.IsNotNull(session.MessageListenerCompleted);
+            Assert.IsTrue(session.MessageListenerCompleted.WaitOne());
+        }
+
+        [TestMethodAttribute]
+        public void ISession_SendMessageShouldThrowSshConnectionException()
+        {
+            var session = (ISession) Session;
+
+            try
+            {
+                session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethodAttribute]
+        public void ISession_TrySendMessageShouldReturnFalse()
+        {
+            var session = (ISession) Session;
+
+            var actual = session.TrySendMessage(new IgnoreMessage());
+
+            Assert.IsFalse(actual);
+        }
+
+        [TestMethodAttribute]
+        public void ISession_WaitOnHandleShouldThrowSshConnectionExceptionDetailingConnectionReset()
+        {
+            var session = (ISession) Session;
+            var waitHandle = new ManualResetEvent(false);
+
+            try
+            {
+                session.WaitOnHandle(waitHandle);
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason);
+
+                var innerException = ex.InnerException;
+                Assert.IsNotNull(innerException);
+                Assert.AreEqual(typeof(SocketException), innerException.GetType());
+
+                var socketException = (SocketException)ex.InnerException;
+                Assert.AreEqual(SocketError.ConnectionReset, socketException.SocketErrorCode);
+
+                Assert.AreSame(innerException.Message, ex.Message);
+
+            }
+        }
+    }
+}

+ 148 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_Disconnect.cs

@@ -0,0 +1,148 @@
+using System.Diagnostics;
+using System.Net.Sockets;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connected_Disconnect : SessionTest_ConnectedBase
+    {
+        protected override void Act()
+        {
+            Session.Disconnect();
+        }
+
+        [TestMethod]
+        public void IsConnectedShouldReturnFalse()
+        {
+            Assert.IsFalse(Session.IsConnected);
+        }
+
+        [TestMethod]
+        public void DisconnectShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Disconnect();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void DisconnectedIsNeverRaised()
+        {
+            Assert.AreEqual(0, DisconnectedRegister.Count);
+        }
+
+        [TestMethod]
+        public void DisconnectReceivedIsNeverRaised()
+        {
+            Assert.AreEqual(0, DisconnectReceivedRegister.Count);
+        }
+
+        [TestMethod]
+        public void ErrorOccurredIsNeverRaised()
+        {
+            Assert.AreEqual(0, ErrorOccurredRegister.Count);
+        }
+
+        [TestMethod]
+        public void DisposeShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Dispose();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void ReceiveOnServerSocketShouldReturnZero()
+        {
+            var buffer = new byte[1];
+
+            var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
+
+            Assert.AreEqual(0, actual);
+        }
+
+        [TestMethod]
+        public void SendMessageShouldThrowSshConnectionException()
+        {
+            try
+            {
+                Session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_MessageListenerCompletedShouldBeSignaled()
+        {
+            var session = (ISession)Session;
+
+            Assert.IsNotNull(session.MessageListenerCompleted);
+            Assert.IsTrue(session.MessageListenerCompleted.WaitOne());
+        }
+
+        [TestMethodAttribute]
+        public void ISession_SendMessageShouldThrowSshConnectionException()
+        {
+            var session = (ISession) Session;
+
+            try
+            {
+                session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_TrySendMessageShouldReturnFalse()
+        {
+            var session = (ISession) Session;
+
+            var actual = session.TrySendMessage(new IgnoreMessage());
+
+            Assert.IsFalse(actual);
+        }
+
+        [TestMethod]
+        public void ISession_WaitOnHandleShouldThrowSshConnectionExceptionDetailingBadPacket()
+        {
+            var session = (ISession) Session;
+            var waitHandle = new ManualResetEvent(false);
+
+            try
+            {
+                session.WaitOnHandle(waitHandle);
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+    }
+}

+ 164 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsBadPacket.cs

@@ -0,0 +1,164 @@
+using System.Diagnostics;
+using System.Net.Sockets;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connected_ServerSendsBadPacket : SessionTest_ConnectedBase
+    {
+        protected override void Act()
+        {
+            var badPacket = new byte[] { 0x0a, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05 };
+            ServerSocket.Send(badPacket, 0, badPacket.Length, SocketFlags.None);
+
+            // give session some time to react to bad packet
+            Thread.Sleep(200);
+        }
+
+        [TestMethod]
+        public void IsConnectedShouldReturnFalse()
+        {
+            Assert.IsFalse(Session.IsConnected);
+        }
+
+        [TestMethod]
+        public void DisconnectShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Disconnect();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void DisconnectedIsNeverRaised()
+        {
+            Assert.AreEqual(0, DisconnectedRegister.Count);
+        }
+
+        [TestMethod]
+        public void DisconnectReceivedIsNeverRaised()
+        {
+            Assert.AreEqual(0, DisconnectReceivedRegister.Count);
+        }
+
+        [TestMethod]
+        public void ErrorOccurredIsRaisedOnce()
+        {
+            Assert.AreEqual(1, ErrorOccurredRegister.Count);
+
+            var errorOccurred = ErrorOccurredRegister[0];
+            Assert.IsNotNull(errorOccurred);
+
+            var exception = errorOccurred.Exception;
+            Assert.IsNotNull(exception);
+            Assert.AreEqual(typeof (SshConnectionException), exception.GetType());
+
+            var connectionException = (SshConnectionException) exception;
+            Assert.AreEqual(DisconnectReason.ProtocolError, connectionException.DisconnectReason);
+            Assert.IsNull(connectionException.InnerException);
+            Assert.AreEqual("Bad packet length: 168101125.", connectionException.Message);
+        }
+
+        [TestMethod]
+        public void DisposeShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Dispose();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void ReceiveOnServerSocketShouldReturnZero()
+        {
+            var buffer = new byte[1];
+
+            var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
+
+            Assert.AreEqual(0, actual);
+        }
+
+        [TestMethod]
+        public void SendMessageShouldThrowSshConnectionException()
+        {
+            try
+            {
+                Session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_MessageListenerCompletedShouldBeSignaled()
+        {
+            var session = (ISession)Session;
+
+            Assert.IsNotNull(session.MessageListenerCompleted);
+            Assert.IsTrue(session.MessageListenerCompleted.WaitOne());
+        }
+
+        [TestMethodAttribute]
+        public void ISession_SendMessageShouldThrowSshConnectionException()
+        {
+            var session = (ISession) Session;
+
+            try
+            {
+                session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_TrySendMessageShouldReturnFalse()
+        {
+            var session = (ISession) Session;
+
+            var actual = session.TrySendMessage(new IgnoreMessage());
+
+            Assert.IsFalse(actual);
+        }
+
+        [TestMethod]
+        public void ISession_WaitOnHandleShouldThrowSshConnectionExceptionDetailingBadPacket()
+        {
+            var session = (ISession) Session;
+            var waitHandle = new ManualResetEvent(false);
+
+            try
+            {
+                session.WaitOnHandle(waitHandle);
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.ProtocolError, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Bad packet length: 168101125.", ex.Message);
+            }
+        }
+    }
+}

+ 168 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsDisconnectMessage.cs

@@ -0,0 +1,168 @@
+using System.Diagnostics;
+using System.Globalization;
+using System.Net.Sockets;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connected_ServerSendsDisconnectMessage : SessionTest_ConnectedBase
+    {
+        private DisconnectMessage _disconnectMessage;
+
+        protected override void Arrange()
+        {
+            _disconnectMessage = new DisconnectMessage(DisconnectReason.ServiceNotAvailable, "Not today!");
+
+            base.Arrange();
+        }
+
+        protected override void Act()
+        {
+            var disconnect = _disconnectMessage.GetPacket(8, null);
+            ServerSocket.Send(disconnect, 4, disconnect.Length - 4, SocketFlags.None);
+
+            // give session some time to process DisconnectMessage
+            Thread.Sleep(200);
+        }
+
+        [TestMethod]
+        public void IsConnectedShouldReturnFalse()
+        {
+            Assert.IsFalse(Session.IsConnected);
+        }
+
+        [TestMethod]
+        public void DisconnectShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Disconnect();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void DisconnectedIsRaisedOnce()
+        {
+            Assert.AreEqual(1, DisconnectedRegister.Count);
+        }
+
+        [TestMethod]
+        public void DisconnectReceivedIsRaisedOnce()
+        {
+            Assert.AreEqual(1, DisconnectReceivedRegister.Count);
+
+            var disconnectMessage = DisconnectReceivedRegister[0].Message;
+            Assert.IsNotNull(disconnectMessage);
+            Assert.AreEqual(_disconnectMessage.Description, disconnectMessage.Description);
+            Assert.AreEqual("en", disconnectMessage.Language);
+            Assert.AreEqual(_disconnectMessage.ReasonCode, disconnectMessage.ReasonCode);
+        }
+
+        [TestMethod]
+        public void ErrorOccurredIsNeverRaised()
+        {
+            Assert.AreEqual(0, ErrorOccurredRegister.Count);
+        }
+
+        [TestMethod]
+        public void DisposeShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Dispose();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void ReceiveOnServerSocketShouldReturnZero()
+        {
+            var buffer = new byte[1];
+
+            var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
+
+            Assert.AreEqual(0, actual);
+        }
+
+        [TestMethod]
+        public void SendMessageShouldThrowSshConnectionException()
+        {
+            try
+            {
+                Session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_MessageListenerCompletedShouldBeSignaled()
+        {
+            var session = (ISession)Session;
+
+            Assert.IsNotNull(session.MessageListenerCompleted);
+            Assert.IsTrue(session.MessageListenerCompleted.WaitOne());
+        }
+
+        [TestMethodAttribute]
+        public void ISession_SendMessageShouldThrowSshConnectionException()
+        {
+            var session = (ISession) Session;
+
+            try
+            {
+                session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_TrySendMessageShouldReturnFalse()
+        {
+            var session = (ISession) Session;
+
+            var actual = session.TrySendMessage(new IgnoreMessage());
+
+            Assert.IsFalse(actual);
+        }
+
+        [TestMethod]
+        public void ISession_WaitOnHandleShouldThrowSshConnectionExceptionDetailingDisconnectReason()
+        {
+            var session = (ISession)Session;
+            var waitHandle = new ManualResetEvent(false);
+
+            try
+            {
+                session.WaitOnHandle(waitHandle);
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.ServiceNotAvailable, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "The connection was closed by the server: {0} ({1}).", _disconnectMessage.Description, _disconnectMessage.ReasonCode), ex.Message);
+            }
+        }
+    }
+}

+ 172 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsDisconnectMessageAndShutsDownSocket.cs

@@ -0,0 +1,172 @@
+using System.Diagnostics;
+using System.Globalization;
+using System.Net.Sockets;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connected_ServerSendsDisconnectMessageAndShutsDownSocket : SessionTest_ConnectedBase
+    {
+        private DisconnectMessage _disconnectMessage;
+
+        protected override void Arrange()
+        {
+            _disconnectMessage = new DisconnectMessage(DisconnectReason.ServiceNotAvailable, "Not today!");
+
+            base.Arrange();
+        }
+
+        protected override void Act()
+        {
+            // server sends SSH_MSG_DISCONNECT
+            var disconnect = _disconnectMessage.GetPacket(8, null);
+            ServerSocket.Send(disconnect, 4, disconnect.Length - 4, SocketFlags.None);
+
+            // server shuts down the socket
+            ServerSocket.Shutdown(SocketShutdown.Send);
+
+            // give session some time to process DisconnectMessage and socket shutdown
+            Thread.Sleep(200);
+        }
+
+        [TestMethod]
+        public void IsConnectedShouldReturnFalse()
+        {
+            Assert.IsFalse(Session.IsConnected);
+        }
+
+        [TestMethod]
+        public void DisconnectShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Disconnect();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void DisconnectedIsRaisedOnce()
+        {
+            Assert.AreEqual(1, DisconnectedRegister.Count);
+        }
+
+        [TestMethod]
+        public void DisconnectReceivedIsRaisedOnce()
+        {
+            Assert.AreEqual(1, DisconnectReceivedRegister.Count);
+
+            var disconnectMessage = DisconnectReceivedRegister[0].Message;
+            Assert.IsNotNull(disconnectMessage);
+            Assert.AreEqual(_disconnectMessage.Description, disconnectMessage.Description);
+            Assert.AreEqual("en", disconnectMessage.Language);
+            Assert.AreEqual(_disconnectMessage.ReasonCode, disconnectMessage.ReasonCode);
+        }
+
+        [TestMethod]
+        public void ErrorOccurredIsNeverRaised()
+        {
+            Assert.AreEqual(0, ErrorOccurredRegister.Count);
+        }
+
+        [TestMethod]
+        public void DisposeShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Dispose();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void ReceiveOnServerSocketShouldReturnZero()
+        {
+            var buffer = new byte[1];
+
+            var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
+
+            Assert.AreEqual(0, actual);
+        }
+
+        [TestMethod]
+        public void SendMessageShouldThrowSshConnectionException()
+        {
+            try
+            {
+                Session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_MessageListenerCompletedShouldBeSignaled()
+        {
+            var session = (ISession)Session;
+
+            Assert.IsNotNull(session.MessageListenerCompleted);
+            Assert.IsTrue(session.MessageListenerCompleted.WaitOne());
+        }
+
+        [TestMethodAttribute]
+        public void ISession_SendMessageShouldThrowSshConnectionException()
+        {
+            var session = (ISession) Session;
+
+            try
+            {
+                session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_TrySendMessageShouldReturnFalse()
+        {
+            var session = (ISession)Session;
+
+            var actual = session.TrySendMessage(new IgnoreMessage());
+
+            Assert.IsFalse(actual);
+        }
+
+        [TestMethod]
+        public void ISession_WaitOnHandleShouldThrowSshConnectionExceptionDetailingDisconnectReason()
+        {
+            var session = (ISession)Session;
+            var waitHandle = new ManualResetEvent(false);
+
+            try
+            {
+                session.WaitOnHandle(waitHandle);
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.ServiceNotAvailable, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "The connection was closed by the server: {0} ({1}).", _disconnectMessage.Description, _disconnectMessage.ReasonCode), ex.Message);
+            }
+        }
+    }
+}

+ 161 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerShutsDownSocket.cs

@@ -0,0 +1,161 @@
+using System;
+using System.Diagnostics;
+using System.Net.Sockets;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connected_ServerShutsDownSocket : SessionTest_ConnectedBase
+    {
+        protected override void Act()
+        {
+            ServerSocket.Shutdown(SocketShutdown.Send);
+
+            // give session some time to process socket shutdown
+            Thread.Sleep(200);
+        }
+
+        [TestMethod]
+        public void IsConnectedShouldReturnFalse()
+        {
+            Assert.IsFalse(Session.IsConnected);
+        }
+
+        [TestMethod]
+        public void DisconnectShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Disconnect();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void DisconnectedIsNeverRaised()
+        {
+            Assert.AreEqual(0, DisconnectedRegister.Count);
+        }
+
+        [TestMethod]
+        public void DisconnectReceivedIsNeverRaised()
+        {
+            Assert.AreEqual(0, DisconnectReceivedRegister.Count);
+        }
+
+        [TestMethod]
+        public void ErrorOccurredIsRaisedOnce()
+        {
+            Assert.AreEqual(1, ErrorOccurredRegister.Count);
+
+            var errorOccurred = ErrorOccurredRegister[0];
+            Assert.IsNotNull(errorOccurred);
+
+            var exception = errorOccurred.Exception;
+            Assert.IsNotNull(exception);
+            Assert.AreEqual(typeof(SshConnectionException), exception.GetType());
+
+            var connectionException = (SshConnectionException) exception;
+            Assert.AreEqual(DisconnectReason.ConnectionLost, connectionException.DisconnectReason);
+            Assert.IsNull(connectionException.InnerException);
+            Assert.AreEqual("An established connection was aborted by the server.", connectionException.Message);
+        }
+
+        [TestMethod]
+        public void DisposeShouldFinishImmediately()
+        {
+            var stopwatch = new Stopwatch();
+            stopwatch.Start();
+
+            Session.Dispose();
+
+            stopwatch.Stop();
+            Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
+        }
+
+        [TestMethod]
+        public void ReceiveOnServerSocketShouldReturnZero()
+        {
+            var buffer = new byte[1];
+
+            var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
+
+            Assert.AreEqual(0, actual);
+        }
+
+        [TestMethod]
+        public void SendMessageShouldThrowSshConnectionException()
+        {
+            try
+            {
+                Session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_MessageListenerCompletedShouldBeSignaled()
+        {
+            var session = (ISession)Session;
+
+            Assert.IsNotNull(session.MessageListenerCompleted);
+            Assert.IsTrue(session.MessageListenerCompleted.WaitOne());
+        }
+
+        [TestMethodAttribute]
+        public void ISession_SendMessageShouldThrowSshConnectionException()
+        {
+            var session = (ISession) Session;
+
+            try
+            {
+                session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_TrySendMessageShouldReturnFalse()
+        {
+            var session = (ISession)Session;
+
+            var actual = session.TrySendMessage(new IgnoreMessage());
+
+            Assert.IsFalse(actual);
+        }
+
+        [TestMethod]
+        public void ISession_WaitOnHandleShouldThrowSshConnectionExceptionDetailingAbortedConnection()
+        {
+            var session = (ISession)Session;
+            var waitHandle = new ManualResetEvent(false);
+
+            try
+            {
+                session.WaitOnHandle(waitHandle);
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("An established connection was aborted by the server.", ex.Message);
+            }
+        }
+    }
+}

+ 191 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs

@@ -0,0 +1,191 @@
+using System;
+using System.Net;
+using System.Threading;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_NotConnected
+    {
+        private ConnectionInfo _connectionInfo;
+        private IServiceFactory _serviceFactory;
+        private Session _session;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        protected void Arrange()
+        {
+            var serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+            _connectionInfo = CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5));
+            _serviceFactory = new Mock<IServiceFactory>(MockBehavior.Strict).Object;
+        }
+
+        protected void Act()
+        {
+            _session = new Session(_connectionInfo, _serviceFactory);
+        }
+
+        [TestMethod]
+        public void ClientVersionIsRenciSshNet()
+        {
+            Assert.AreEqual("SSH-2.0-Renci.SshNet.SshClient.0.0.1", _session.ClientVersion);
+        }
+
+        [TestMethod]
+        public void ConnectionInfoShouldReturnConnectionInfoPassedThroughConstructor()
+        {
+            Assert.AreSame(_connectionInfo, _session.ConnectionInfo);
+        }
+
+        [TestMethod]
+        public void IsConnectedShouldReturnFalse()
+        {
+            Assert.IsFalse(_session.IsConnected);
+        }
+
+        [TestMethod]
+        public void SendMessageShouldThrowShhConnectionException()
+        {
+            try
+            {
+                _session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void SessionIdShouldReturnNull()
+        {
+            Assert.IsNull(_session.SessionId);
+        }
+
+        [TestMethod]
+        public void ServerVersionShouldReturnNull()
+        {
+            Assert.IsNull(_session.ServerVersion);
+        }
+
+        [TestMethod]
+        public void WaitOnHandle_WaitHandle_ShouldThrowArgumentNullExceptionWhenWaitHandleIsNull()
+        {
+            WaitHandle waitHandle = null;
+
+            try
+            {
+                _session.WaitOnHandle(waitHandle);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("waitHandle", ex.ParamName);
+            }
+        }
+
+        [TestMethod]
+        public void WaitOnHandle_WaitHandleAndTimeout_ShouldThrowArgumentNullExceptionWhenWaitHandleIsNull()
+        {
+            WaitHandle waitHandle = null;
+            var timeout = TimeSpan.FromMinutes(5);
+
+            try
+            {
+                _session.WaitOnHandle(waitHandle, timeout);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("waitHandle", ex.ParamName);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_ConnectionInfoShouldReturnConnectionInfoPassedThroughConstructor()
+        {
+            var session = (ISession)_session;
+            Assert.AreSame(_connectionInfo, session.ConnectionInfo);
+        }
+
+        [TestMethod]
+        public void ISession_MessageListenerCompletedShouldBeSignaled()
+        {
+            var session = (ISession) _session;
+
+            Assert.IsNotNull(session.MessageListenerCompleted);
+            Assert.IsTrue(session.MessageListenerCompleted.WaitOne(0));
+        }
+
+        [TestMethod]
+        public void ISession_SendMessageShouldThrowShhConnectionException()
+        {
+            var session = (ISession)_session;
+
+            try
+            {
+                session.SendMessage(new IgnoreMessage());
+                Assert.Fail();
+            }
+            catch (SshConnectionException ex)
+            {
+                Assert.AreEqual(DisconnectReason.None, ex.DisconnectReason);
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("Client not connected.", ex.Message);
+            }
+        }
+
+        [TestMethod]
+        public void ISession_TrySendMessageShouldReturnFalse()
+        {
+            var session = (ISession)_session;
+
+            var actual = session.TrySendMessage(new IgnoreMessage());
+
+            Assert.IsFalse(actual);
+        }
+
+        [TestMethod]
+        public void ISession_WaitOnHandleShouldThrowArgumentNullExceptionWhenWaitHandleIsNull()
+        {
+            WaitHandle waitHandle = null;
+            var session = (ISession)_session;
+
+            try
+            {
+                session.WaitOnHandle(waitHandle);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("waitHandle", ex.ParamName);
+            }
+        }
+
+        private static ConnectionInfo CreateConnectionInfo(IPEndPoint serverEndPoint, TimeSpan timeout)
+        {
+            var connectionInfo = new ConnectionInfo(
+                serverEndPoint.Address.ToString(),
+                serverEndPoint.Port,
+                "eric",
+                new NoneAuthenticationMethod("eric"));
+            connectionInfo.Timeout = timeout;
+            return connectionInfo;
+        }
+    }
+}

+ 95 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/SessionTest_SocketConnected_BadPacketAndDispose.cs

@@ -0,0 +1,95 @@
+using System;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Tests.Common;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_SocketConnected_BadPacketAndDispose
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private ConnectionInfo _connectionInfo;
+        private Session _session;
+        private AsyncSocketListener _serverListener;
+        private IPEndPoint _serverEndPoint;
+        private Socket _serverSocket;
+        private SshConnectionException _actualException;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void TearDown()
+        {
+            if (_serverListener != null)
+            {
+                _serverListener.Dispose();
+            }
+        }
+
+        protected void Arrange()
+        {
+            _serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+            _connectionInfo = new ConnectionInfo(
+                _serverEndPoint.Address.ToString(),
+                _serverEndPoint.Port,
+                "user",
+                new PasswordAuthenticationMethod("user", "password"));
+            _connectionInfo.Timeout = TimeSpan.FromMilliseconds(200);
+            _actualException = null;
+
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+
+            _serverListener = new AsyncSocketListener(_serverEndPoint);
+            _serverListener.Connected += (socket) =>
+                {
+                    _serverSocket = socket;
+
+                    socket.Send(Encoding.ASCII.GetBytes("\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
+                    socket.Send(Encoding.ASCII.GetBytes("SSH-2.0-SshStub\r\n"));
+                };
+            _serverListener.BytesReceived += (received, socket) =>
+                {
+                    var badPacket = new byte[] {0x0a, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05};
+                    _serverSocket.Send(badPacket, 0, badPacket.Length, SocketFlags.None);
+                    _serverSocket.Shutdown(SocketShutdown.Send);
+                };
+            _serverListener.Start();
+        }
+
+        protected virtual void Act()
+        {
+            try
+            {
+                using (_session = new Session(_connectionInfo, _serviceFactoryMock.Object))
+                {
+                    _session.Connect();
+                }
+            }
+            catch (SshConnectionException ex)
+            {
+                _actualException = ex;
+            }
+        }
+
+        [TestMethod]
+        public void ConnectShouldThrowSshConnectionException()
+        {
+            Assert.IsNotNull(_actualException);
+            Assert.IsNull(_actualException.InnerException);
+            Assert.AreEqual(DisconnectReason.ProtocolError, _actualException.DisconnectReason);
+            Assert.AreEqual("Bad packet length: 168101125.", _actualException.Message);
+        }
+    }
+}

+ 1 - 1
Renci.SshClient/Renci.SshNet.Tests/Classes/Sftp/Responses/SftpHandleResponseTest.cs

@@ -1,4 +1,4 @@
-using System;
+ using System;
 using System.Linq;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
 using Renci.SshNet.Common;

+ 15 - 0
Renci.SshClient/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj

@@ -112,6 +112,7 @@
     <Compile Include="Classes\Channels\ClientChannelTest_OnSessionChannelOpenConfirmationReceived_OnOpenConfirmation_Exception.cs" />
     <Compile Include="Classes\Channels\ClientChannelTest_OnSessionChannelOpenFailureReceived_OnOpenFailure_Exception.cs" />
     <Compile Include="Classes\CipherInfoTest.cs" />
+    <Compile Include="Classes\ClientAuthenticationTest.cs" />
     <Compile Include="Classes\ClientAuthenticationTestBase.cs" />
     <Compile Include="Classes\ClientAuthenticationTest_Failure_SingleList_AuthenticationMethodFailed.cs" />
     <Compile Include="Classes\ClientAuthenticationTest_Failure_SingleList_AuthenticationMethodNotConfigured.cs" />
@@ -120,6 +121,8 @@
     <Compile Include="Classes\ClientAuthenticationTest_Success_MultiList_SameAllowedAuthenticationsAfterPartialSuccess.cs" />
     <Compile Include="Classes\ClientAuthenticationTest_Success_MultiList_SkipFailedAuthenticationMethod.cs" />
     <Compile Include="Classes\ClientAuthenticationTest_Success_SingleList_SameAllowedAuthenticationAfterPartialSuccess.cs" />
+    <Compile Include="Classes\ConnectionInfoTest_Authenticate_Failure.cs" />
+    <Compile Include="Classes\ConnectionInfoTest_Authenticate_Success.cs" />
     <Compile Include="Classes\ForwardedPortDynamicTest_Dispose_PortStarted_ChannelBound.cs" />
     <Compile Include="Classes\ForwardedPortDynamicTest_Start_SessionNotConnected.cs" />
     <Compile Include="Classes\ForwardedPortDynamicTest_Start_SessionNull.cs" />
@@ -169,6 +172,8 @@
     <Compile Include="Classes\ForwardedPortRemoteTest_Stop_PortNeverStarted.cs" />
     <Compile Include="Classes\ForwardedPortRemoteTest_Stop_PortStarted_ChannelBound.cs" />
     <Compile Include="Classes\ForwardedPortRemoteTest_Stop_PortStopped.cs" />
+    <Compile Include="Classes\Messages\Transport\KeyExchangeDhGroupExchangeGroupBuilder.cs" />
+    <Compile Include="Classes\Messages\Transport\KeyExchangeDhGroupExchangeReplyBuilder.cs" />
     <Compile Include="Classes\NetConfClientTest_Dispose_Connected.cs" />
     <Compile Include="Classes\NetConfClientTest_Dispose_Disconnected.cs" />
     <Compile Include="Classes\NetConfClientTest_Dispose_Disposed.cs" />
@@ -214,6 +219,16 @@
     <Compile Include="Classes\Common\PortForwardEventArgsTest.cs" />
     <Compile Include="Classes\Compression\ZlibTest.cs" />
     <Compile Include="Classes\ConnectionInfoTest.cs" />
+    <Compile Include="Classes\SessionTest_Connected.cs" />
+    <Compile Include="Classes\SessionTest_ConnectedBase.cs" />
+    <Compile Include="Classes\SessionTest_Connected_ConnectionReset.cs" />
+    <Compile Include="Classes\SessionTest_Connected_Disconnect.cs" />
+    <Compile Include="Classes\SessionTest_Connected_ServerSendsDisconnectMessage.cs" />
+    <Compile Include="Classes\SessionTest_Connected_ServerSendsBadPacket.cs" />
+    <Compile Include="Classes\SessionTest_Connected_ServerSendsDisconnectMessageAndShutsDownSocket.cs" />
+    <Compile Include="Classes\SessionTest_Connected_ServerShutsDownSocket.cs" />
+    <Compile Include="Classes\SessionTest_NotConnected.cs" />
+    <Compile Include="Classes\SessionTest_SocketConnected_BadPacketAndDispose.cs" />
     <Compile Include="Classes\SftpClientTest.SynchronizeDirectories.cs" />
     <Compile Include="Classes\SftpClientTest.cs" />
     <Compile Include="Classes\SftpClientTest.Download.cs" />

+ 2 - 2
Renci.SshClient/Renci.SshNet/BaseClient.cs

@@ -155,8 +155,8 @@ namespace Renci.SshNet
         /// <param name="connectionInfo">The connection info.</param>
         /// <param name="ownsConnectionInfo">Specified whether this instance owns the connection info.</param>
         /// <param name="serviceFactory">The factory to use for creating new services.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is null.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="serviceFactory"/> is null.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serviceFactory"/> is <c>null</c>.</exception>
         /// <remarks>
         /// If <paramref name="ownsConnectionInfo"/> is <c>true</c>, then the
         /// connection info will be disposed when this instance is disposed.

+ 1 - 1
Renci.SshClient/Renci.SshNet/ClientAuthentication.cs

@@ -5,7 +5,7 @@ using Renci.SshNet.Common;
 
 namespace Renci.SshNet
 {
-    internal class ClientAuthentication
+    internal class ClientAuthentication : IClientAuthentication
     {
         public void Authenticate(IConnectionInfoInternal connectionInfo, ISession session)
         {

+ 0 - 10
Renci.SshClient/Renci.SshNet/Common/SshConnectionException.cs

@@ -52,15 +52,5 @@ namespace Renci.SshNet.Common
         {
             DisconnectReason = disconnectReasonCode;
         }
-
-        /// <summary>
-        /// Initializes a new instance of the <see cref="SshConnectionException"/> class.
-        /// </summary>
-        /// <param name="message">The message.</param>
-        /// <param name="innerException">The inner exception.</param>
-        public SshConnectionException(string message, Exception innerException) :
-            base(message, innerException)
-        {
-        }
     }
 }

+ 24 - 87
Renci.SshClient/Renci.SshNet/Common/SshData.cs

@@ -337,16 +337,6 @@ namespace Renci.SshNet.Common
             return ((ulong)data[0] << 56 | (ulong)data[1] << 48 | (ulong)data[2] << 40 | (ulong)data[3] << 32 | (ulong)data[4] << 24 | (ulong)data[5] << 16 | (ulong)data[6] << 8 | data[7]);
         }
 
-        /// <summary>
-        /// Reads next int64 data type from internal buffer.
-        /// </summary>
-        /// <returns>int64 read</returns>
-        protected long ReadInt64()
-        {
-            var data = ReadBytes(8);
-            return (int)(data[0] << 56 | data[1] << 48 | data[2] << 40 | data[3] << 32 | data[4] << 24 | data[5] << 16 | data[6] << 8 | data[7]);
-        }
-
 #if !TUNING
         /// <summary>
         /// Reads next string data type from internal buffer.
@@ -379,6 +369,9 @@ namespace Renci.SshNet.Common
         /// <returns>string read</returns>
         protected string ReadString(Encoding encoding)
         {
+#if TUNING
+            return _stream.ReadString(encoding);
+#else
             var length = ReadUInt32();
 
             if (length > int.MaxValue)
@@ -386,6 +379,7 @@ namespace Renci.SshNet.Common
                 throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Strings longer than {0} is not supported.", int.MaxValue));
             }
             return encoding.GetString(ReadBytes((int)length), 0, (int)length);
+#endif
         }
 
 #if TUNING
@@ -397,14 +391,7 @@ namespace Renci.SshNet.Common
         /// </returns>
         protected byte[] ReadBinary()
         {
-            var length = ReadUInt32();
-
-            if (length > int.MaxValue)
-            {
-                throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Data longer than {0} is not supported.", int.MaxValue));
-            }
-
-            return ReadBytes((int) length);
+            return _stream.ReadBinary();
         }
 #else
         /// <summary>
@@ -424,23 +411,6 @@ namespace Renci.SshNet.Common
         }
 #endif
 
-        /// <summary>
-        /// Reads next mpint data type from internal buffer.
-        /// </summary>
-        /// <returns>mpint read.</returns>
-        protected BigInteger ReadBigInt()
-        {
-            var length = ReadUInt32();
-
-            var data = ReadBytes((int)length);
-
-#if TUNING
-            return new BigInteger(data.Reverse());
-#else
-            return new BigInteger(data.Reverse().ToArray());
-#endif
-        }
-
         /// <summary>
         /// Reads next name-list data type from internal buffer.
         /// </summary>
@@ -475,7 +445,7 @@ namespace Renci.SshNet.Common
         /// <exception cref="ArgumentNullException"><paramref name="data"/> is null.</exception>
         protected void Write(byte[] data)
         {
-            _stream.Write(data, 0, data.Length);
+            _stream.Write(data);
         }
 #else
         /// <summary>
@@ -528,22 +498,17 @@ namespace Renci.SshNet.Common
             Write(data ? (byte) 1 : (byte) 0);
         }
 
-        /// <summary>
-        /// Writes uint16 data into internal buffer.
-        /// </summary>
-        /// <param name="data">uint16 data to write.</param>
-        protected void Write(ushort data)
-        {
-            Write(data.GetBytes());
-        }
-
         /// <summary>
         /// Writes uint32 data into internal buffer.
         /// </summary>
         /// <param name="data">uint32 data to write.</param>
         protected void Write(uint data)
         {
+#if TUNING
+            _stream.Write(data);
+#else
             Write(data.GetBytes());
+#endif
         }
 
         /// <summary>
@@ -552,26 +517,11 @@ namespace Renci.SshNet.Common
         /// <param name="data">uint64 data to write.</param>
         protected void Write(ulong data)
         {
+#if TUNING
+            _stream.Write(data);
+#else
             Write(data.GetBytes());
-        }
-
-        /// <summary>
-        /// Writes int64 data into internal buffer.
-        /// </summary>
-        /// <param name="data">int64 data to write.</param>
-        protected void Write(long data)
-        {
-            Write(data.GetBytes());
-        }
-
-
-        /// <summary>
-        /// Writes string data into internal buffer as ASCII.
-        /// </summary>
-        /// <param name="data">string data to write.</param>
-        protected void WriteAscii(string data)
-        {
-            Write(data, Ascii);
+#endif
         }
 
         /// <summary>
@@ -593,17 +543,16 @@ namespace Renci.SshNet.Common
         /// <exception cref="ArgumentNullException"><paramref name="encoding"/> is null.</exception>
         protected void Write(string data, Encoding encoding)
         {
+#if TUNING
+            _stream.Write(data, encoding);
+#else
             if (data == null)
                 throw new ArgumentNullException("data");
             if (encoding == null)
                 throw new ArgumentNullException("encoding");
 
             var bytes = encoding.GetBytes(data);
-#if TUNING
-            var bytesLength = bytes.Length;
-            Write((uint) bytesLength);
-            Write(bytes, 0, bytesLength);
-#else
+
             Write((uint)bytes.Length);
             Write(bytes);
 #endif
@@ -617,12 +566,7 @@ namespace Renci.SshNet.Common
         /// <exception cref="ArgumentNullException"><paramref name="buffer"/> is null.</exception>
         protected void WriteBinaryString(byte[] buffer)
         {
-            if (buffer == null)
-                throw new ArgumentNullException("buffer");
-
-            var bufferLength = buffer.Length;
-            Write((uint)bufferLength);
-            Write(buffer, 0, bufferLength);
+            _stream.WriteBinary(buffer);
         }
 
         /// <summary>
@@ -636,11 +580,7 @@ namespace Renci.SshNet.Common
         /// <exception cref="ArgumentOutOfRangeException"><paramref name="offset"/> or <paramref name="count"/> is negative.</exception>
         protected void WriteBinary(byte[] buffer, int offset, int count)
         {
-            if (buffer == null)
-                throw new ArgumentNullException("buffer");
-
-            Write((uint) count);
-            Write(buffer, offset, count);
+            _stream.WriteBinary(buffer, offset, count);
         }
 #else
         /// <summary>
@@ -665,10 +605,7 @@ namespace Renci.SshNet.Common
         protected void Write(BigInteger data)
         {
 #if TUNING
-            var bytes = data.ToByteArray().Reverse();
-            var bytesLength = bytes.Length;
-            Write((uint) bytesLength);
-            Write(bytes, 0, bytesLength);
+            _stream.Write(data);
 #else
             var bytes = data.ToByteArray().Reverse().ToList();
             Write((uint)bytes.Count);
@@ -682,7 +619,7 @@ namespace Renci.SshNet.Common
         /// <param name="data">name-list data to write.</param>
         protected void Write(string[] data)
         {
-            WriteAscii(string.Join(",", data));
+            Write(string.Join(",", data), Ascii);
         }
 
         /// <summary>
@@ -693,8 +630,8 @@ namespace Renci.SshNet.Common
         {
             foreach (var item in data)
             {
-                WriteAscii(item.Key);
-                WriteAscii(item.Value);
+                Write(item.Key, Ascii);
+                Write(item.Value, Ascii);
             }
         }
     }

+ 93 - 13
Renci.SshClient/Renci.SshNet/Common/SshDataStream.cs

@@ -32,7 +32,7 @@ namespace Renci.SshNet.Common
         }
 
         /// <summary>
-        /// Writes <see cref="uint"/> data to the SSH data stream.
+        /// Writes an <see cref="uint"/> to the SSH data stream.
         /// </summary>
         /// <param name="value"><see cref="uint"/> data to write.</param>
         public void Write(uint value)
@@ -42,7 +42,7 @@ namespace Renci.SshNet.Common
         }
 
         /// <summary>
-        /// Writes <see cref="ulong"/> data to the SSH data stream.
+        /// Writes an <see cref="ulong"/> to the SSH data stream.
         /// </summary>
         /// <param name="value"><see cref="ulong"/> data to write.</param>
         public void Write(ulong value)
@@ -51,24 +51,102 @@ namespace Renci.SshNet.Common
             Write(bytes, 0, bytes.Length);
         }
 
+        /// <summary>
+        /// Writes a <see cref="BigInteger"/> into the SSH data stream.
+        /// </summary>
+        /// <param name="data">The <see cref="BigInteger" /> to write.</param>
+        public void Write(BigInteger data)
+        {
+            var bytes = data.ToByteArray().Reverse();
+            WriteBinary(bytes, 0, bytes.Length);
+        }
+
+        /// <summary>
+        /// Writes bytes array data into the SSH data stream.
+        /// </summary>
+        /// <param name="data">Byte array data to write.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="data"/> is null.</exception>
+        public void Write(byte[] data)
+        {
+            if (data == null)
+                throw new ArgumentNullException("data");
+
+            Write(data, 0, data.Length);
+        }
+
+        /// <summary>
+        /// Reads a byte array from the SSH data stream.
+        /// </summary>
+        /// <returns>
+        /// The byte array read from the SSH data stream.
+        /// </returns>
+        public byte[] ReadBinary()
+        {
+            var length = ReadUInt32();
+
+            if (length > int.MaxValue)
+            {
+                throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Data longer than {0} is not supported.", int.MaxValue));
+            }
+
+            return ReadBytes((int)length);
+        }
+
+        /// <summary>
+        /// Writes a buffer preceded by its length into the SSH data stream.
+        /// </summary>
+        /// <param name="buffer">The data to write.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="buffer"/> is null.</exception>
+        public void WriteBinary(byte[] buffer)
+        {
+            if (buffer == null)
+                throw new ArgumentNullException("buffer");
+
+            WriteBinary(buffer, 0, buffer.Length);
+        }
+
+        /// <summary>
+        /// Writes a buffer preceded by its length into the SSH data stream.
+        /// </summary>
+        /// <param name="buffer">An array of bytes. This method write <paramref name="count"/> bytes from buffer to the current SSH data stream.</param>
+        /// <param name="offset">The zero-based byte offset in <paramref name="buffer"/> at which to begin writing bytes to the SSH data stream.</param>
+        /// <param name="count">The number of bytes to be written to the current SSH data stream.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="buffer"/> is null.</exception>
+        /// <exception cref="ArgumentException">The sum of <paramref name="offset"/> and <paramref name="count"/> is greater than the buffer length.</exception>
+        /// <exception cref="ArgumentOutOfRangeException"><paramref name="offset"/> or <paramref name="count"/> is negative.</exception>
+        public void WriteBinary(byte[] buffer, int offset, int count)
+        {
+            Write((uint) count);
+            Write(buffer, offset, count);
+        }
+
         /// <summary>
         /// Writes string data to the SSH data stream using the specified encoding.
         /// </summary>
-        /// <param name="value">The string data to write.</param>
+        /// <param name="s">The string data to write.</param>
         /// <param name="encoding">The character encoding to use.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="value"/> is null.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="s"/> is null.</exception>
         /// <exception cref="ArgumentNullException"><paramref name="encoding"/> is null.</exception>
-        public void Write(string value, Encoding encoding)
+        public void Write(string s, Encoding encoding)
         {
-            if (value == null)
-                throw new ArgumentNullException("value");
             if (encoding == null)
                 throw new ArgumentNullException("encoding");
 
-            var bytes = encoding.GetBytes(value);
-            var bytesLength = bytes.Length;
-            Write((uint) bytesLength);
-            Write(bytes, 0, bytesLength);
+            var bytes = encoding.GetBytes(s);
+            WriteBinary(bytes, 0, bytes.Length);
+        }
+
+        /// <summary>
+        /// Reads a <see cref="BigInteger"/> from the SSH datastream.
+        /// </summary>
+        /// <returns>
+        /// The <see cref="BigInteger"/> read from the SSH data stream.
+        /// </returns>
+        public BigInteger ReadBigInt()
+        {
+            var length = ReadUInt32();
+            var data = ReadBytes((int) length);
+            return new BigInteger(data.Reverse());
         }
 
         /// <summary>
@@ -80,7 +158,7 @@ namespace Renci.SshNet.Common
         public uint ReadUInt32()
         {
             var data = ReadBytes(4);
-            return (uint)(data[0] << 24 | data[1] << 16 | data[2] << 8 | data[3]);
+            return (uint) (data[0] << 24 | data[1] << 16 | data[2] << 8 | data[3]);
         }
 
         /// <summary>
@@ -110,7 +188,9 @@ namespace Renci.SshNet.Common
             {
                 throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Strings longer than {0} is not supported.", int.MaxValue));
             }
-            return encoding.GetString(ReadBytes((int) length), 0, (int) length);
+
+            var bytes = ReadBytes((int) length);
+            return encoding.GetString(bytes, 0, bytes.Length);
         }
 
         /// <summary>

+ 7 - 2
Renci.SshClient/Renci.SshNet/ConnectionInfo.cs

@@ -402,12 +402,17 @@ namespace Renci.SshNet
         /// Authenticates the specified session.
         /// </summary>
         /// <param name="session">The session to be authenticated.</param>
+        /// <param name="serviceFactory">The factory to use for creating new services.</param>
         /// <exception cref="ArgumentNullException"><paramref name="session"/> is null.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serviceFactory"/> is <c>null</c>.</exception>
         /// <exception cref="SshAuthenticationException">No suitable authentication method found to complete authentication, or permission denied.</exception>
-        internal void Authenticate(ISession session)
+        internal void Authenticate(ISession session, IServiceFactory serviceFactory)
         {
+            if (serviceFactory == null)
+                throw new ArgumentNullException("serviceFactory");
+
             IsAuthenticated = false;
-            var clientAuthentication = new ClientAuthentication();
+            var clientAuthentication = serviceFactory.CreateClientAuthentication();
             clientAuthentication.Authenticate(this, session);
             IsAuthenticated = true;
         }

+ 12 - 0
Renci.SshClient/Renci.SshNet/IClientAuthentication.cs

@@ -0,0 +1,12 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Renci.SshNet
+{
+    internal interface IClientAuthentication
+    {
+        void Authenticate(IConnectionInfoInternal connectionInfo, ISession session);
+    }
+}

+ 18 - 0
Renci.SshClient/Renci.SshNet/IServiceFactory.cs

@@ -1,6 +1,8 @@
 using System;
+using System.Collections.Generic;
 using System.Text;
 using Renci.SshNet.Common;
+using Renci.SshNet.Security;
 using Renci.SshNet.Sftp;
 
 namespace Renci.SshNet
@@ -10,6 +12,8 @@ namespace Renci.SshNet
     /// </summary>
     internal partial interface IServiceFactory
     {
+        IClientAuthentication CreateClientAuthentication();
+
         /// <summary>
         /// Creates a new <see cref="ISession"/> with the specified <see cref="ConnectionInfo"/>.
         /// </summary>
@@ -39,5 +43,19 @@ namespace Renci.SshNet
         /// A <see cref="PipeStream"/>.
         /// </returns>
         PipeStream CreatePipeStream();
+
+        /// <summary>
+        /// Negotiates a key exchange algorithm, and creates a <see cref="IKeyExchange" /> for the negotiated
+        /// algorithm.
+        /// </summary>
+        /// <param name="clientAlgorithms">A <see cref="IDictionary{String, Type}"/> of the key exchange algorithms supported by the client where the key is the name of the algorithm, and the value is the type implementing this algorithm.</param>
+        /// <param name="serverAlgorithms">The names of the key exchange algorithms supported by the SSH server.</param>
+        /// <returns>
+        /// A <see cref="IKeyExchange"/> that was negotiated between client and server.
+        /// </returns>
+        /// <exception cref="ArgumentNullException"><paramref name="clientAlgorithms"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serverAlgorithms"/> is <c>null</c>.</exception>
+        /// <exception cref="SshConnectionException">No key exchange algorithm is supported by both client and server.</exception>
+        IKeyExchange CreateKeyExchange(IDictionary<string, Type> clientAlgorithms, string[] serverAlgorithms);
     }
 }

+ 0 - 15
Renci.SshClient/Renci.SshNet/ISession.cs

@@ -141,21 +141,6 @@ namespace Renci.SshNet
         /// </remarks>
         void WaitOnHandle(WaitHandle waitHandle);
 
-        /// <summary>
-        /// Waits for the specified handle or the exception handle for the receive thread
-        /// to signal within the specified timeout.
-        /// </summary>
-        /// <param name="waitHandle">The wait handle.</param>
-        /// <param name="timeout">The time to wait for any of the handles to become signaled.</param>
-        /// <exception cref="SshConnectionException">A received package was invalid or failed the message integrity check.</exception>
-        /// <exception cref="SshOperationTimeoutException">None of the handles are signaled in time and the session is not disconnecting.</exception>
-        /// <exception cref="SocketException">A socket error was signaled while receiving messages from the server.</exception>
-        /// <remarks>
-        /// When neither handles are signaled in time and the session is not closing, then the
-        /// session is disconnected.
-        /// </remarks>
-        void WaitOnHandle(WaitHandle waitHandle, TimeSpan timeout);
-
         /// <summary>
         /// Occurs when <see cref="ChannelCloseMessage"/> message received
         /// </summary>

+ 11 - 4
Renci.SshClient/Renci.SshNet/Messages/Transport/IgnoreMessage.cs

@@ -1,22 +1,26 @@
-namespace Renci.SshNet.Messages.Transport
+using System;
+
+namespace Renci.SshNet.Messages.Transport
 {
     /// <summary>
     /// Represents SSH_MSG_IGNORE message.
     /// </summary>
-    [Message("SSH_MSG_IGNORE", 2)]
+    [Message("SSH_MSG_IGNORE", MessageNumber)]
     public class IgnoreMessage : Message
     {
+        internal const byte MessageNumber = 2;
+
         /// <summary>
         /// Gets ignore message data if any.
         /// </summary>
         public byte[] Data { get; private set; }
 
         /// <summary>
-        /// Initializes a new instance of the <see cref="IgnoreMessage"/> class.
+        /// Initializes a new instance of the <see cref="IgnoreMessage"/> class
         /// </summary>
         public IgnoreMessage()
         {
-            Data = new byte[] { };
+            Data = new byte[0];
         }
 
 #if TUNING
@@ -44,6 +48,9 @@
         /// <param name="data">The data.</param>
         public IgnoreMessage(byte[] data)
         {
+            if (data == null)
+                throw new ArgumentNullException("data");
+
             Data = data;
         }
 

+ 3 - 1
Renci.SshClient/Renci.SshNet/Messages/Transport/KeyExchangeDhGroupExchangeGroup.cs

@@ -6,9 +6,11 @@ namespace Renci.SshNet.Messages.Transport
     /// <summary>
     /// Represents SSH_MSG_KEX_DH_GEX_GROUP message.
     /// </summary>
-    [Message("SSH_MSG_KEX_DH_GEX_GROUP", 31)]
+    [Message("SSH_MSG_KEX_DH_GEX_GROUP", MessageNumber)]
     public class KeyExchangeDhGroupExchangeGroup : Message
     {
+        internal const byte MessageNumber = 31;
+
 #if TUNING
         private byte[] _safePrime;
         private byte[] _subGroup;

+ 3 - 1
Renci.SshClient/Renci.SshNet/Messages/Transport/KeyExchangeDhGroupExchangeReply.cs

@@ -6,9 +6,11 @@ namespace Renci.SshNet.Messages.Transport
     /// <summary>
     /// Represents SSH_MSG_KEX_DH_GEX_REPLY message.
     /// </summary>
-    [Message("SSH_MSG_KEX_DH_GEX_REPLY", 33)]
+    [Message("SSH_MSG_KEX_DH_GEX_REPLY", MessageNumber)]
     internal class KeyExchangeDhGroupExchangeReply : Message
     {
+        internal const byte MessageNumber = 33;
+
 #if TUNING
         private byte[] _fBytes;
 #endif

+ 3 - 1
Renci.SshClient/Renci.SshNet/Messages/Transport/KeyExchangeDhGroupExchangeRequest.cs

@@ -5,9 +5,11 @@ namespace Renci.SshNet.Messages.Transport
     /// <summary>
     /// Represents SSH_MSG_KEX_DH_GEX_REQUEST message.
     /// </summary>
-    [Message("SSH_MSG_KEX_DH_GEX_REQUEST", 34)]
+    [Message("SSH_MSG_KEX_DH_GEX_REQUEST", MessageNumber)]
     internal class KeyExchangeDhGroupExchangeRequest : Message, IKeyExchangedAllowed
     {
+        internal const byte MessageNumber = 34;
+
         /// <summary>
         /// Gets or sets the minimal size in bits of an acceptable group.
         /// </summary>

+ 3 - 1
Renci.SshClient/Renci.SshNet/Messages/Transport/ServiceAcceptMessage.cs

@@ -6,9 +6,11 @@ namespace Renci.SshNet.Messages.Transport
     /// <summary>
     /// Represents SSH_MSG_SERVICE_ACCEPT message.
     /// </summary>
-    [Message("SSH_MSG_SERVICE_ACCEPT", 6)]
+    [Message("SSH_MSG_SERVICE_ACCEPT", MessageNumber)]
     public class ServiceAcceptMessage : Message
     {
+        internal const byte MessageNumber = 6;
+
         /// <summary>
         /// Gets the name of the service.
         /// </summary>

+ 2 - 0
Renci.SshClient/Renci.SshNet/Renci.SshNet.csproj

@@ -154,6 +154,7 @@
       <SubType>Code</SubType>
     </Compile>
     <Compile Include="IAuthenticationMethod.cs" />
+    <Compile Include="IClientAuthentication.cs" />
     <Compile Include="IConnectionInfo.cs" />
     <Compile Include="IForwardedPort.cs" />
     <Compile Include="IServiceFactory.cs" />
@@ -174,6 +175,7 @@
       <SubType>Code</SubType>
     </Compile>
     <Compile Include="Security\GroupExchangeHashData.cs" />
+    <Compile Include="Security\IKeyExchange.cs" />
     <Compile Include="Security\KeyExchangeDiffieHellmanGroupExchangeShaBase.cs" />
     <Compile Include="Security\KeyExchangeEllipticCurveDiffieHellman.cs" />
     <Compile Include="ServiceFactory.cs" />

+ 96 - 0
Renci.SshClient/Renci.SshNet/Security/IKeyExchange.cs

@@ -0,0 +1,96 @@
+using System;
+using System.Security.Cryptography;
+using Renci.SshNet.Common;
+using Renci.SshNet.Compression;
+using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Security.Cryptography;
+
+namespace Renci.SshNet.Security
+{
+    /// <summary>
+    /// Represents a key exchange algorithm.
+    /// </summary>
+    public interface IKeyExchange : IDisposable
+    {
+        /// <summary>
+        /// Occurs when the host key is received.
+        /// </summary>
+        event EventHandler<HostKeyEventArgs> HostKeyReceived;
+
+        /// <summary>
+        /// Gets the name of the algorithm.
+        /// </summary>
+        /// <value>
+        /// The name of the algorithm.
+        /// </value>
+        string Name { get; }
+
+        /// <summary>
+        /// Gets the exchange hash.
+        /// </summary>
+        /// <value>
+        /// The exchange hash.
+        /// </value>
+        byte[] ExchangeHash { get; }
+
+        /// <summary>
+        /// Starts the key exchange algorithm.
+        /// </summary>
+        /// <param name="session">The session.</param>
+        /// <param name="message">Key exchange init message.</param>
+        void Start(Session session, KeyExchangeInitMessage message);
+
+        /// <summary>
+        /// Finishes the key exchange algorithm.
+        /// </summary>
+        void Finish();
+
+        /// <summary>
+        /// Creates the client-side cipher to use.
+        /// </summary>
+        /// <returns>
+        /// The client cipher.
+        /// </returns>
+        Cipher CreateClientCipher();
+
+        /// <summary>
+        /// Creates the server-side cipher to use.
+        /// </summary>
+        /// <returns>
+        /// The server cipher.
+        /// </returns>
+        Cipher CreateServerCipher();
+
+        /// <summary>
+        /// Creates the server-side hash algorithm to use.
+        /// </summary>
+        /// <returns>
+        /// The server hash algorithm.
+        /// </returns>
+        HashAlgorithm CreateServerHash();
+
+        /// <summary>
+        /// Creates the client-side hash algorithm to use.
+        /// </summary>
+        /// <returns>
+        /// The client hash algorithm.
+        /// </returns>
+        HashAlgorithm CreateClientHash();
+
+        /// <summary>
+        /// Creates the compression algorithm to use to deflate data.
+        /// </summary>
+        /// <returns>
+        /// The compression method to deflate data.
+        /// </returns>
+        Compressor CreateCompressor();
+
+        /// <summary>
+        /// Creates the compression algorithm to use to inflate data.
+        /// </summary>
+        /// <returns>
+        /// The compression method to inflate data.
+        /// </returns>
+        Compressor CreateDecompressor();
+    }
+}

+ 2 - 1
Renci.SshClient/Renci.SshNet/Security/KeyExchange.cs

@@ -13,7 +13,7 @@ namespace Renci.SshNet.Security
     /// <summary>
     /// Represents base class for different key exchange algorithm implementations
     /// </summary>
-    public abstract class KeyExchange : Algorithm, IDisposable
+    public abstract class KeyExchange : Algorithm, IKeyExchange
     {
         private CipherInfo _clientCipherInfo;
 
@@ -44,6 +44,7 @@ namespace Renci.SshNet.Security
         public BigInteger SharedKey { get; protected set; }
 
         private byte[] _exchangeHash;
+
         /// <summary>
         /// Gets the exchange hash.
         /// </summary>

+ 49 - 1
Renci.SshClient/Renci.SshNet/ServiceFactory.cs

@@ -1,6 +1,10 @@
 using System;
+using System.Collections.Generic;
+using System.Linq;
 using System.Text;
 using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Security;
 using Renci.SshNet.Sftp;
 
 namespace Renci.SshNet
@@ -10,6 +14,17 @@ namespace Renci.SshNet
     /// </summary>
     internal partial class ServiceFactory : IServiceFactory
     {
+        /// <summary>
+        /// Creates a <see cref="IClientAuthentication"/>.
+        /// </summary>
+        /// <returns>
+        /// A <see cref="IClientAuthentication"/>.
+        /// </returns>
+        public IClientAuthentication CreateClientAuthentication()
+        {
+            return new ClientAuthentication();
+        }
+
         /// <summary>
         /// Creates a new <see cref="ISession"/> with the specified <see cref="ConnectionInfo"/>.
         /// </summary>
@@ -20,7 +35,7 @@ namespace Renci.SshNet
         /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <c>null</c>.</exception>
         public ISession CreateSession(ConnectionInfo connectionInfo)
         {
-            return new Session(connectionInfo);
+            return new Session(connectionInfo, this);
         }
 
         /// <summary>
@@ -48,5 +63,38 @@ namespace Renci.SshNet
         {
             return new PipeStream();
         }
+
+        /// <summary>
+        /// Negotiates a key exchange algorithm, and creates a <see cref="IKeyExchange" /> for the negotiated
+        /// algorithm.
+        /// </summary>
+        /// <param name="clientAlgorithms">A <see cref="IDictionary{String, Type}"/> of the key exchange algorithms supported by the client where key is the name of the algorithm, and value is the type implementing this algorithm.</param>
+        /// <param name="serverAlgorithms">The names of the key exchange algorithms supported by the SSH server.</param>
+        /// <returns>
+        /// A <see cref="IKeyExchange"/> that was negotiated between client and server.
+        /// </returns>
+        /// <exception cref="ArgumentNullException"><paramref name="clientAlgorithms"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serverAlgorithms"/> is <c>null</c>.</exception>
+        /// <exception cref="SshConnectionException">No key exchange algorithms are supported by both client and server.</exception>
+        public IKeyExchange CreateKeyExchange(IDictionary<string, Type> clientAlgorithms, string[] serverAlgorithms)
+        {
+            if (clientAlgorithms == null)
+                throw new ArgumentNullException("clientAlgorithms");
+            if (serverAlgorithms == null)
+                throw new ArgumentNullException("serverAlgorithms");
+
+            // find an algorithm that is supported by both client and server
+            var keyExchangeAlgorithmType = (from c in clientAlgorithms
+                                            from s in serverAlgorithms
+                                            where s == c.Key
+                                            select c.Value).FirstOrDefault();
+
+            if (keyExchangeAlgorithmType == null)
+            {
+                throw new SshConnectionException("Failed to negotiate key exchange algorithm.", DisconnectReason.KeyExchangeFailed);
+            }
+
+            return keyExchangeAlgorithmType.CreateInstance<KeyExchange>();
+        }
     }
 }

+ 3 - 8
Renci.SshClient/Renci.SshNet/Session.NET.cs

@@ -217,13 +217,6 @@ namespace Renci.SshNet
                 }
                 catch (SocketException exp)
                 {
-                    if (exp.SocketErrorCode == SocketError.ConnectionAborted)
-                    {
-                        buffer = new byte[length];
-                        Disconnect();
-                        return;
-                    }
-
                     if (exp.SocketErrorCode == SocketError.WouldBlock ||
                         exp.SocketErrorCode == SocketError.IOPending ||
                         exp.SocketErrorCode == SocketError.NoBufferSpaceAvailable)
@@ -232,7 +225,9 @@ namespace Renci.SshNet
                         Thread.Sleep(30);
                     }
                     else
-                        throw;  // any serious error occurred
+                    {
+                        throw new SshConnectionException(exp.Message, DisconnectReason.ConnectionLost, exp);
+                    }
                 }
             } while (receivedTotal < length);
         }

+ 60 - 78
Renci.SshClient/Renci.SshNet/Session.cs

@@ -63,7 +63,9 @@ namespace Renci.SshNet
         /// </value>
         private const int LocalChannelDataPacketSize = 1024*64;
 
+#if !TUNING
         private static readonly RNGCryptoServiceProvider Randomizer = new RNGCryptoServiceProvider();
+#endif
 
 #if SILVERLIGHT
         private static readonly Regex ServerVersionRe = new Regex("^SSH-(?<protoversion>[^-]+)-(?<softwareversion>.+)( SP.+)?$");
@@ -149,7 +151,7 @@ namespace Renci.SshNet
         /// </summary>
         private bool _isDisconnecting;
 
-        private KeyExchange _keyExchange;
+        private IKeyExchange _keyExchange;
 
         private HashAlgorithm _serverMac;
 
@@ -165,6 +167,11 @@ namespace Renci.SshNet
 
         private SemaphoreLight _sessionSemaphore;
 
+        /// <summary>
+        /// Holds the factory to use for creating new services.
+        /// </summary>
+        private readonly IServiceFactory _serviceFactory;
+
         /// <summary>
         /// Gets the session semaphore that controls session channels.
         /// </summary>
@@ -225,6 +232,9 @@ namespace Renci.SshNet
         /// This methods returns true in all but the following cases:
         /// <list type="bullet">
         ///     <item>
+        ///         <description>The <see cref="Session"/> is disposed.</description>
+        ///     </item>
+        ///     <item>
         ///         <description>The SSH_MSG_DISCONNECT message - which is used to disconnect from the server - has been sent.</description>
         ///     </item>
         ///     <item>
@@ -242,7 +252,7 @@ namespace Renci.SshNet
         {
             get
             {
-                if (_isDisconnectMessageSent || !_isAuthenticated)
+                if (_disposed || _isDisconnectMessageSent || !_isAuthenticated)
                     return false;
                 if (_messageListenerCompleted == null || _messageListenerCompleted.WaitOne(0))
                     return false;
@@ -468,15 +478,20 @@ namespace Renci.SshNet
         /// Initializes a new instance of the <see cref="Session"/> class.
         /// </summary>
         /// <param name="connectionInfo">The connection info.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <c>null</c>.</exception>
-        internal Session(ConnectionInfo connectionInfo)
+        /// <param name="serviceFactory">The factory to use for creating new services.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="serviceFactory"/> is <c>null</c>.</exception>
+        internal Session(ConnectionInfo connectionInfo, IServiceFactory serviceFactory)
         {
             if (connectionInfo == null)
                 throw new ArgumentNullException("connectionInfo");
+            if (serviceFactory == null)
+                throw new ArgumentNullException("serviceFactory");
 
             ConnectionInfo = connectionInfo;
             //this.ClientVersion = string.Format(CultureInfo.CurrentCulture, "SSH-2.0-Renci.SshNet.SshClient.{0}", this.GetType().Assembly.GetName().Version);
-            ClientVersion = string.Format(CultureInfo.CurrentCulture, "SSH-2.0-Renci.SshNet.SshClient.0.0.1");
+            ClientVersion = "SSH-2.0-Renci.SshNet.SshClient.0.0.1";
+            _serviceFactory = serviceFactory;
+            _messageListenerCompleted = new ManualResetEvent(true);
         }
 
         /// <summary>
@@ -538,7 +553,7 @@ namespace Renci.SshNet
                         var serverVersion = string.Empty;
                         SocketReadLine(ref serverVersion, ConnectionInfo.Timeout);
                         if (serverVersion == null)
-                            throw new SshConnectionException("Server response does not contain SSH protocol identification.");
+                            throw new SshConnectionException("Server response does not contain SSH protocol identification.", DisconnectReason.ProtocolError);
                         versionMatch = ServerVersionRe.Match(serverVersion);
                         if (versionMatch.Success)
                         {
@@ -577,20 +592,11 @@ namespace Renci.SshNet
                     //  Some server implementations might sent this message first, prior establishing encryption algorithm
                     RegisterMessage("SSH_MSG_USERAUTH_BANNER");
 
-                    //  Start incoming request listener
-                    _messageListenerCompleted = new ManualResetEvent(false);
+                    // mark the message listener threads as started
+                    _messageListenerCompleted.Reset();
 
-                    ExecuteThread(() =>
-                    {
-                        try
-                        {
-                            MessageListener();
-                        }
-                        finally
-                        {
-                            _messageListenerCompleted.Set();
-                        }
-                    });
+                    //  Start incoming request listener
+                    ExecuteThread(MessageListener);
 
                     //  Wait for key exchange to be completed
                     WaitOnHandle(_keyExchangeCompletedWaitHandle);
@@ -613,7 +619,7 @@ namespace Renci.SshNet
                         throw new SshException("Username is not specified.");
                     }
 
-                    ConnectionInfo.Authenticate(this);
+                    ConnectionInfo.Authenticate(this, _serviceFactory);
                     _isAuthenticated = true;
 
                     //  Register Connection messages
@@ -650,6 +656,14 @@ namespace Renci.SshNet
         public void Disconnect()
         {
             Disconnect(DisconnectReason.ByApplication, "Connection terminated by the client.");
+
+            // at this point, we are sure that the listener thread will stop as we've
+            // disconnected the socket, so lets wait until the message listener thread
+            // has completed
+            if (_messageListenerCompleted != null)
+            {
+                _messageListenerCompleted.WaitOne();
+            }
         }
 
         private void Disconnect(DisconnectReason reason, string message)
@@ -661,19 +675,13 @@ namespace Renci.SshNet
             //
             // note that this should also cause the listener thread to be stopped as
             // the server should respond by closing the socket
-            SendDisconnect(reason, message);
+            if (reason == DisconnectReason.ByApplication)
+            {
+                SendDisconnect(reason, message);
+            }
 
             // disconnect socket, and dispose it
             SocketDisconnectAndDispose();
-
-            if (_messageListenerCompleted != null)
-            {
-                // at this point, we are sure that the listener thread will stop
-                // as we've disconnected the socket
-                _messageListenerCompleted.WaitOne();
-                _messageListenerCompleted.Dispose();
-                _messageListenerCompleted = null;
-            }
         }
 
         /// <summary>
@@ -693,24 +701,6 @@ namespace Renci.SshNet
             WaitOnHandle(waitHandle, ConnectionInfo.Timeout);
         }
 
-        /// <summary>
-        /// Waits for the specified handle or the exception handle for the receive thread
-        /// to signal within the specified timeout.
-        /// </summary>
-        /// <param name="waitHandle">The wait handle.</param>
-        /// <param name="timeout">The time to wait for any of the handles to become signaled.</param>
-        /// <exception cref="SshConnectionException">A received package was invalid or failed the message integrity check.</exception>
-        /// <exception cref="SshOperationTimeoutException">None of the handles are signaled in time and the session is not disconnecting.</exception>
-        /// <exception cref="SocketException">A socket error was signaled while receiving messages from the server.</exception>
-        /// <remarks>
-        /// When neither handles are signaled in time and the session is not closing, then the
-        /// session is disconnected.
-        /// </remarks>
-        void ISession.WaitOnHandle(WaitHandle waitHandle, TimeSpan timeout)
-        {
-            WaitOnHandle(waitHandle, timeout);
-        }
-
         /// <summary>
         /// Waits for the specified handle or the exception handle for the receive thread
         /// to signal within the connection timeout.
@@ -754,17 +744,7 @@ namespace Renci.SshNet
                 case 0:
                     throw _exception;
                 case 1:
-                    // when the session is NOT disconnecting, the listener should actually
-                    // never complete without setting the exception wait handle and should
-                    // end up in case 0... 
-                    //
-                    // when the session is disconnecting, the completion of the listener
-                    // should not be considered an error (quite the oppposite actually)
-                    if (!_isDisconnecting)
-                    {
-                        throw new SshConnectionException("Client not connected.");
-                    }
-                    break;
+                    throw new SshConnectionException("Client not connected.");
                 case WaitHandle.WaitTimeout:
                     // when the session is disconnecting, a timeout is likely when no
                     // network connectivity is available; depending on the configured
@@ -1013,7 +993,7 @@ namespace Renci.SshNet
 
             //  Test packet minimum and maximum boundaries
             if (packetLength < Math.Max((byte)16, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
-                throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length {0}", packetLength), DisconnectReason.ProtocolError);
+                throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength), DisconnectReason.ProtocolError);
 
             //  Read rest of the packet data
             var bytesToRead = (int)(packetLength - (blockSize - 4));
@@ -1145,9 +1125,7 @@ namespace Renci.SshNet
         private void HandleMessage(DisconnectMessage message)
         {
             OnDisconnectReceived(message);
-
-            //  disconnect from the socket, and dispose it
-            SocketDisconnectAndDispose();
+            Disconnect(message.ReasonCode, message.Description);
         }
 
         private void HandleMessage(IgnoreMessage message)
@@ -1298,6 +1276,9 @@ namespace Renci.SshNet
         {
             Log(string.Format("Disconnect received: {0} {1}", message.ReasonCode, message.Description));
 
+            _exception = new SshConnectionException(string.Format(CultureInfo.InvariantCulture, "The connection was closed by the server: {0} ({1}).", message.Description, message.ReasonCode), message.ReasonCode);
+            _exceptionWaitHandle.Set();
+
             var disconnectReceived = DisconnectReceived;
             if (disconnectReceived != null)
                 disconnectReceived(this, new MessageEventArgs<DisconnectMessage>(message));
@@ -1379,20 +1360,10 @@ namespace Renci.SshNet
                     messageMetadata.Enabled = false;
             }
 
-            var keyExchangeAlgorithmName = (from c in ConnectionInfo.KeyExchangeAlgorithms.Keys
-                                            from s in message.KeyExchangeAlgorithms
-                                            where s == c
-                                            select c).FirstOrDefault();
-
-            if (keyExchangeAlgorithmName == null)
-            {
-                throw new SshConnectionException("Failed to negotiate key exchange algorithm.", DisconnectReason.KeyExchangeFailed);
-            }
+            _keyExchange = _serviceFactory.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms,
+                                                             message.KeyExchangeAlgorithms);
 
-            //  Create instance of key exchange algorithm that will be used
-            _keyExchange = ConnectionInfo.KeyExchangeAlgorithms[keyExchangeAlgorithmName].CreateInstance<KeyExchange>();
-
-            ConnectionInfo.CurrentKeyExchangeAlgorithm = keyExchangeAlgorithmName;
+            ConnectionInfo.CurrentKeyExchangeAlgorithm = _keyExchange.Name;
 
             _keyExchange.HostKeyReceived += KeyExchange_HostKeyReceived;
 
@@ -1872,6 +1843,11 @@ namespace Renci.SshNet
             {
                 RaiseError(exp);
             }
+            finally
+            {
+                // signal that the message listener thread has stopped
+                _messageListenerCompleted.Set();
+            }
         }
 
         private byte SocketReadByte()
@@ -2206,7 +2182,7 @@ namespace Renci.SshNet
             if (errorOccured != null)
                 errorOccured(this, new ExceptionEventArgs(exp));
 
-            if (connectionException != null && connectionException.DisconnectReason != DisconnectReason.ConnectionLost)
+            if (connectionException != null)
             {
                 Disconnect(connectionException.DisconnectReason, exp.ToString());
             }
@@ -2223,7 +2199,7 @@ namespace Renci.SshNet
             if (_keyExchangeCompletedWaitHandle != null)
                 _keyExchangeCompletedWaitHandle.Reset();
             if (_messageListenerCompleted != null)
-                _messageListenerCompleted.Reset();
+                _messageListenerCompleted.Set();
 
             SessionId = null;
             _isDisconnectMessageSent = false;
@@ -2303,6 +2279,12 @@ namespace Renci.SshNet
                         _bytesReadFromSocket.Dispose();
                         _bytesReadFromSocket = null;
                     }
+
+                    if (_messageListenerCompleted != null)
+                    {
+                        _messageListenerCompleted.Dispose();
+                        _messageListenerCompleted = null;
+                    }
                 }
 
                 // Note disposing has been done.