Browse Source

Tweak semaphore usage in Session (#1304)

- Change _connectAndLazySemaphoreInitLock to a SemaphoreSlim and use it in
  ConnectAsync.
- Rename it to _connectLock and only use it for connecting. Replace its
  other usages (on SessionSemaphore and NextChannelNumber) with Interlocked
  operations.
- Remove AuthenticationConnection semaphore. This static member placed a
  process-wide limit on the number of connections an application can make.
  I agree with the argument in
  https://github.com/sshnet/SSH.NET/issues/409#issuecomment-457415542
  (and in several other issues/PRs) that this should not be something
  that the library attempts to control.

The last change broke a few tests which do things like making 100 connections.
I was tempted to delete these tests as I don't think they have much value, but
instead I just limited their concurrency.

Co-authored-by: Wojciech Nagórski <wojtpl2@gmail.com>
Rob Hague 1 year ago
parent
commit
47eabe7574

+ 178 - 184
src/Renci.SshNet/Session.cs

@@ -1,4 +1,5 @@
 using System;
+using System.Diagnostics;
 using System.Globalization;
 using System.Linq;
 using System.Net.Sockets;
@@ -81,14 +82,6 @@ namespace Renci.SshNet
         /// </remarks>
         internal static readonly TimeSpan InfiniteTimeSpan = new TimeSpan(0, 0, 0, 0, -1);
 
-        /// <summary>
-        /// Controls how many authentication attempts can take place at the same time.
-        /// </summary>
-        /// <remarks>
-        /// Some server may restrict number to prevent authentication attacks.
-        /// </remarks>
-        private static readonly SemaphoreSlim AuthenticationConnection = new SemaphoreSlim(3);
-
         /// <summary>
         /// Holds the factory to use for creating new services.
         /// </summary>
@@ -123,9 +116,9 @@ namespace Renci.SshNet
 
         /// <summary>
         /// Holds an object that is used to ensure only a single thread can connect
-        /// and lazy initialize the <see cref="SessionSemaphore"/> at any given time.
+        /// at any given time.
         /// </summary>
-        private readonly object _connectAndLazySemaphoreInitLock = new object();
+        private readonly SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1);
 
         /// <summary>
         /// Holds metadata about session messages.
@@ -195,7 +188,7 @@ namespace Renci.SshNet
 
         private bool _isDisconnectMessageSent;
 
-        private uint _nextChannelNumber;
+        private int _nextChannelNumber;
 
         /// <summary>
         /// Holds connection socket.
