Jelajahi Sumber

Implement OpenSSH strict key exchange extension (#1366)

* Implement OpenSSH strict key exchange extension

* The pseudo-algorithm
is only valid in the initial SSH2_MSG_KEXINIT and MUST be ignored
if they are present in subsequent SSH2_MSG_KEXINIT packets.

* Only send strict kex pseudo algorithm for the first kex.
Strictly disable non-kex massages in strict kex mode.

* Unit tests for strict kex

* More unit tests

* More unit tests

* Correct file name

* Update SessionTest_ConnectingBase.cs

* More unit tests

* Delete SessionTest_Connecting_ServerSendsMaxIgnoreMessagesBeforeKexInit.cs

* Add a comment about throwing exception when inbound sequence number is about to wrap during init kex.

* Delete SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_NoStrictKex.cs

* Fix build

* Update test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs

---------

Co-authored-by: Rob Hague <rob.hague00@gmail.com>
Scott Xu 1 tahun lalu
induk
melakukan
94397d47ed
17 mengubah file dengan 750 tambahan dan 58 penghapusan
  1. 83 34
      src/Renci.SshNet/Session.cs
  2. 28 3
      src/Renci.SshNet/SshMessageFactory.cs
  3. 2 2
      test/Renci.SshNet.Tests/Classes/SessionTest_ConnectToServerFails.cs
  4. 26 0
      test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs
  5. 3 9
      test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs
  6. 294 0
      test/Renci.SshNet.Tests/Classes/SessionTest_ConnectingBase.cs
  7. 1 7
      test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerIdentificationReceived.cs
  8. 35 0
      test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerNotResetSequenceNumberAfterNewKeys_StrictKex.cs
  9. 31 0
      test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerResetsSequenceNumberAfterNewKeys_StrictKex.cs
  10. 48 0
      test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs
  11. 39 0
      test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs
  12. 38 0
      test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs
  13. 40 0
      test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs
  14. 38 0
      test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs
  15. 41 0
      test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs
  16. 2 2
      test/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs
  17. 1 1
      test/Renci.SshNet.Tests/Common/AsyncSocketListener.cs

+ 83 - 34
src/Renci.SshNet/Session.cs

@@ -154,6 +154,17 @@ namespace Renci.SshNet
         /// </summary>
         private bool _isDisconnecting;
 
+        /// <summary>
+        /// Indicates whether it is the init kex.
+        /// </summary>
+        private bool _isInitialKex;
+
+        /// <summary>
+        /// Indicates whether server supports strict key exchange.
+        /// <see href="https://github.com/openssh/openssh-portable/blob/master/PROTOCOL"/> 1.10.
+        /// </summary>
+        private bool _isStrictKex;
+
         private IKeyExchange _keyExchange;
 
         private HashAlgorithm _serverMac;
@@ -281,35 +292,11 @@ namespace Renci.SshNet
         /// </value>
         public byte[] SessionId { get; private set; }
 
-        private Message _clientInitMessage;
-
         /// <summary>
         /// Gets the client init message.
         /// </summary>
         /// <value>The client init message.</value>
-        public Message ClientInitMessage
-        {
-            get
-            {
-                _clientInitMessage ??= new KeyExchangeInitMessage
-                    {
-                        KeyExchangeAlgorithms = ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(),
-                        ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(),
-                        EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(),
-                        EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(),
-                        MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
-                        MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
-                        CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
-                        CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
-                        LanguagesClientToServer = new[] { string.Empty },
-                        LanguagesServerToClient = new[] { string.Empty },
-                        FirstKexPacketFollows = false,
-                        Reserved = 0
-                    };
-
-                return _clientInitMessage;
-            }
-        }
+        public Message ClientInitMessage { get; private set; }
 
         /// <summary>
         /// Gets the server version string.
@@ -617,6 +604,8 @@ namespace Renci.SshNet
                 // Send our key exchange init.
                 // We need to do this before starting the message listener to avoid the case where we receive the server
                 // key exchange init and we continue the key exchange before having sent our own init.
+                _isInitialKex = true;
+                ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: true);
                 SendMessage(ClientInitMessage);
 
                 // Mark the message listener threads as started
@@ -741,6 +730,8 @@ namespace Renci.SshNet
                 // Send our key exchange init.
                 // We need to do this before starting the message listener to avoid the case where we receive the server
                 // key exchange init and we continue the key exchange before having sent our own init.
+                _isInitialKex = true;
+                ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: true);
                 SendMessage(ClientInitMessage);
 
                 // Mark the message listener threads as started
