Explorar o código

Send the client key exchange init in Connect (#1274)

* Send the client key exchange init in Connect

* Add a test

---------

Co-authored-by: Wojciech Nagórski <wojtpl2@gmail.com>
Rob Hague hai 1 ano
pai
achega
34b5123f0a

+ 3 - 2
src/Renci.SshNet/Security/IKeyExchange.cs

@@ -38,8 +38,9 @@ namespace Renci.SshNet.Security
         /// Starts the key exchange algorithm.
         /// </summary>
         /// <param name="session">The session.</param>
-        /// <param name="message">Key exchange init message.</param>
-        void Start(Session session, KeyExchangeInitMessage message);
+        /// <param name="message">The key exchange init message received from the server.</param>
+        /// <param name="sendClientInitMessage">Whether to send a key exchange init message in response.</param>
+        void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage);
 
         /// <summary>
         /// Finishes the key exchange algorithm.

+ 6 - 7
src/Renci.SshNet/Security/KeyExchange.cs

@@ -61,16 +61,15 @@ namespace Renci.SshNet.Security
         /// </summary>
         public event EventHandler<HostKeyEventArgs> HostKeyReceived;
 
-        /// <summary>
-        /// Starts key exchange algorithm.
-        /// </summary>
-        /// <param name="session">The session.</param>
-        /// <param name="message">Key exchange init message.</param>
-        public virtual void Start(Session session, KeyExchangeInitMessage message)
+        /// <inheritdoc/>
+        public virtual void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
         {
             Session = session;
 
-            SendMessage(session.ClientInitMessage);
+            if (sendClientInitMessage)
+            {
+                SendMessage(session.ClientInitMessage);
+            }
 
             // Determine encryption algorithm
             var clientEncryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys

+ 3 - 7
src/Renci.SshNet/Security/KeyExchangeDiffieHellman.cs

@@ -76,14 +76,10 @@ namespace Renci.SshNet.Security
             return ValidateExchangeHash(_hostKey, _signature);
         }
 
-        /// <summary>
-        /// Starts key exchange algorithm.
-        /// </summary>
-        /// <param name="session">The session.</param>
-        /// <param name="message">Key exchange init message.</param>
-        public override void Start(Session session, KeyExchangeInitMessage message)
+        /// <inheritdoc/>
+        public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
         {
-            base.Start(session, message);
+            base.Start(session, message, sendClientInitMessage);
 
             _serverPayload = message.GetBytes();
             _clientPayload = Session.ClientInitMessage.GetBytes();

+ 3 - 7
src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupExchangeShaBase.cs

@@ -39,14 +39,10 @@ namespace Renci.SshNet.Security
             return Hash(groupExchangeHashData.GetBytes());
         }
 