@@ -212,12 +205,18 @@ namespace Renci.SshNet
         {
             get
             {
-                if (_sessionSemaphore is null)
+                if (_sessionSemaphore is SemaphoreSlim sessionSemaphore)
                 {
-                    lock (_connectAndLazySemaphoreInitLock)
-                    {
-                        _sessionSemaphore ??= new SemaphoreSlim(ConnectionInfo.MaxSessions);
-                    }
+                    return sessionSemaphore;
+                }
+
+                sessionSemaphore = new SemaphoreSlim(ConnectionInfo.MaxSessions);
+
+                if (Interlocked.CompareExchange(ref _sessionSemaphore, sessionSemaphore, comparand: null) is not null)
+                {
+                    // Another thread has set _sessionSemaphore. Dispose our one.
+                    Debug.Assert(_sessionSemaphore != sessionSemaphore);
+                    sessionSemaphore.Dispose();
                 }
 
                 return _sessionSemaphore;
@@ -234,14 +233,7 @@ namespace Renci.SshNet
         {
             get
             {
-                uint result;
-
-                lock (_connectAndLazySemaphoreInitLock)
-                {
-                    result = _nextChannelNumber++;
-                }
-
-                return result;
+                return (uint)Interlocked.Increment(ref _nextChannelNumber);
             }
         }
 
@@ -583,128 +575,116 @@ namespace Renci.SshNet
                 return;
             }
 
+            _connectLock.Wait();
+
             try
             {
-                AuthenticationConnection.Wait();
-
                 if (IsConnected)
                 {
                     return;
                 }
 
-                lock (_connectAndLazySemaphoreInitLock)
-                {
-                    // If connected don't connect again
-                    if (IsConnected)
-                    {
-                        return;
-                    }
-
-                    // Reset connection specific information
-                    Reset();
+                // Reset connection specific information
+                Reset();
 
-                    // Build list of available messages while connecting
-                    _sshMessageFactory = new SshMessageFactory();
+                // Build list of available messages while connecting
+                _sshMessageFactory = new SshMessageFactory();
 
-                    _socket = _serviceFactory.CreateConnector(ConnectionInfo, _socketFactory)
-                                             .Connect(ConnectionInfo);
+                _socket = _serviceFactory.CreateConnector(ConnectionInfo, _socketFactory)
+                                            .Connect(ConnectionInfo);
 
-                    var serverIdentification = _serviceFactory.CreateProtocolVersionExchange()
-                                                              .Start(ClientVersion, _socket, ConnectionInfo.Timeout);
+                var serverIdentification = _serviceFactory.CreateProtocolVersionExchange()
+                                                            .Start(ClientVersion, _socket, ConnectionInfo.Timeout);
 
-                    // Set connection versions
-                    ServerVersion = ConnectionInfo.ServerVersion = serverIdentification.ToString();
-                    ConnectionInfo.ClientVersion = ClientVersion;
+                // Set connection versions
+                ServerVersion = ConnectionInfo.ServerVersion = serverIdentification.ToString();
+                ConnectionInfo.ClientVersion = ClientVersion;
 
-                    DiagnosticAbstraction.Log(string.Format("Server version '{0}'.", serverIdentification));
+                DiagnosticAbstraction.Log(string.Format("Server version '{0}'.", serverIdentification));
 
-                    if (!(serverIdentification.ProtocolVersion.Equals("2.0") || serverIdentification.ProtocolVersion.Equals("1.99")))
-                    {
-                        throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Server version '{0}' is not supported.", serverIdentification.ProtocolVersion),
-                                                         DisconnectReason.ProtocolVersionNotSupported);
-                    }
+                if (!(serverIdentification.ProtocolVersion.Equals("2.0") || serverIdentification.ProtocolVersion.Equals("1.99")))
+                {
+                    throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Server version '{0}' is not supported.", serverIdentification.ProtocolVersion),
+                                                        DisconnectReason.ProtocolVersionNotSupported);
+                }
 
-                    ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification));
+                ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification));
 
-                    // Register Transport response messages
-                    RegisterMessage("SSH_MSG_DISCONNECT");
-                    RegisterMessage("SSH_MSG_IGNORE");
-                    RegisterMessage("SSH_MSG_UNIMPLEMENTED");
-                    RegisterMessage("SSH_MSG_DEBUG");
-                    RegisterMessage("SSH_MSG_SERVICE_ACCEPT");
-                    RegisterMessage("SSH_MSG_KEXINIT");
-                    RegisterMessage("SSH_MSG_NEWKEYS");
+                // Register Transport response messages
+                RegisterMessage("SSH_MSG_DISCONNECT");
+                RegisterMessage("SSH_MSG_IGNORE");
+                RegisterMessage("SSH_MSG_UNIMPLEMENTED");
+                RegisterMessage("SSH_MSG_DEBUG");
+                RegisterMessage("SSH_MSG_SERVICE_ACCEPT");
+                RegisterMessage("SSH_MSG_KEXINIT");
+                RegisterMessage("SSH_MSG_NEWKEYS");
 
-                    // Some server implementations might sent this message first, prior to establishing encryption algorithm
-                    RegisterMessage("SSH_MSG_USERAUTH_BANNER");
+                // Some server implementations might sent this message first, prior to establishing encryption algorithm
+                RegisterMessage("SSH_MSG_USERAUTH_BANNER");
 