@@ -1107,13 +1098,20 @@ namespace Renci.SshNet
                     SendPacket(data, 0, data.Length);
                 }
 
-                // increment the packet sequence number only after we're sure the packet has
-                // been sent; even though it's only used for the MAC, it needs to be incremented
-                // for each package sent.
-                //
-                // the server will use it to verify the data integrity, and as such the order in
-                // which messages are sent must follow the outbound packet sequence number
-                _outboundPacketSequence++;
+                if (_isStrictKex && message is NewKeysMessage)
+                {
+                    _outboundPacketSequence = 0;
+                }
+                else
+                {
+                    // increment the packet sequence number only after we're sure the packet has
+                    // been sent; even though it's only used for the MAC, it needs to be incremented
+                    // for each package sent.
+                    //
+                    // the server will use it to verify the data integrity, and as such the order in
+                    // which messages are sent must follow the outbound packet sequence number
+                    _outboundPacketSequence++;
+                }
             }
         }
 
@@ -1344,6 +1342,13 @@ namespace Renci.SshNet
 
             _inboundPacketSequence++;
 
+            // The below code mirrors from https://github.com/openssh/openssh-portable/commit/1edb00c58f8a6875fad6a497aa2bacf37f9e6cd5
+            // It ensures the integrity of key exchange process.
+            if (_inboundPacketSequence == uint.MaxValue && _isInitialKex)
+            {
+                throw new SshConnectionException("Inbound packet sequence number is about to wrap during initial key exchange.", DisconnectReason.KeyExchangeFailed);
+            }
+
             return LoadMessage(data, messagePayloadOffset, messagePayloadLength);
         }
 
@@ -1455,8 +1460,20 @@ namespace Renci.SshNet
 
             _keyExchangeCompletedWaitHandle.Reset();
 
+            if (_isInitialKex && message.KeyExchangeAlgorithms.Contains("kex-strict-s-v00@openssh.com"))
+            {
+                _isStrictKex = true;
+
+                DiagnosticAbstraction.Log(string.Format("[{0}] Enabling strict key exchange extension.", ToHex(SessionId)));
+
+                if (_inboundPacketSequence != 1)
+                {
+                    throw new SshConnectionException("KEXINIT was not the first packet during strict key exchange.", DisconnectReason.KeyExchangeFailed);
+                }
+            }
+
             // Disable messages that are not key exchange related
-            _sshMessageFactory.DisableNonKeyExchangeMessages();
+            _sshMessageFactory.DisableNonKeyExchangeMessages(_isStrictKex);
 
             _keyExchange = _serviceFactory.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms,
                                                              message.KeyExchangeAlgorithms);
@@ -1533,6 +1550,17 @@ namespace Renci.SshNet
             // Enable activated messages that are not key exchange related
             _sshMessageFactory.EnableActivatedMessages();
 
+            if (_isInitialKex)
+            {
+                _isInitialKex = false;
+                ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: false);
+            }
+
+            if (_isStrictKex)
+            {
+                _inboundPacketSequence = 0;
+            }
+
             NewKeysReceived?.Invoke(this, new MessageEventArgs<NewKeysMessage>(message));
 
             // Signal that key exchange completed
@@ -2067,7 +2095,28 @@ namespace Renci.SshNet
         private static SshConnectionException CreateConnectionAbortedByServerException()
         {
             return new SshConnectionException("An established connection was aborted by the server.",
-                                              DisconnectReason.ConnectionLost);
+            DisconnectReason.ConnectionLost);
+        }
+
+        private KeyExchangeInitMessage BuildClientInitMessage(bool includeStrictKexPseudoAlgorithm)
+        {
+            return new KeyExchangeInitMessage
+            {
+                KeyExchangeAlgorithms = includeStrictKexPseudoAlgorithm ?
+                                        ConnectionInfo.KeyExchangeAlgorithms.Keys.Concat(["kex-strict-c-v00@openssh.com"]).ToArray() :
+                                        ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(),
+                ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(),
+                EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(),
+                EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(),
+                MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
+                MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
+                CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
+                CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
+                LanguagesClientToServer = new[] { string.Empty },
+                LanguagesServerToClient = new[] { string.Empty },
+                FirstKexPacketFollows = false,
+                Reserved = 0,
+            };
         }
 
         private bool _disposed;

+ 28 - 3
src/Renci.SshNet/SshMessageFactory.cs