-        /// <summary>
-        /// Starts key exchange algorithm.
-        /// </summary>
-        /// <param name="session">The session.</param>
-        /// <param name="message">Key exchange init message.</param>
-        public override void Start(Session session, KeyExchangeInitMessage message)
+        /// <inheritdoc/>
+        public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
         {
-            base.Start(session, message);
+            base.Start(session, message, sendClientInitMessage);
 
             // Register SSH_MSG_KEX_DH_GEX_GROUP message
             Session.RegisterMessage("SSH_MSG_KEX_DH_GEX_GROUP");

+ 3 - 7
src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupShaBase.cs

@@ -13,14 +13,10 @@ namespace Renci.SshNet.Security
         /// </value>
         public abstract BigInteger GroupPrime { get; }
 
-        /// <summary>
-        /// Starts key exchange algorithm.
-        /// </summary>
-        /// <param name="session">The session.</param>
-        /// <param name="message">Key exchange init message.</param>
-        public override void Start(Session session, KeyExchangeInitMessage message)
+        /// <inheritdoc/>
+        public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
         {
-            base.Start(session, message);
+            base.Start(session, message, sendClientInitMessage);
 
             Session.RegisterMessage("SSH_MSG_KEXDH_REPLY");
 

+ 3 - 7
src/Renci.SshNet/Security/KeyExchangeEC.cs

@@ -78,14 +78,10 @@ namespace Renci.SshNet.Security
             return ValidateExchangeHash(_hostKey, _signature);
         }
 
-        /// <summary>
-        /// Starts key exchange algorithm.
-        /// </summary>
-        /// <param name="session">The session.</param>
-        /// <param name="message">Key exchange init message.</param>
-        public override void Start(Session session, KeyExchangeInitMessage message)
+        /// <inheritdoc/>
+        public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
         {
-            base.Start(session, message);
+            base.Start(session, message, sendClientInitMessage);
 
             _serverPayload = message.GetBytes();
             _clientPayload = Session.ClientInitMessage.GetBytes();

+ 3 - 7
src/Renci.SshNet/Security/KeyExchangeECCurve25519.cs

@@ -29,14 +29,10 @@ namespace Renci.SshNet.Security
             get { return 256; }
         }
 
-        /// <summary>
-        /// Starts key exchange algorithm.
-        /// </summary>
-        /// <param name="session">The session.</param>
-        /// <param name="message">Key exchange init message.</param>
-        public override void Start(Session session, KeyExchangeInitMessage message)
+        /// <inheritdoc/>
+        public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
         {
-            base.Start(session, message);
+            base.Start(session, message, sendClientInitMessage);
 
             Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY");
 

+ 3 - 7
src/Renci.SshNet/Security/KeyExchangeECDH.cs

@@ -24,14 +24,10 @@ namespace Renci.SshNet.Security
         private ECDHCBasicAgreement _keyAgreement;
         private ECDomainParameters _domainParameters;
 
-        /// <summary>
-        /// Starts key exchange algorithm.
-        /// </summary>
-        /// <param name="session">The session.</param>
-        /// <param name="message">Key exchange init message.</param>
-        public override void Start(Session session, KeyExchangeInitMessage message)
+        /// <inheritdoc/>
+        public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
         {
-            base.Start(session, message);
+            base.Start(session, message, sendClientInitMessage);
 
             Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY");
 

+ 26 - 18
src/Renci.SshNet/Session.cs

@@ -160,12 +160,7 @@ namespace Renci.SshNet
         /// <summary>
         /// WaitHandle to signal that key exchange was completed.
         /// </summary>
-        private EventWaitHandle _keyExchangeCompletedWaitHandle = new ManualResetEvent(initialState: false);
-
-        /// <summary>
-        /// WaitHandle to signal that key exchange is in progress.
-        /// </summary>
-        private bool _keyExchangeInProgress;
+        private ManualResetEventSlim _keyExchangeCompletedWaitHandle = new ManualResetEventSlim(initialState: false);
 
         /// <summary>
         /// Exception that need to be thrown by waiting thread.
@@ -643,6 +638,11 @@ namespace Renci.SshNet
                     // Some server implementations might sent this message first, prior to establishing encryption algorithm
                     RegisterMessage("SSH_MSG_USERAUTH_BANNER");
 
+                    // Send our key exchange init.
+                    // We need to do this before starting the message listener to avoid the case where we receive the server
+                    // key exchange init and we continue the key exchange before having sent our own init.
+                    SendMessage(ClientInitMessage);
+
                     // Mark the message listener threads as started
                     _ = _messageListenerCompleted.Reset();
 
@@ -651,7 +651,7 @@ namespace Renci.SshNet
                     _ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);
 
                     // Wait for key exchange to be completed
-                    WaitOnHandle(_keyExchangeCompletedWaitHandle);
+                    WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
 
                     // If sessionId is not set then its not connected
                     if (SessionId is null)
@@ -757,6 +757,11 @@ namespace Renci.SshNet
             // Some server implementations might sent this message first, prior to establishing encryption algorithm
             RegisterMessage("SSH_MSG_USERAUTH_BANNER");
 
+            // Send our key exchange init.
+            // We need to do this before starting the message listener to avoid the case where we receive the server
+            // key exchange init and we continue the key exchange before having sent our own init.
+            SendMessage(ClientInitMessage);
+
             // Mark the message listener threads as started
             _ = _messageListenerCompleted.Reset();
 
@@ -765,7 +770,7 @@ namespace Renci.SshNet
             _ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);
 
             // Wait for key exchange to be completed
-            WaitOnHandle(_keyExchangeCompletedWaitHandle);
+            WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
 
             // If sessionId is not set then its not connected
             if (SessionId is null)
@@ -1046,10 +1051,10 @@ namespace Renci.SshNet
                 throw new SshConnectionException("Client not connected.");
             }
 
-            if (_keyExchangeInProgress && message is not IKeyExchangedAllowed)
+            if (!_keyExchangeCompletedWaitHandle.IsSet && message is not IKeyExchangedAllowed)
             {
                 // Wait for key exchange to be completed
-                WaitOnHandle(_keyExchangeCompletedWaitHandle);
+                WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
             }
 
             DiagnosticAbstraction.Log(string.Format("[{0}] Sending message '{1}' to server: '{2}'.", ToHex(SessionId), message.GetType().Name, message));