-                    // Send our key exchange init.
-                    // We need to do this before starting the message listener to avoid the case where we receive the server
-                    // key exchange init and we continue the key exchange before having sent our own init.
-                    SendMessage(ClientInitMessage);
+                // Send our key exchange init.
+                // We need to do this before starting the message listener to avoid the case where we receive the server
+                // key exchange init and we continue the key exchange before having sent our own init.
+                SendMessage(ClientInitMessage);
 
-                    // Mark the message listener threads as started
-                    _ = _messageListenerCompleted.Reset();
+                // Mark the message listener threads as started
+                _ = _messageListenerCompleted.Reset();
 
-                    // Start incoming request listener
-                    // ToDo: Make message pump async, to not consume a thread for every session
-                    _ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);
+                // Start incoming request listener
+                // ToDo: Make message pump async, to not consume a thread for every session
+                _ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);
 
-                    // Wait for key exchange to be completed
-                    WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
+                // Wait for key exchange to be completed
+                WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
 
-                    // If sessionId is not set then its not connected
-                    if (SessionId is null)
-                    {
-                        Disconnect();
-                        return;
-                    }
+                // If sessionId is not set then its not connected
+                if (SessionId is null)
+                {
+                    Disconnect();
+                    return;
+                }
 
-                    // Request user authorization service
-                    SendMessage(new ServiceRequestMessage(ServiceName.UserAuthentication));
+                // Request user authorization service
+                SendMessage(new ServiceRequestMessage(ServiceName.UserAuthentication));
 
-                    // Wait for service to be accepted
-                    WaitOnHandle(_serviceAccepted);
+                // Wait for service to be accepted
+                WaitOnHandle(_serviceAccepted);
 
-                    if (string.IsNullOrEmpty(ConnectionInfo.Username))
-                    {
-                        throw new SshException("Username is not specified.");
-                    }
-
-                    // Some servers send a global request immediately after successful authentication
-                    // Avoid race condition by already enabling SSH_MSG_GLOBAL_REQUEST before authentication
-                    RegisterMessage("SSH_MSG_GLOBAL_REQUEST");
-
-                    ConnectionInfo.Authenticate(this, _serviceFactory);
-                    _isAuthenticated = true;
-
-                    // Register Connection messages
-                    RegisterMessage("SSH_MSG_REQUEST_SUCCESS");
-                    RegisterMessage("SSH_MSG_REQUEST_FAILURE");
-                    RegisterMessage("SSH_MSG_CHANNEL_OPEN_CONFIRMATION");
-                    RegisterMessage("SSH_MSG_CHANNEL_OPEN_FAILURE");
-                    RegisterMessage("SSH_MSG_CHANNEL_WINDOW_ADJUST");
-                    RegisterMessage("SSH_MSG_CHANNEL_EXTENDED_DATA");
-                    RegisterMessage("SSH_MSG_CHANNEL_REQUEST");
-                    RegisterMessage("SSH_MSG_CHANNEL_SUCCESS");
-                    RegisterMessage("SSH_MSG_CHANNEL_FAILURE");
-                    RegisterMessage("SSH_MSG_CHANNEL_DATA");
-                    RegisterMessage("SSH_MSG_CHANNEL_EOF");
-                    RegisterMessage("SSH_MSG_CHANNEL_CLOSE");
+                if (string.IsNullOrEmpty(ConnectionInfo.Username))
+                {
+                    throw new SshException("Username is not specified.");
                 }