@@ -115,16 +115,41 @@ namespace Renci.SshNet
             return enabledMessageMetadata.Create();
         }
 
-        public void DisableNonKeyExchangeMessages()
+        /// <summary>
+        /// Disables non-KeyExchange messages.
+        /// </summary>
+        /// <param name="strict">
+        /// <see langword="true"/> to indicate the strict key exchange mode; otherwise <see langword="false"/>.
+        /// <para>In strict key exchange mode, only below messages are allowed:</para>
+        /// <list type="bullet">
+        /// <item>SSH_MSG_KEXINIT -> 20</item>
+        /// <item>SSH_MSG_NEWKEYS -> 21</item>
+        /// <item>SSH_MSG_DISCONNECT -> 1</item>
+        /// </list>
+        /// <para>Note:</para>
+        /// <para>  The relevant KEX Reply MSG will be allowed from a sub class of KeyExchange class.</para>
+        /// <para>  For example, it calls <c>Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY");</c> if the curve25519-sha256 KEX algorithm is selected per negotiation.</para>
+        /// </param>
+        public void DisableNonKeyExchangeMessages(bool strict)
         {
             for (var i = 0; i < AllMessages.Length; i++)
             {
                 var messageMetadata = AllMessages[i];
 
                 var messageNumber = messageMetadata.Number;
-                if (messageNumber is (> 2 and < 20) or > 30)
+                if (strict)
+                {
+                    if (messageNumber is not 20 and not 21 and not 1)
+                    {
+                        _enabledMessagesByNumber[messageNumber] = null;
+                    }
+                }
+                else
                 {
-                    _enabledMessagesByNumber[messageNumber] = null;
+                    if (messageNumber is (> 2 and < 20) or > 30)
+                    {
+                        _enabledMessagesByNumber[messageNumber] = null;
+                    }
                 }
             }
         }

+ 2 - 2
test/Renci.SshNet.Tests/Classes/SessionTest_ConnectToServerFails.cs

@@ -87,7 +87,7 @@ namespace Renci.SshNet.Tests.Classes
         }
 
         [TestMethod]
-        public void SendMessageShouldThrowShhConnectionException()
+        public void SendMessageShouldThrowSshConnectionException()
         {
             try
             {
@@ -189,7 +189,7 @@ namespace Renci.SshNet.Tests.Classes
         }
 
         [TestMethod]
-        public void ISession_SendMessageShouldThrowShhConnectionException()
+        public void ISession_SendMessageShouldThrowSshConnectionException()
         {
             var session = (ISession)_session;
 

+ 26 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs

@@ -1,4 +1,5 @@
 using System;
+using System.Linq;
 using System.Threading;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
 using Moq;
@@ -30,6 +31,31 @@ namespace Renci.SshNet.Tests.Classes
             Assert.AreEqual("SSH-2.0-Renci.SshNet.SshClient.0.0.1", Session.ClientVersion);
         }
 
+        [TestMethod]
+        public void IncludeStrictKexPseudoAlgorithmInInitKex()
+        {
+            Assert.IsTrue(ServerBytesReceivedRegister.Count > 0);
+
+            var kexInitMessage = new KeyExchangeInitMessage();
+            kexInitMessage.Load(ServerBytesReceivedRegister[0], 4 + 1 + 1, ServerBytesReceivedRegister[0].Length - 4 - 1 - 1);
+            Assert.IsTrue(kexInitMessage.KeyExchangeAlgorithms.Contains("kex-strict-c-v00@openssh.com"));
+        }
+
+        [TestMethod]
+        public void ShouldNotIncludeStrictKexPseudoAlgorithmInSubsequentKex()
+        {
+            ServerBytesReceivedRegister.Clear();
+            Session.SendMessage(Session.ClientInitMessage);
+
+            Thread.Sleep(100);
+
+            Assert.IsTrue(ServerBytesReceivedRegister.Count > 0);
+
+            var kexInitMessage = new KeyExchangeInitMessage();
+            kexInitMessage.Load(ServerBytesReceivedRegister[0], 4 + 1 + 1, ServerBytesReceivedRegister[0].Length - 4 - 1 - 1);
+            Assert.IsFalse(kexInitMessage.KeyExchangeAlgorithms.Contains("kex-strict-c-v00@openssh.com"));
+        }
+
         [TestMethod]
         public void ConnectionInfoShouldReturnConnectionInfoPassedThroughConstructor()
         {

+ 3 - 9
test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs

@@ -46,8 +46,7 @@ namespace Renci.SshNet.Tests.Classes
         protected Session Session { get; private set; }
         protected Socket ClientSocket { get; private set; }
         protected Socket ServerSocket { get; private set; }
-        internal SshIdentification ServerIdentification { get; set; }
-        protected bool CallSessionConnectWhenArrange { get; set; }
+        protected SshIdentification ServerIdentification { get; private set; }
 
         /// <summary>
         /// Should the "server" wait for the client kexinit before sending its own.
@@ -163,8 +162,6 @@ namespace Renci.SshNet.Tests.Classes
 
             ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo);
 
-            CallSessionConnectWhenArrange = true;
-
             void SendKeyExchangeInit()
             {
                 var keyExchangeInitMessage = new KeyExchangeInitMessage
@@ -204,7 +201,7 @@ namespace Renci.SshNet.Tests.Classes
             _ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange())
                                   .Returns(_protocolVersionExchangeMock.Object);
             _ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout))
