Ver Fonte

Expose SshIdentificationReceived event (#1195)

* Fix https://github.com/sshnet/SSH.NET/issues/1191

* Expose `SshIdentificationReceived` event so that lib consumer can adjust based on server identification

* revert unrelated code style change

* revert OpenSSH 6.6 related tests

* revert ConnectionBase

* Add unit tests

* Rename to `ServerIdentificationReceived`

* rename
Scott Xu há 1 ano atrás
pai
commit
f9f2b0e5f4

+ 13 - 0
src/Renci.SshNet/BaseClient.cs

@@ -153,6 +153,11 @@ namespace Renci.SshNet
         /// </example>
         public event EventHandler<HostKeyEventArgs> HostKeyReceived;
 
+        /// <summary>
+        /// Occurs when server identification received.
+        /// </summary>
+        public event EventHandler<SshIdentificationEventArgs> ServerIdentificationReceived;
+
         /// <summary>
         /// Initializes a new instance of the <see cref="BaseClient"/> class.
         /// </summary>
@@ -390,6 +395,11 @@ namespace Renci.SshNet
             HostKeyReceived?.Invoke(this, e);
         }
 
+        private void Session_ServerIdentificationReceived(object sender, SshIdentificationEventArgs e)
+        {
+            ServerIdentificationReceived?.Invoke(this, e);
+        }
+
         /// <summary>
         /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
         /// </summary>
@@ -532,6 +542,7 @@ namespace Renci.SshNet
         private ISession CreateAndConnectSession()
         {
             var session = _serviceFactory.CreateSession(ConnectionInfo, _serviceFactory.CreateSocketFactory());
+            session.ServerIdentificationReceived += Session_ServerIdentificationReceived;
             session.HostKeyReceived += Session_HostKeyReceived;
             session.ErrorOccured += Session_ErrorOccured;
 
@@ -550,6 +561,7 @@ namespace Renci.SshNet
         private async Task<ISession> CreateAndConnectSessionAsync(CancellationToken cancellationToken)
         {
             var session = _serviceFactory.CreateSession(ConnectionInfo, _serviceFactory.CreateSocketFactory());
+            session.ServerIdentificationReceived += Session_ServerIdentificationReceived;
             session.HostKeyReceived += Session_HostKeyReceived;
             session.ErrorOccured += Session_ErrorOccured;
 
@@ -569,6 +581,7 @@ namespace Renci.SshNet
         {
             session.ErrorOccured -= Session_ErrorOccured;
             session.HostKeyReceived -= Session_HostKeyReceived;
+            session.ServerIdentificationReceived -= Session_ServerIdentificationReceived;
             session.Dispose();
         }
 

+ 26 - 0
src/Renci.SshNet/Common/SshIdentificationEventArgs.cs

@@ -0,0 +1,26 @@
+using System;
+
+using Renci.SshNet.Connection;
+
+namespace Renci.SshNet.Common
+{
+    /// <summary>
+    /// Provides data for the ServerIdentificationReceived events.
+    /// </summary>
+    public class SshIdentificationEventArgs : EventArgs
+    {
+        /// <summary>
+        /// Initializes a new instance of the <see cref="SshIdentificationEventArgs"/> class.
+        /// </summary>
+        /// <param name="sshIdentification">The SSH identification.</param>
+        public SshIdentificationEventArgs(SshIdentification sshIdentification)
+        {
+            SshIdentification = sshIdentification;
+        }
+
+        /// <summary>
+        /// Gets the SSH identification.
+        /// </summary>
+        public SshIdentification SshIdentification { get; private set; }
+    }
+}

+ 1 - 1
src/Renci.SshNet/Connection/SshIdentification.cs

@@ -5,7 +5,7 @@ namespace Renci.SshNet.Connection
     /// <summary>
     /// Represents an SSH identification.
     /// </summary>
-    internal sealed class SshIdentification
+    public sealed class SshIdentification
     {
         /// <summary>
         /// Initializes a new instance of the <see cref="SshIdentification"/> class with the specified protocol version

+ 5 - 0
src/Renci.SshNet/ISession.cs

@@ -260,6 +260,11 @@ namespace Renci.SshNet
         /// </summary>
         event EventHandler<ExceptionEventArgs> ErrorOccured;
 
+        /// <summary>
+        /// Occurs when server identification received.
+        /// </summary>
+        event EventHandler<SshIdentificationEventArgs> ServerIdentificationReceived;
+
         /// <summary>
         /// Occurs when host key received.
         /// </summary>

+ 9 - 0
src/Renci.SshNet/Session.cs

@@ -366,6 +366,11 @@ namespace Renci.SshNet
         /// </summary>
         public event EventHandler<EventArgs> Disconnected;
 
+        /// <summary>
+        /// Occurs when server identification received.
+        /// </summary>
+        public event EventHandler<SshIdentificationEventArgs> ServerIdentificationReceived;
+
         /// <summary>
         /// Occurs when host key received.
         /// </summary>
@@ -624,6 +629,8 @@ namespace Renci.SshNet
                                                          DisconnectReason.ProtocolVersionNotSupported);
                     }
 
+                    ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification));
+
                     // Register Transport response messages
                     RegisterMessage("SSH_MSG_DISCONNECT");
                     RegisterMessage("SSH_MSG_IGNORE");