+
+                // Some servers send a global request immediately after successful authentication
+                // Avoid race condition by already enabling SSH_MSG_GLOBAL_REQUEST before authentication
+                RegisterMessage("SSH_MSG_GLOBAL_REQUEST");
+
+                ConnectionInfo.Authenticate(this, _serviceFactory);
+                _isAuthenticated = true;
+
+                // Register Connection messages
+                RegisterMessage("SSH_MSG_REQUEST_SUCCESS");
+                RegisterMessage("SSH_MSG_REQUEST_FAILURE");
+                RegisterMessage("SSH_MSG_CHANNEL_OPEN_CONFIRMATION");
+                RegisterMessage("SSH_MSG_CHANNEL_OPEN_FAILURE");
+                RegisterMessage("SSH_MSG_CHANNEL_WINDOW_ADJUST");
+                RegisterMessage("SSH_MSG_CHANNEL_EXTENDED_DATA");
+                RegisterMessage("SSH_MSG_CHANNEL_REQUEST");
+                RegisterMessage("SSH_MSG_CHANNEL_SUCCESS");
+                RegisterMessage("SSH_MSG_CHANNEL_FAILURE");
+                RegisterMessage("SSH_MSG_CHANNEL_DATA");
+                RegisterMessage("SSH_MSG_CHANNEL_EOF");
+                RegisterMessage("SSH_MSG_CHANNEL_CLOSE");
             }
             finally
             {
-                _ = AuthenticationConnection.Release();
+                _ = _connectLock.Release();
             }
         }
 
         /// <summary>
         /// Asynchronously connects to the server.
         /// </summary>
-        /// <remarks>
-        /// Please note this function is NOT thread safe.<br/>
-        /// The caller SHOULD limit the number of simultaneous connection attempts to a server to a single connection attempt.</remarks>
         /// <param name="cancellationToken">The <see cref="CancellationToken"/> to observe.</param>
         /// <returns>A <see cref="Task"/> that represents the asynchronous connect operation.</returns>
         /// <exception cref="SocketException">Socket connection to the SSH server or proxy server could not be established, or an error occurred while resolving the hostname.</exception>
@@ -719,97 +699,111 @@ namespace Renci.SshNet
                 return;
             }
 
-            // Reset connection specific information
-            Reset();
+            await _connectLock.WaitAsync(cancellationToken).ConfigureAwait(false);
 
-            // Build list of available messages while connecting
-            _sshMessageFactory = new SshMessageFactory();
+            try
+            {
+                if (IsConnected)
+                {
+                    return;
+                }
 
-            _socket = await _serviceFactory.CreateConnector(ConnectionInfo, _socketFactory)
-                                        .ConnectAsync(ConnectionInfo, cancellationToken).ConfigureAwait(false);
+                // Reset connection specific information
+                Reset();
 
-            var serverIdentification = await _serviceFactory.CreateProtocolVersionExchange()
-                                                        .StartAsync(ClientVersion, _socket, cancellationToken).ConfigureAwait(false);
+                // Build list of available messages while connecting
+                _sshMessageFactory = new SshMessageFactory();
 
-            // Set connection versions
-            ServerVersion = ConnectionInfo.ServerVersion = serverIdentification.ToString();
-            ConnectionInfo.ClientVersion = ClientVersion;
+                _socket = await _serviceFactory.CreateConnector(ConnectionInfo, _socketFactory)
+                                            .ConnectAsync(ConnectionInfo, cancellationToken).ConfigureAwait(false);
 
-            DiagnosticAbstraction.Log(string.Format("Server version '{0}'.", serverIdentification));
+                var serverIdentification = await _serviceFactory.CreateProtocolVersionExchange()
+                                                            .StartAsync(ClientVersion, _socket, cancellationToken).ConfigureAwait(false);
 
-            if (!(serverIdentification.ProtocolVersion.Equals("2.0") || serverIdentification.ProtocolVersion.Equals("1.99")))
-            {
-                throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Server version '{0}' is not supported.", serverIdentification.ProtocolVersion),
-                                                    DisconnectReason.ProtocolVersionNotSupported);
-            }
+                // Set connection versions
+                ServerVersion = ConnectionInfo.ServerVersion = serverIdentification.ToString();
+                ConnectionInfo.ClientVersion = ClientVersion;
 
-            ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification));
+                DiagnosticAbstraction.Log(string.Format("Server version '{0}'.", serverIdentification));
 