-                                            .Returns(() => ServerIdentification);
+                                            .Returns(ServerIdentification);
             _ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
             _ = _keyExchangeMock.Setup(p => p.Name)
                                 .Returns(_keyExchangeAlgorithm);
@@ -252,10 +249,7 @@ namespace Renci.SshNet.Tests.Classes
             SetupData();
             SetupMocks();
 
-            if (CallSessionConnectWhenArrange)
-            {
-                Session.Connect();
-            }
+            Session.Connect();
         }
 
         protected virtual void ClientAuthentication_Callback()

+ 294 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_ConnectingBase.cs

@@ -0,0 +1,294 @@
+using System;
+using System.Collections.Generic;
+using System.Globalization;
+using System.Net;
+using System.Net.Sockets;
+using System.Security.Cryptography;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Moq;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Compression;
+using Renci.SshNet.Connection;
+using Renci.SshNet.Messages;
+using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Security;
+using Renci.SshNet.Security.Cryptography;
+using Renci.SshNet.Tests.Common;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public abstract class SessionTest_ConnectingBase
+    {
+        internal Mock<IServiceFactory> ServiceFactoryMock { get; private set; }
+        internal Mock<ISocketFactory> SocketFactoryMock { get; private set; }
+        internal Mock<IConnector> ConnectorMock { get; private set; }
+
+        private Mock<IProtocolVersionExchange> _protocolVersionExchangeMock;
+        private Mock<IKeyExchange> _keyExchangeMock;
+        private Mock<IClientAuthentication> _clientAuthenticationMock;
+        private IPEndPoint _serverEndPoint;
+        private string[] _keyExchangeAlgorithms;
+        private bool _authenticationStarted;
+        private SocketFactory _socketFactory;
+
+        protected Random Random { get; private set; }
+        protected byte[] SessionId { get; private set; }
+        protected ConnectionInfo ConnectionInfo { get; private set; }
+        protected IList<EventArgs> DisconnectedRegister { get; private set; }
+        protected IList<MessageEventArgs<DisconnectMessage>> DisconnectReceivedRegister { get; private set; }
+        protected IList<ExceptionEventArgs> ErrorOccurredRegister { get; private set; }
+        protected AsyncSocketListener ServerListener { get; private set; }
+        protected IList<byte[]> ServerBytesReceivedRegister { get; private set; }
+        protected Session Session { get; private set; }
+        protected Socket ClientSocket { get; private set; }
+        protected Socket ServerSocket { get; private set; }
+        protected SshIdentification ServerIdentification { get; set; }
+        protected virtual bool ServerSupportsStrictKex { get; }
+
+        protected virtual bool ServerResetsSequenceAfterSendingNewKeys
+        {
+            get
+            {
+                return ServerSupportsStrictKex;
+            }
+        }
+
+        protected uint ServerOutboundPacketSequence { get; set; }
+
+        [TestInitialize]
+        public void Setup()
+        {
+            CreateMocks();
+            SetupData();
+            SetupMocks();
+        }
+
+        protected virtual void ActionBeforeKexInit()
+        {
+        }
+
+        protected virtual void ActionAfterKexInit()
+        {
+        }
+
+        [TestCleanup]
+        public void TearDown()
+        {
+            if (ServerListener != null)
+            {
+                ServerListener.Dispose();
+                ServerListener = null;
+            }
+
+            if (ServerSocket != null)
+            {
+                ServerSocket.Dispose();
+                ServerSocket = null;
+            }
+
+            if (Session != null)
+            {
+                Session.Dispose();
+                Session = null;
+            }
+
+            if (ClientSocket != null && ClientSocket.Connected)
+            {
+                ClientSocket.Shutdown(SocketShutdown.Both);
+                ClientSocket.Dispose();
+            }
+        }
+
+        protected virtual void SetupData()
+        {
+            Random = new Random();
+
+            _serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+            ConnectionInfo = new ConnectionInfo(
+                _serverEndPoint.Address.ToString(),
+                _serverEndPoint.Port,
+                "user",
+                new PasswordAuthenticationMethod("user", "password"))
+            { Timeout = TimeSpan.FromSeconds(20) };
+            _keyExchangeAlgorithms = ServerSupportsStrictKex ?
+                                     [Random.Next().ToString(CultureInfo.InvariantCulture), "kex-strict-s-v00@openssh.com"] :
+                                     [Random.Next().ToString(CultureInfo.InvariantCulture)];
+            SessionId = new byte[10];
+            Random.NextBytes(SessionId);
+            DisconnectedRegister = new List<EventArgs>();
+            DisconnectReceivedRegister = new List<MessageEventArgs<DisconnectMessage>>();
+            ErrorOccurredRegister = new List<ExceptionEventArgs>();
+            ServerBytesReceivedRegister = new List<byte[]>();
+            ServerIdentification = new SshIdentification("2.0", "OurServerStub");
+            _authenticationStarted = false;
+            _socketFactory = new SocketFactory();
+
+            Session = new Session(ConnectionInfo, ServiceFactoryMock.Object, SocketFactoryMock.Object);
+            Session.Disconnected += (sender, args) => DisconnectedRegister.Add(args);
+            Session.DisconnectReceived += (sender, args) => DisconnectReceivedRegister.Add(args);
+            Session.ErrorOccured += (sender, args) => ErrorOccurredRegister.Add(args);
+
+            ServerListener = new AsyncSocketListener(_serverEndPoint)
+            {
+                ShutdownRemoteCommunicationSocket = false
+            };
+            ServerListener.Connected += socket =>
+            {
+                ServerSocket = socket;
+                ActionBeforeKexInit();
+                var keyExchangeInitMessage = new KeyExchangeInitMessage
+                {
+                    CompressionAlgorithmsClientToServer = new string[0],
+                    CompressionAlgorithmsServerToClient = new string[0],
+                    EncryptionAlgorithmsClientToServer = new string[0],
+                    EncryptionAlgorithmsServerToClient = new string[0],
+                    KeyExchangeAlgorithms = _keyExchangeAlgorithms,
+                    LanguagesClientToServer = new string[0],
+                    LanguagesServerToClient = new string[0],
+                    MacAlgorithmsClientToServer = new string[0],
+                    MacAlgorithmsServerToClient = new string[0],
+                    ServerHostKeyAlgorithms = new string[0]
+                };
+                var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null);
+                _ = ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None);
+                ServerOutboundPacketSequence++;
+            };
+            ServerListener.BytesReceived += (received, socket) =>
+            {
+                ServerBytesReceivedRegister.Add(received);
+
+                if (received.Length > 5 && received[5] == 20)
+                {
+                    ActionAfterKexInit();
+                    var newKeysMessage = new NewKeysMessage();
+                    var newKeys = newKeysMessage.GetPacket(8, null);
+                    _ = ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None);
+
+                    if (ServerResetsSequenceAfterSendingNewKeys)
+                    {
+                        ServerOutboundPacketSequence = 0;
+                    }
+                    else
+                    {
+                        ServerOutboundPacketSequence++;
+                    }
+
+                    if (!_authenticationStarted)
+                    {
+                        var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication)
+                                                                              .Build(ServerOutboundPacketSequence);
+                        var hash = Abstractions.CryptoAbstraction.CreateSHA256().ComputeHash(serviceAcceptMessage);
+
+                        var packet = new byte[serviceAcceptMessage.Length - 4 + hash.Length];
+
+                        Array.Copy(serviceAcceptMessage, 4, packet, 0, serviceAcceptMessage.Length - 4);
+                        Array.Copy(hash, 0, packet, serviceAcceptMessage.Length - 4, hash.Length);
+
+                        _ = ServerSocket.Send(packet, 0, packet.Length, SocketFlags.None);
+
+                        ServerOutboundPacketSequence++;
+
+                        _authenticationStarted = true;
+                    }
+                }
+            };
+            ServerListener.Start();
+
+            ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo);
+        }
+
+        private void CreateMocks()
+        {
+            ServiceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            SocketFactoryMock = new Mock<ISocketFactory>(MockBehavior.Strict);
+            ConnectorMock = new Mock<IConnector>(MockBehavior.Strict);
+            _protocolVersionExchangeMock = new Mock<IProtocolVersionExchange>(MockBehavior.Strict);
+            _keyExchangeMock = new Mock<IKeyExchange>(MockBehavior.Strict);
+            _clientAuthenticationMock = new Mock<IClientAuthentication>(MockBehavior.Strict);
+        }
+
+        private void SetupMocks()
+        {
+            _ = ServiceFactoryMock.Setup(p => p.CreateConnector(ConnectionInfo, SocketFactoryMock.Object))
+                                  .Returns(ConnectorMock.Object);
+            _ = ConnectorMock.Setup(p => p.Connect(ConnectionInfo))
+                             .Returns(ClientSocket);
+            _ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange())
+                                  .Returns(_protocolVersionExchangeMock.Object);
+            _ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout))
+                                            .Returns(() => ServerIdentification);
+            _ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, _keyExchangeAlgorithms)).Returns(_keyExchangeMock.Object);
+
+            _ = _keyExchangeMock.Setup(p => p.Name)
+                                .Returns(_keyExchangeAlgorithms[0]);
+            _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny<KeyExchangeInitMessage>(), false));
+            _ = _keyExchangeMock.Setup(p => p.ExchangeHash)
+                                .Returns(SessionId);
+            _ = _keyExchangeMock.Setup(p => p.CreateServerCipher(out It.Ref<bool>.IsAny))
+                                .Returns((ref bool serverAead) =>
+                                {
+                                    serverAead = false;
+                                    return (Cipher) null;
+                                });
+            _ = _keyExchangeMock.Setup(p => p.CreateClientCipher(out It.Ref<bool>.IsAny))
+                                .Returns((ref bool clientAead) =>
+                                {
+                                    clientAead = false;
+                                    return (Cipher) null;
+                                });
+            _ = _keyExchangeMock.Setup(p => p.CreateServerHash(out It.Ref<bool>.IsAny))
+                                .Returns((ref bool serverEtm) =>
+                                {
+                                    serverEtm = false;
+                                    return SHA256.Create();
+                                });
+            _ = _keyExchangeMock.Setup(p => p.CreateClientHash(out It.Ref<bool>.IsAny))
+                                .Returns((ref bool clientEtm) =>
+                                {
+                                    clientEtm = false;
+                                    return (HashAlgorithm) null;
+                                });
+            _ = _keyExchangeMock.Setup(p => p.CreateCompressor())
+                                .Returns((Compressor) null);
+            _ = _keyExchangeMock.Setup(p => p.CreateDecompressor())
+                                .Returns((Compressor) null);
+            _ = _keyExchangeMock.Setup(p => p.Dispose());
+            _ = ServiceFactoryMock.Setup(p => p.CreateClientAuthentication())
+                                  .Returns(_clientAuthenticationMock.Object);
+            _ = _clientAuthenticationMock.Setup(p => p.Authenticate(ConnectionInfo, Session));
+        }
+
+        private class ServiceAcceptMessageBuilder
+        {
+            private readonly ServiceName _serviceName;
+
+            private ServiceAcceptMessageBuilder(ServiceName serviceName)
+            {
+                _serviceName = serviceName;
+            }
+
+            public static ServiceAcceptMessageBuilder Create(ServiceName serviceName)
+            {
+                return new ServiceAcceptMessageBuilder(serviceName);
+            }
+
+            public byte[] Build(uint sequence)
+            {
+                var serviceName = _serviceName.ToArray();
+                var target = new ServiceAcceptMessage();
+
+                var sshDataStream = new SshDataStream(4 + 4 + 1 + 1 + 4 + serviceName.Length);
+                sshDataStream.Write(sequence);
+                sshDataStream.Write((uint) (sshDataStream.Capacity - 8)); //sequence and packet length
+                sshDataStream.WriteByte(0); // padding length
+                sshDataStream.WriteByte(target.MessageNumber);
+                sshDataStream.WriteBinary(serviceName);
+                return sshDataStream.ToArray();
+            }
+        }
+    }
+}

