Просмотр исходного кода

Plumbing for more async/await work (#1281)

* Changes _socketDisposeLock to a SemaphoreSlim so it can play nice with async/await

* Adds a SendAsync method for .NET6+

* Fix false positive analyzer error

* Formatting

---------

Co-authored-by: Robert Hague <rh@johnstreetcapital.com>
Co-authored-by: Wojciech Nagórski <wojtpl2@gmail.com>
Co-authored-by: Rob Hague <rob.hague00@gmail.com>
Jacob Slusser 1 год назад
Родитель
Сommit
48a8fe6502

+ 50 - 0
src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs

@@ -0,0 +1,50 @@
+#if NET6_0_OR_GREATER
+
+using System;
+using System.Diagnostics;
+using System.Net.Sockets;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Renci.SshNet.Abstractions
+{
+    internal static partial class SocketAbstraction
+    {
+        public static ValueTask<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
+        {
+            return socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken);
+        }
+
+        public static ValueTask SendAsync(Socket socket, ReadOnlyMemory<byte> data, CancellationToken cancellationToken = default)
+        {
+            Debug.Assert(socket != null);
+            Debug.Assert(data.Length > 0);
+
+            if (cancellationToken.IsCancellationRequested)
+            {
+                return ValueTask.FromCanceled(cancellationToken);
+            }
+
+            return SendAsyncCore(socket, data, cancellationToken);
+
+            static async ValueTask SendAsyncCore(Socket socket, ReadOnlyMemory<byte> data, CancellationToken cancellationToken)
+            {
+                do
+                {
+                    try
+                    {
+                        var bytesSent = await socket.SendAsync(data, SocketFlags.None, cancellationToken).ConfigureAwait(false);
+                        data = data.Slice(bytesSent);
+                    }
+                    catch (SocketException ex) when (IsErrorResumable(ex.SocketErrorCode))
+                    {
+                        // Buffer may be full; attempt a short delay and retry
+                        await Task.Delay(30, cancellationToken).ConfigureAwait(false);
+                    }
+                }
+                while (data.Length > 0);
+            }
+        }
+    }
+}
+#endif // NET6_0_OR_GREATER

+ 2 - 7
src/Renci.SshNet/Abstractions/SocketAbstraction.cs

@@ -10,7 +10,7 @@ using Renci.SshNet.Messages.Transport;
 
 namespace Renci.SshNet.Abstractions
 {
-    internal static class SocketAbstraction
+    internal static partial class SocketAbstraction
     {
         public static bool CanRead(Socket socket)
         {
@@ -311,12 +311,7 @@ namespace Renci.SshNet.Abstractions
             return totalBytesRead;
         }
 
-#if NET6_0_OR_GREATER
-        public static async Task<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
-        {
-            return await socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken).ConfigureAwait(false);
-        }
-#else
+#if NET6_0_OR_GREATER == false
         public static Task<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
         {
             return socket.ReceiveAsync(buffer, 0, buffer.Length, cancellationToken);

+ 4 - 0
src/Renci.SshNet/Connection/ProtocolVersionExchange.cs

@@ -81,7 +81,11 @@ namespace Renci.SshNet.Connection
         {
             // Immediately send the identification string since the spec states both sides MUST send an identification string
             // when the connection has been established
+#if NET6_0_OR_GREATER
+            await SocketAbstraction.SendAsync(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A"), cancellationToken).ConfigureAwait(false);
+#else
             SocketAbstraction.Send(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A"));
+#endif // NET6_0_OR_GREATER
 
             var bytesReceived = new List<byte>();
 

+ 25 - 7
src/Renci.SshNet/Session.cs

@@ -119,7 +119,7 @@ namespace Renci.SshNet
         /// This is also used to ensure that <see cref="_socket"/> will not be disposed
         /// while performing a given operation or set of operations on <see cref="_socket"/>.
         /// </remarks>
-        private readonly object _socketDisposeLock = new object();
+        private readonly SemaphoreSlim _socketDisposeLock = new SemaphoreSlim(1, 1);
 
         /// <summary>
         /// Holds an object that is used to ensure only a single thread can connect
@@ -1127,12 +1127,14 @@ namespace Renci.SshNet
         /// </para>
         /// <para>
         /// This method is only to be used when the connection is established, as the locking
-        /// overhead is not required while establising the connection.
+        /// overhead is not required while establishing the connection.
         /// </para>
         /// </remarks>
         private void SendPacket(byte[] packet, int offset, int length)
         {
-            lock (_socketDisposeLock)
+            _socketDisposeLock.Wait();
+
+            try
             {
                 if (!_socket.IsConnected())
                 {
@@ -1141,6 +1143,10 @@ namespace Renci.SshNet
 
                 SocketAbstraction.Send(_socket, packet, offset, length);
             }
+            finally
+            {
+                _ = _socketDisposeLock.Release();
+            }
         }
 
         /// <summary>
@@ -1798,8 +1804,9 @@ namespace Renci.SshNet
         /// </remarks>
         private bool IsSocketConnected()
         {
-#pragma warning disable S2222 // Locks should be released on all paths
-            lock (_socketDisposeLock)
+            _socketDisposeLock.Wait();
+
+            try
             {
                 if (!_socket.IsConnected())
                 {
@@ -1821,7 +1828,10 @@ namespace Renci.SshNet
                     Monitor.Exit(_socketReadLock);
                 }
             }
-#pragma warning restore S2222 // Locks should be released on all paths
+            finally
+            {
+                _ = _socketDisposeLock.Release();
+            }
         }
 
         /// <summary>
@@ -1848,9 +1858,13 @@ namespace Renci.SshNet
         {
             if (_socket != null)
             {
-                lock (_socketDisposeLock)
+                _socketDisposeLock.Wait();
+
+                try
                 {
+#pragma warning disable CA1508 // Avoid dead conditional code; Value could have been changed by another thread.
                     if (_socket != null)
+#pragma warning restore CA1508 // Avoid dead conditional code
                     {
                         if (_socket.Connected)
                         {
@@ -1879,6 +1893,10 @@ namespace Renci.SshNet
                         _socket = null;
                     }
                 }
+                finally
+                {
+                    _ = _socketDisposeLock.Release();
+                }
             }
         }