-            // Register Transport response messages
-            RegisterMessage("SSH_MSG_DISCONNECT");
-            RegisterMessage("SSH_MSG_IGNORE");
-            RegisterMessage("SSH_MSG_UNIMPLEMENTED");
-            RegisterMessage("SSH_MSG_DEBUG");
-            RegisterMessage("SSH_MSG_SERVICE_ACCEPT");
-            RegisterMessage("SSH_MSG_KEXINIT");
-            RegisterMessage("SSH_MSG_NEWKEYS");
+                if (!(serverIdentification.ProtocolVersion.Equals("2.0") || serverIdentification.ProtocolVersion.Equals("1.99")))
+                {
+                    throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Server version '{0}' is not supported.", serverIdentification.ProtocolVersion),
+                                                        DisconnectReason.ProtocolVersionNotSupported);
+                }
 
-            // Some server implementations might sent this message first, prior to establishing encryption algorithm
-            RegisterMessage("SSH_MSG_USERAUTH_BANNER");
+                ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification));
 
-            // Send our key exchange init.
-            // We need to do this before starting the message listener to avoid the case where we receive the server
-            // key exchange init and we continue the key exchange before having sent our own init.
-            SendMessage(ClientInitMessage);
+                // Register Transport response messages
+                RegisterMessage("SSH_MSG_DISCONNECT");
+                RegisterMessage("SSH_MSG_IGNORE");
+                RegisterMessage("SSH_MSG_UNIMPLEMENTED");
+                RegisterMessage("SSH_MSG_DEBUG");
+                RegisterMessage("SSH_MSG_SERVICE_ACCEPT");
+                RegisterMessage("SSH_MSG_KEXINIT");
+                RegisterMessage("SSH_MSG_NEWKEYS");
 
-            // Mark the message listener threads as started
-            _ = _messageListenerCompleted.Reset();
+                // Some server implementations might sent this message first, prior to establishing encryption algorithm
+                RegisterMessage("SSH_MSG_USERAUTH_BANNER");
 
-            // Start incoming request listener
-            // ToDo: Make message pump async, to not consume a thread for every session
-            _ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);
+                // Send our key exchange init.
+                // We need to do this before starting the message listener to avoid the case where we receive the server
+                // key exchange init and we continue the key exchange before having sent our own init.
+                SendMessage(ClientInitMessage);
 
-            // Wait for key exchange to be completed
-            WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
+                // Mark the message listener threads as started
+                _ = _messageListenerCompleted.Reset();
 
-            // If sessionId is not set then its not connected
-            if (SessionId is null)
-            {
-                Disconnect();
-                return;
-            }
+                // Start incoming request listener
+                // ToDo: Make message pump async, to not consume a thread for every session
+                _ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);
+
+                // Wait for key exchange to be completed
+                WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
+
+                // If sessionId is not set then its not connected
+                if (SessionId is null)
+                {
+                    Disconnect();
+                    return;
+                }
 
-            // Request user authorization service
-            SendMessage(new ServiceRequestMessage(ServiceName.UserAuthentication));
+                // Request user authorization service
+                SendMessage(new ServiceRequestMessage(ServiceName.UserAuthentication));
 
-            // Wait for service to be accepted
-            WaitOnHandle(_serviceAccepted);
+                // Wait for service to be accepted
+                WaitOnHandle(_serviceAccepted);
+
+                if (string.IsNullOrEmpty(ConnectionInfo.Username))
+                {
+                    throw new SshException("Username is not specified.");
+                }
 