+ 1 - 7
test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerIdentificationReceived.cs → test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerIdentificationReceived.cs

@@ -5,14 +5,12 @@ using Renci.SshNet.Connection;
 namespace Renci.SshNet.Tests.Classes
 {
     [TestClass]
-    public class SessionTest_Connected_ServerIdentificationReceived : SessionTest_ConnectedBase
+    public class SessionTest_Connecting_ServerIdentificationReceived : SessionTest_ConnectingBase
     {
         protected override void SetupData()
         {
             base.SetupData();
 
-            CallSessionConnectWhenArrange = false;
-
             Session.ServerIdentificationReceived += (s, e) =>
             {
                 if ((e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.5", System.StringComparison.Ordinal) || e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.6", System.StringComparison.Ordinal))
@@ -24,10 +22,6 @@ namespace Renci.SshNet.Tests.Classes
             };
         }
 
-        protected override void Act()
-        {
-        }
-
         [TestMethod]
         [DataRow("OpenSSH_6.5")]
         [DataRow("OpenSSH_6.5p1")]

+ 35 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerNotResetSequenceNumberAfterNewKeys_StrictKex.cs

@@ -0,0 +1,35 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connecting_ServerNotResetSequenceNumberAfterNewKeys_StrictKex : SessionTest_ConnectingBase
+    {
+        protected override bool ServerSupportsStrictKex
+        {
+            get
+            {
+                return true;
+            }
+        }
+
+        protected override bool ServerResetsSequenceAfterSendingNewKeys
+        {
+            get
+            {
+                return false;
+            }
+        }
+
+
+        [TestMethod]
+        public void ShouldThrowSshConnectionException()
+        {
+            var reason = Assert.ThrowsException<SshConnectionException>(Session.Connect).DisconnectReason;
+            Assert.AreEqual(DisconnectReason.MacError, reason);
+        }
+    }
+}

+ 31 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerResetsSequenceNumberAfterNewKeys_StrictKex.cs

@@ -0,0 +1,31 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connecting_ServerResetsSequenceNumberAfterNewKeys_StrictKex : SessionTest_ConnectingBase
+    {
+        protected override bool ServerSupportsStrictKex
+        {
+            get
+            {
+                return true;
+            }
+        }
+
+        protected override bool ServerResetsSequenceAfterSendingNewKeys
+        {
+            get
+            {
+                return true;
+            }
+        }
+
+
+        [TestMethod]
+        public void ShouldNotThrowException()
+        {
+            Session.Connect();
+        }
+    }
+}