@@ -1394,9 +1399,15 @@ namespace Renci.SshNet
         /// <param name="message"><see cref="KeyExchangeInitMessage"/> message.</param>
         internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
         {
-            _keyExchangeInProgress = true;
+            // If _keyExchangeCompletedWaitHandle is already set, then this is a key
+            // re-exchange initiated by the server, and we need to send our own init
+            // message.
+            // Otherwise, the wait handle is not set and this received init is part of the
+            // initial connection for which we have already sent our init, so we shouldn't
+            // send another one.
+            var sendClientInitMessage = _keyExchangeCompletedWaitHandle.IsSet;
 
-            _ = _keyExchangeCompletedWaitHandle.Reset();
+            _keyExchangeCompletedWaitHandle.Reset();
 
             // Disable messages that are not key exchange related
             _sshMessageFactory.DisableNonKeyExchangeMessages();
@@ -1411,7 +1422,7 @@ namespace Renci.SshNet
             _keyExchange.HostKeyReceived += KeyExchange_HostKeyReceived;
 
             // Start the algorithm implementation
-            _keyExchange.Start(this, message);
+            _keyExchange.Start(this, message, sendClientInitMessage);
 
             KeyExchangeInitReceived?.Invoke(this, new MessageEventArgs<KeyExchangeInitMessage>(message));
         }
@@ -1477,9 +1488,7 @@ namespace Renci.SshNet
             NewKeysReceived?.Invoke(this, new MessageEventArgs<NewKeysMessage>(message));
 
             // Signal that key exchange completed
-            _ = _keyExchangeCompletedWaitHandle.Set();
-
-            _keyExchangeInProgress = false;
+            _keyExchangeCompletedWaitHandle.Set();
         }
 
         /// <summary>
@@ -1967,7 +1976,7 @@ namespace Renci.SshNet
         private void Reset()
         {
             _ = _exceptionWaitHandle?.Reset();
-            _ = _keyExchangeCompletedWaitHandle?.Reset();
+            _keyExchangeCompletedWaitHandle?.Reset();
             _ = _messageListenerCompleted?.Set();
 
             SessionId = null;
@@ -1975,7 +1984,6 @@ namespace Renci.SshNet
             _isDisconnecting = false;
             _isAuthenticated = false;
             _exception = null;
-            _keyExchangeInProgress = false;
         }
 
         private static SshConnectionException CreateConnectionAbortedByServerException()

+ 50 - 29
test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs

@@ -49,6 +49,12 @@ namespace Renci.SshNet.Tests.Classes
         internal SshIdentification ServerIdentification { get; set; }
         protected bool CallSessionConnectWhenArrange { get; set; }
 
+        /// <summary>
+        /// Should the "server" wait for the client kexinit before sending its own.
+        /// A regression test simulating e.g. cisco devices.
+        /// </summary>
+        protected bool WaitForClientKeyExchangeInit { get; set; }
+
         [TestInitialize]
         public void Setup()
         {
@@ -59,18 +65,18 @@ namespace Renci.SshNet.Tests.Classes
         [TestCleanup]
         public void TearDown()
         {
-            if (ServerSocket != null)
-            {
-                ServerSocket.Dispose();
-                ServerSocket = null;
-            }
-
             if (ServerListener != null)
             {
                 ServerListener.Dispose();
                 ServerListener = null;
             }
 
+            if (ServerSocket != null)
+            {
+                ServerSocket.Dispose();
+                ServerSocket = null;
+            }
+
             if (Session != null)
             {
                 Session.Dispose();
@@ -115,6 +121,15 @@ namespace Renci.SshNet.Tests.Classes
                     var newKeysMessage = new NewKeysMessage();
                     var newKeys = newKeysMessage.GetPacket(8, null);
                     _ = ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None);
+
+                    if (!_authenticationStarted)
+                    {
+                        var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication)
+                                                                              .Build();
+                        _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None);
+
+                        _authenticationStarted = true;
+                    }
                 };
 
             ServerListener = new AsyncSocketListener(_serverEndPoint)