@@ -736,6 +743,8 @@ namespace Renci.SshNet
                                                     DisconnectReason.ProtocolVersionNotSupported);
             }
 
+            ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification));
+
             // Register Transport response messages
             RegisterMessage("SSH_MSG_DISCONNECT");
             RegisterMessage("SSH_MSG_IGNORE");

+ 9 - 3
test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs

@@ -46,7 +46,8 @@ namespace Renci.SshNet.Tests.Classes
         protected Session Session { get; private set; }
         protected Socket ClientSocket { get; private set; }
         protected Socket ServerSocket { get; private set; }
-        internal SshIdentification ServerIdentification { get; private set; }
+        internal SshIdentification ServerIdentification { get; set; }
+        protected bool CallSessionConnectWhenArrange { get; set; }
 
         [TestInitialize]
         public void Setup()
@@ -159,6 +160,8 @@ namespace Renci.SshNet.Tests.Classes
             ServerListener.Start();
 
             ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo);
+
+            CallSessionConnectWhenArrange = true;
         }
 
         private void CreateMocks()
@@ -180,7 +183,7 @@ namespace Renci.SshNet.Tests.Classes
             _ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange())
                                   .Returns(_protocolVersionExchangeMock.Object);
             _ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout))
-                                            .Returns(ServerIdentification);
+                                            .Returns(() => ServerIdentification);
             _ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
             _ = _keyExchangeMock.Setup(p => p.Name)
                                 .Returns(_keyExchangeAlgorithm);
@@ -212,7 +215,10 @@ namespace Renci.SshNet.Tests.Classes
             SetupData();
             SetupMocks();
 
-            Session.Connect();
+            if (CallSessionConnectWhenArrange)
+            {
+                Session.Connect();
+            }
         }
 
         protected virtual void ClientAuthentication_Callback()

+ 65 - 0
test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerIdentificationReceived.cs

@@ -0,0 +1,65 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Connection;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class SessionTest_Connected_ServerIdentificationReceived : SessionTest_ConnectedBase
+    {
+        protected override void SetupData()
+        {
+            base.SetupData();
+
+            CallSessionConnectWhenArrange = false;
+
+            Session.ServerIdentificationReceived += (s, e) =>
+            {
+                if ((e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.5", System.StringComparison.Ordinal) || e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.6", System.StringComparison.Ordinal))
+                       && !e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.6.1", System.StringComparison.Ordinal))
+                {
+                    _ = ConnectionInfo.KeyExchangeAlgorithms.Remove("curve25519-sha256");
+                    _ = ConnectionInfo.KeyExchangeAlgorithms.Remove("curve25519-sha256@libssh.org");
+                }
+            };
+        }
+
+        protected override void Act()
+        {
+        }
+
+        [TestMethod]
+        [DataRow("OpenSSH_6.5")]
+        [DataRow("OpenSSH_6.5p1")]
+        [DataRow("OpenSSH_6.5 PKIX")]
+        [DataRow("OpenSSH_6.6")]
+        [DataRow("OpenSSH_6.6p1")]
+        [DataRow("OpenSSH_6.6 PKIX")]
+        public void ShouldExcludeCurve25519KexWhenServerIs(string softwareVersion)
+        {
+            ServerIdentification = new SshIdentification("2.0", softwareVersion);
+
+            Session.Connect();
+
+            Assert.IsFalse(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256"));
+            Assert.IsFalse(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256@libssh.org"));
+        }
+
+        [TestMethod]
+        [DataRow("OpenSSH_6.6.1")]
+        [DataRow("OpenSSH_6.6.1p1")]
+        [DataRow("OpenSSH_6.6.1 PKIX")]
+        [DataRow("OpenSSH_6.7")]
+        [DataRow("OpenSSH_6.7p1")]
+        [DataRow("OpenSSH_6.7 PKIX")]
+        public void ShouldIncludeCurve25519KexWhenServerIs(string softwareVersion)
+        {
+            ServerIdentification = new SshIdentification("2.0", softwareVersion);
+
+            Session.Connect();
+
+            Assert.IsTrue(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256"));
+            Assert.IsTrue(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256@libssh.org"));
+        }
+    }
+}