+ 48 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs

@@ -0,0 +1,48 @@
+using System.Globalization;
+using System.Net.Sockets;
+using System.Text;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex : SessionTest_ConnectingBase
+    {
+        protected override bool ServerSupportsStrictKex
+        {
+            get
+            {
+                return true;
+            }
+        }
+
+        protected override void ActionAfterKexInit()
+        {
+            using var stream = new SshDataStream(0);
+            stream.WriteByte(1);
+            stream.Write("This is a debug message", Encoding.UTF8);
+            stream.Write(CultureInfo.CurrentCulture.Name, Encoding.UTF8);
+
+            var debugMessage = new DebugMessage();
+            debugMessage.Load(stream.ToArray());
+            var debug = debugMessage.GetPacket(8, null);
+
+            // MitM sends debug message to client
+            _ = ServerSocket.Send(debug, 4, debug.Length - 4, SocketFlags.None);
+
+            // MitM drops server message
+            ServerOutboundPacketSequence++;
+        }
+
+        [TestMethod]
+        public void ShouldThrowSshException()
+        {
+            var message = Assert.ThrowsException<SshException>(Session.Connect).Message;
+            Assert.AreEqual("Message type 4 is not valid in the current context.", message);
+        }
+    }
+}

+ 39 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs

@@ -0,0 +1,39 @@
+using System.Net.Sockets;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex : SessionTest_ConnectingBase
+    {
+        protected override bool ServerSupportsStrictKex
+        {
+            get
+            {
+                return true;
+            }
+        }
+
+        protected override void ActionAfterKexInit()
+        {
+            var disconnectMessage = new DisconnectMessage(DisconnectReason.TooManyConnections, "too many connections");
+            var disconnect = disconnectMessage.GetPacket(8, null);
+
+            // Server sends disconnect message to client
+            _ = ServerSocket.Send(disconnect, 4, disconnect.Length - 4, SocketFlags.None);
+
+            ServerOutboundPacketSequence++;
+        }
+
+        [TestMethod]
+        public void DisconnectIsAllowedDuringStrictKex()
+        {
+            var exception = Assert.ThrowsException<SshConnectionException>(Session.Connect);
+            Assert.AreEqual(DisconnectReason.TooManyConnections, exception.DisconnectReason);
+        }
+    }
+}