@@ -125,36 +140,23 @@ namespace Renci.SshNet.Tests.Classes
                 {
                     ServerSocket = socket;
 
-                    // Since we're mocking the protocol version exchange, we'll immediately stat KEX upon
+                    // Since we're mocking the protocol version exchange, we'll immediately start KEX upon
                     // having established the connection instead of when the client has been identified
 
-                    var keyExchangeInitMessage = new KeyExchangeInitMessage
-                        {
-                            CompressionAlgorithmsClientToServer = new string[0],
-                            CompressionAlgorithmsServerToClient = new string[0],
-                            EncryptionAlgorithmsClientToServer = new string[0],
-                            EncryptionAlgorithmsServerToClient = new string[0],
-                            KeyExchangeAlgorithms = new[] { _keyExchangeAlgorithm },
-                            LanguagesClientToServer = new string[0],
-                            LanguagesServerToClient = new string[0],
-                            MacAlgorithmsClientToServer = new string[0],
-                            MacAlgorithmsServerToClient = new string[0],
-                            ServerHostKeyAlgorithms = new string[0]
-                        };
-                    var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null);
-                    _ = ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None);
+                    if (!WaitForClientKeyExchangeInit)
+                    {
+                        SendKeyExchangeInit();
+                    }
                 };
             ServerListener.BytesReceived += (received, socket) =>
                 {
                     ServerBytesReceivedRegister.Add(received);
 
-                    if (!_authenticationStarted)
+                    if (WaitForClientKeyExchangeInit && received.Length > 5 && received[5] == 20)
                     {
-                        var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication)
-                                                                              .Build();
-                        _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None);
-
-                        _authenticationStarted = true;
+                        // This is the KEXINIT. Send one back.
+                        SendKeyExchangeInit();
+                        WaitForClientKeyExchangeInit = false;
                     }
                 };
             ServerListener.Start();
@@ -162,6 +164,25 @@ namespace Renci.SshNet.Tests.Classes
             ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo);
 
             CallSessionConnectWhenArrange = true;
+
+            void SendKeyExchangeInit()
+            {
+                var keyExchangeInitMessage = new KeyExchangeInitMessage
+                {
+                    CompressionAlgorithmsClientToServer = new string[0],
+                    CompressionAlgorithmsServerToClient = new string[0],
+                    EncryptionAlgorithmsClientToServer = new string[0],
+                    EncryptionAlgorithmsServerToClient = new string[0],
+                    KeyExchangeAlgorithms = new[] { _keyExchangeAlgorithm },
+                    LanguagesClientToServer = new string[0],
+                    LanguagesServerToClient = new string[0],
+                    MacAlgorithmsClientToServer = new string[0],
+                    MacAlgorithmsServerToClient = new string[0],
+                    ServerHostKeyAlgorithms = new string[0]
+                };
+                var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null);
+                _ = ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None);
+            }
         }
 
         private void CreateMocks()
@@ -187,7 +208,7 @@ namespace Renci.SshNet.Tests.Classes
             _ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
             _ = _keyExchangeMock.Setup(p => p.Name)
                                 .Returns(_keyExchangeAlgorithm);
-            _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny<KeyExchangeInitMessage>()));
+            _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny<KeyExchangeInitMessage>(), false));
             _ = _keyExchangeMock.Setup(p => p.ExchangeHash)
                                 .Returns(SessionId);
             _ = _keyExchangeMock.Setup(p => p.CreateServerCipher())

+ 8 - 8
test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs

@@ -89,6 +89,13 @@ namespace Renci.SshNet.Tests.Classes
                     var newKeysMessage = new NewKeysMessage();
                     var newKeys = newKeysMessage.GetPacket(8, null);
                     _ = ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None);
+
+                    if (!_authenticationStarted)
+                    {
+                        var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication).Build();
+                        _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None);
+                        _authenticationStarted = true;
+                    }
                 };
 
             ServerListener = new AsyncSocketListener(_serverEndPoint);
@@ -118,13 +125,6 @@ namespace Renci.SshNet.Tests.Classes
             ServerListener.BytesReceived += (received, socket) =>
                 {
                     ServerBytesReceivedRegister.Add(received);
-
-                    if (!_authenticationStarted)
-                    {
-                        var serviceAcceptMessage =ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication).Build();
-                        _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None);
-                        _authenticationStarted = true;
-                    }
                 };
 
             ServerListener.Start();
@@ -156,7 +156,7 @@ namespace Renci.SshNet.Tests.Classes
                                    .Returns(_keyExchangeMock.Object);
             _ = _keyExchangeMock.Setup(p => p.Name)
                                 .Returns(_keyExchangeAlgorithm);
-            _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny<KeyExchangeInitMessage>()));
+            _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny<KeyExchangeInitMessage>(), false));
             _ = _keyExchangeMock.Setup(p => p.ExchangeHash)
                                 .Returns(SessionId);
             _ = _keyExchangeMock.Setup(p => p.CreateServerCipher())

+ 24 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerDoesNotSendKexInit.cs

@@ -0,0 +1,24 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connected_ServerDoesNotSendKexInit : SessionTest_ConnectedBase
+    {
+        protected override void SetupData()
+        {
+            WaitForClientKeyExchangeInit = true;
+
+            base.SetupData();
+        }
+
+        protected override void Act()
+        {
+        }
+
+        [TestMethod]
+        public void ConnectShouldSucceed()
+        {
+        }
+    }
+}