-            if (string.IsNullOrEmpty(ConnectionInfo.Username))
+                // Some servers send a global request immediately after successful authentication
+                // Avoid race condition by already enabling SSH_MSG_GLOBAL_REQUEST before authentication
+                RegisterMessage("SSH_MSG_GLOBAL_REQUEST");
+
+                ConnectionInfo.Authenticate(this, _serviceFactory);
+                _isAuthenticated = true;
+
+                // Register Connection messages
+                RegisterMessage("SSH_MSG_REQUEST_SUCCESS");
+                RegisterMessage("SSH_MSG_REQUEST_FAILURE");
+                RegisterMessage("SSH_MSG_CHANNEL_OPEN_CONFIRMATION");
+                RegisterMessage("SSH_MSG_CHANNEL_OPEN_FAILURE");
+                RegisterMessage("SSH_MSG_CHANNEL_WINDOW_ADJUST");
+                RegisterMessage("SSH_MSG_CHANNEL_EXTENDED_DATA");
+                RegisterMessage("SSH_MSG_CHANNEL_REQUEST");
+                RegisterMessage("SSH_MSG_CHANNEL_SUCCESS");
+                RegisterMessage("SSH_MSG_CHANNEL_FAILURE");
+                RegisterMessage("SSH_MSG_CHANNEL_DATA");
+                RegisterMessage("SSH_MSG_CHANNEL_EOF");
+                RegisterMessage("SSH_MSG_CHANNEL_CLOSE");
+            }
+            finally
             {
-                throw new SshException("Username is not specified.");
+                _ = _connectLock.Release();
             }
-
-            // Some servers send a global request immediately after successful authentication
-            // Avoid race condition by already enabling SSH_MSG_GLOBAL_REQUEST before authentication
-            RegisterMessage("SSH_MSG_GLOBAL_REQUEST");
-
-            ConnectionInfo.Authenticate(this, _serviceFactory);
-            _isAuthenticated = true;
-
-            // Register Connection messages
-            RegisterMessage("SSH_MSG_REQUEST_SUCCESS");
-            RegisterMessage("SSH_MSG_REQUEST_FAILURE");
-            RegisterMessage("SSH_MSG_CHANNEL_OPEN_CONFIRMATION");
-            RegisterMessage("SSH_MSG_CHANNEL_OPEN_FAILURE");
-            RegisterMessage("SSH_MSG_CHANNEL_WINDOW_ADJUST");
-            RegisterMessage("SSH_MSG_CHANNEL_EXTENDED_DATA");
-            RegisterMessage("SSH_MSG_CHANNEL_REQUEST");
-            RegisterMessage("SSH_MSG_CHANNEL_SUCCESS");
-            RegisterMessage("SSH_MSG_CHANNEL_FAILURE");
-            RegisterMessage("SSH_MSG_CHANNEL_DATA");
-            RegisterMessage("SSH_MSG_CHANNEL_EOF");
-            RegisterMessage("SSH_MSG_CHANNEL_CLOSE");
         }
 
         /// <summary>

+ 6 - 35
test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs

@@ -442,47 +442,18 @@ namespace Renci.SshNet.IntegrationTests.OldIntegrationTests
             }
         }
 
-        [TestMethod]
-        public void Test_MultipleThread_Example_MultipleConnections()
-        {
-            try
-            {
-#region Example SshCommand RunCommand Parallel
-                Parallel.For(0, 100,
-                    () =>
-                    {
-                        var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password);
-                        client.Connect();
-                        return client;
-                    },
-                    (int counter, ParallelLoopState pls, SshClient client) =>
-                    {
-                        var result = client.RunCommand("echo 123");
-                        Debug.WriteLine(string.Format("TestMultipleThreadMultipleConnections #{0}", counter));
-                        return client;
-                    },
-                    (SshClient client) =>
-                    {
-                        client.Disconnect();
-                        client.Dispose();
-                    }
-                );
-#endregion
-
-            }
-            catch (Exception exp)
-            {
-                Assert.Fail(exp.ToString());
-            }
-        }
-
         [TestMethod]
         
         public void Test_MultipleThread_100_MultipleConnections()
         {
             try
             {
-                Parallel.For(0, 100,
+                var options = new ParallelOptions()
+                {
+                    MaxDegreeOfParallelism = 8
+                };
+
+                Parallel.For(0, 100, options,
                     () =>
                     {
                         var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password);

+ 1 - 1
test/Renci.SshNet.IntegrationTests/SftpTests.cs

@@ -83,7 +83,7 @@ namespace Renci.SshNet.IntegrationTests
         public void Sftp_ConnectDisconnect_Parallel()
         {
             const int iterations = 10;
-            const int threads = 20;
+            const int threads = 5;
 
             var startEvent = new ManualResetEvent(false);