+ 38 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs

@@ -0,0 +1,38 @@
+using System.Net.Sockets;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex : SessionTest_ConnectingBase
+    {
+        protected override bool ServerSupportsStrictKex
+        {
+            get
+            {
+                return false;
+            }
+        }
+
+        protected override void ActionAfterKexInit()
+        {
+            var ignoreMessage = new IgnoreMessage();
+            var ignore = ignoreMessage.GetPacket(8, null);
+
+            // MitM sends ignore message to client
+            _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None);
+
+            // MitM drops server message
+            ServerOutboundPacketSequence++;
+        }
+
+        [TestMethod]
+        public void DoesNotThrowException()
+        {
+            Session.Connect();
+        }
+    }
+}

+ 40 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs

@@ -0,0 +1,40 @@
+using System.Net.Sockets;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex : SessionTest_ConnectingBase
+    {
+        protected override bool ServerSupportsStrictKex
+        {
+            get
+            {
+                return true;
+            }
+        }
+
+        protected override void ActionAfterKexInit()
+        {
+            var ignoreMessage = new IgnoreMessage();
+            var ignore = ignoreMessage.GetPacket(8, null);
+
+            // MitM sends ignore message to client
+            _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None);
+
+            // MitM drops server message
+            ServerOutboundPacketSequence++;
+        }
+
+        [TestMethod]
+        public void ShouldThrowSshException()
+        {
+            var message = Assert.ThrowsException<SshException>(Session.Connect).Message;
+            Assert.AreEqual("Message type 2 is not valid in the current context.", message);
+        }
+    }
+}

+ 38 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs

@@ -0,0 +1,38 @@
+using System.Net.Sockets;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex : SessionTest_ConnectingBase
+    {
+        protected override bool ServerSupportsStrictKex
+        {
+            get
+            {
+                return false;
+            }
+        }
+
+        protected override void ActionBeforeKexInit()
+        {
+            var ignoreMessage = new IgnoreMessage();
+            var ignore = ignoreMessage.GetPacket(8, null);
+
+            // MitM sends ignore message to client
+            _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None);
+
+            // MitM drops server message
+            ServerOutboundPacketSequence++;
+        }
+
+        [TestMethod]
+        public void DoesNotThrowException()
+        {
+            Session.Connect();
+        }
+    }
+}

+ 41 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs

@@ -0,0 +1,41 @@
+using System.Net.Sockets;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex : SessionTest_ConnectingBase
+    {
+        protected override bool ServerSupportsStrictKex
+        {
+            get
+            {
+                return true;
+            }
+        }
+
+        protected override void ActionBeforeKexInit()
+        {
+            var ignoreMessage = new IgnoreMessage();
+            var ignore = ignoreMessage.GetPacket(8, null);
+
+            // MitM sends ignore message to client
+            _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None);
+
+            // MitM drops server message
+            ServerOutboundPacketSequence++;
+        }
+
+        [TestMethod]
+        public void ShouldThrowSshConnectionException()
+        {
+            var exception = Assert.ThrowsException<SshConnectionException>(Session.Connect);
+            Assert.AreEqual(DisconnectReason.KeyExchangeFailed, exception.DisconnectReason);
+            Assert.AreEqual("KEXINIT was not the first packet during strict key exchange.", exception.Message);
+        }
+    }
+}

+ 2 - 2
test/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs

@@ -57,7 +57,7 @@ namespace Renci.SshNet.Tests.Classes
         }
 
         [TestMethod]
-        public void SendMessageShouldThrowShhConnectionException()
+        public void SendMessageShouldThrowSshConnectionException()
         {
             try
             {
@@ -159,7 +159,7 @@ namespace Renci.SshNet.Tests.Classes
         }
 
         [TestMethod]
-        public void ISession_SendMessageShouldThrowShhConnectionException()
+        public void ISession_SendMessageShouldThrowSshConnectionException()
         {
             var session = (ISession) _session;
 

+ 1 - 1
test/Renci.SshNet.Tests/Common/AsyncSocketListener.cs

@@ -385,7 +385,7 @@ namespace Renci.SshNet.Tests.Common
         {
             public Socket Socket { get; private set; }
 
-            public readonly byte[] Buffer = new byte[1024];
+            public readonly byte[] Buffer = new byte[2048];
 
             public SocketStateObject(Socket handler)
             {