Ver código fonte

Adapt InternalUploadFile for async (#1653)

The recently added UploadFileAsync effectively calls stream.CopyToAsync(SftpFileStream).
This is slower than the sync UploadFile (by about 4x in a local test) because the sync
version sends multiple write requests concurrently, without waiting for each response
in turn like the stream-based version does.

This change adapts the sync code for async and uses it to bring the performance of
UploadFileAsync in line with that of UploadFile.
Rob Hague 2 meses atrás
pai
commit
b7c5f1a87d

+ 3 - 0
src/Renci.SshNet/.editorconfig

@@ -194,3 +194,6 @@ dotnet_diagnostic.MA0042.severity = none
 
 # S3236: Caller information arguments should not be provided explicitly
 dotnet_diagnostic.S3236.severity = none
+
+# S3358: Ternary operators should not be nested
+dotnet_diagnostic.S3358.severity = none

+ 28 - 2
src/Renci.SshNet/ISubsystemSession.cs

@@ -1,5 +1,6 @@
 using System;
 using System.Threading;
+using System.Threading.Tasks;
 
 using Microsoft.Extensions.Logging;
 
@@ -49,15 +50,40 @@ namespace Renci.SshNet
         void Disconnect();
 
         /// <summary>
-        /// Waits a specified time for a given <see cref="WaitHandle"/> to get signaled.
+        /// Waits a specified time for a given <see cref="WaitHandle"/> to be signaled.
         /// </summary>
         /// <param name="waitHandle">The handle to wait for.</param>
-        /// <param name="millisecondsTimeout">The number of milliseconds wait for <paramref name="waitHandle"/> to get signaled, or <c>-1</c> to wait indefinitely.</param>
+        /// <param name="millisecondsTimeout">The number of milliseconds to wait for <paramref name="waitHandle"/> to be signaled, or <c>-1</c> to wait indefinitely.</param>
         /// <exception cref="SshException">The connection was closed by the server.</exception>
         /// <exception cref="SshException">The channel was closed.</exception>
         /// <exception cref="SshOperationTimeoutException">The handle did not get signaled within the specified timeout.</exception>
         void WaitOnHandle(WaitHandle waitHandle, int millisecondsTimeout);
 
+        /// <summary>
+        /// Asynchronously waits for a given <see cref="WaitHandle"/> to be signaled.
+        /// </summary>
+        /// <param name="waitHandle">The handle to wait for.</param>
+        /// <param name="millisecondsTimeout">The number of milliseconds to wait for <paramref name="waitHandle"/> to be signaled, or <c>-1</c> to wait indefinitely.</param>
+        /// <param name="cancellationToken">The cancellation token to observe.</param>
+        /// <exception cref="SshException">The connection was closed by the server.</exception>
+        /// <exception cref="SshException">The channel was closed.</exception>
+        /// <exception cref="SshOperationTimeoutException">The handle did not get signaled within the specified timeout.</exception>
+        /// <returns>A <see cref="Task"/> representing the wait.</returns>
+        Task WaitOnHandleAsync(WaitHandle waitHandle, int millisecondsTimeout, CancellationToken cancellationToken);
+
+        /// <summary>
+        /// Asynchronously waits for a given <see cref="TaskCompletionSource{T}"/> to complete.
+        /// </summary>
+        /// <typeparam name="T">The type of the result which is being awaited.</typeparam>
+        /// <param name="tcs">The handle to wait for.</param>
+        /// <param name="millisecondsTimeout">The number of milliseconds to wait for <paramref name="tcs"/> to complete, or <c>-1</c> to wait indefinitely.</param>
+        /// <param name="cancellationToken">The cancellation token to observe.</param>
+        /// <exception cref="SshException">The connection was closed by the server.</exception>
+        /// <exception cref="SshException">The channel was closed.</exception>
+        /// <exception cref="SshOperationTimeoutException">The handle did not get signaled within the specified timeout.</exception>
+        /// <returns>A <see cref="Task"/> representing the wait.</returns>
+        Task<T> WaitOnHandleAsync<T>(TaskCompletionSource<T> tcs, int millisecondsTimeout, CancellationToken cancellationToken);
+
         /// <summary>
         /// Blocks the current thread until the specified <see cref="WaitHandle"/> gets signaled, using a
         /// 32-bit signed integer to specify the time interval in milliseconds.

+ 91 - 47
src/Renci.SshNet/SftpClient.cs

@@ -1027,6 +1027,8 @@ namespace Renci.SshNet
         /// <inheritdoc/>
         public void UploadFile(Stream input, string path, bool canOverride, Action<ulong>? uploadCallback = null)
         {
+            ThrowHelper.ThrowIfNull(input);
+            ThrowHelper.ThrowIfNullOrWhiteSpace(path);
             CheckDisposed();
 
             var flags = Flags.Write | Flags.Truncate;
@@ -1040,15 +1042,31 @@ namespace Renci.SshNet
                 flags |= Flags.CreateNew;
             }
 
-            InternalUploadFile(input, path, flags, asyncResult: null, uploadCallback);
+            InternalUploadFile(
+                input,
+                path,
+                flags,
+                asyncResult: null,
+                uploadCallback,
+                isAsync: false,
+                default).GetAwaiter().GetResult();
         }
 
         /// <inheritdoc />
         public Task UploadFileAsync(Stream input, string path, CancellationToken cancellationToken = default)
         {
+            ThrowHelper.ThrowIfNull(input);
+            ThrowHelper.ThrowIfNullOrWhiteSpace(path);
             CheckDisposed();
 
-            return InternalUploadFileAsync(input, path, cancellationToken);
+            return InternalUploadFile(
+                input,
+                path,
+                Flags.Write | Flags.Truncate | Flags.CreateNewOrOpen,
+                asyncResult: null,
+                uploadCallback: null,
+                isAsync: true,
+                cancellationToken);
         }
 
         /// <summary>
@@ -1163,9 +1181,9 @@ namespace Renci.SshNet
         /// </remarks>
         public IAsyncResult BeginUploadFile(Stream input, string path, bool canOverride, AsyncCallback? asyncCallback, object? state, Action<ulong>? uploadCallback = null)
         {
-            CheckDisposed();
             ThrowHelper.ThrowIfNull(input);
             ThrowHelper.ThrowIfNullOrWhiteSpace(path);
+            CheckDisposed();
 
             var flags = Flags.Write | Flags.Truncate;
 
@@ -1180,19 +1198,28 @@ namespace Renci.SshNet
 
             var asyncResult = new SftpUploadAsyncResult(asyncCallback, state);
 
-            ThreadAbstraction.ExecuteThread(() =>
+            _ = DoUploadAndSetResult();
+
+            async Task DoUploadAndSetResult()
             {
                 try
                 {
-                    InternalUploadFile(input, path, flags, asyncResult, uploadCallback);
+                    await InternalUploadFile(
+                        input,
+                        path,
+                        flags,
+                        asyncResult,
+                        uploadCallback,
+                        isAsync: true,
+                        CancellationToken.None).ConfigureAwait(false);
 
                     asyncResult.SetAsCompleted(exception: null, completedSynchronously: false);
                 }
                 catch (Exception exp)
                 {
-                    asyncResult.SetAsCompleted(exception: exp, completedSynchronously: false);
+                    asyncResult.SetAsCompleted(exp, completedSynchronously: false);
                 }
-            });
+            }
 
             return asyncResult;
         }
@@ -2284,11 +2311,16 @@ namespace Renci.SshNet
                         var remoteFileName = string.Format(CultureInfo.InvariantCulture, @"{0}/{1}", destinationPath, localFile.Name);
                         try
                         {
-#pragma warning disable CA2000 // Dispose objects before losing scope; false positive
                             using (var file = File.OpenRead(localFile.FullName))
-#pragma warning restore CA2000 // Dispose objects before losing scope; false positive
                             {
-                                InternalUploadFile(file, remoteFileName, uploadFlag, asyncResult: null, uploadCallback: null);
+                                InternalUploadFile(
+                                    file,
+                                    remoteFileName,
+                                    uploadFlag,
+                                    asyncResult: null,
+                                    uploadCallback: null,
+                                    isAsync: false,
+                                    CancellationToken.None).GetAwaiter().GetResult();
                             }
 
                             uploadedFiles.Add(localFile);
@@ -2455,37 +2487,42 @@ namespace Renci.SshNet
             }
         }
 
-        /// <summary>
-        /// Internals the upload file.
-        /// </summary>
-        /// <param name="input">The input.</param>
-        /// <param name="path">The path.</param>
-        /// <param name="flags">The flags.</param>
-        /// <param name="asyncResult">An <see cref="IAsyncResult"/> that references the asynchronous request.</param>
-        /// <param name="uploadCallback">The upload callback.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="input" /> is <see langword="null"/>.</exception>
-        /// <exception cref="ArgumentException"><paramref name="path" /> is <see langword="null"/> or contains whitespace.</exception>
-        /// <exception cref="SshConnectionException">Client not connected.</exception>
-        private void InternalUploadFile(Stream input, string path, Flags flags, SftpUploadAsyncResult? asyncResult, Action<ulong>? uploadCallback)
+#pragma warning disable S6966 // Awaitable method should be used
+        private async Task InternalUploadFile(
+            Stream input,
+            string path,
+            Flags flags,
+            SftpUploadAsyncResult? asyncResult,
+            Action<ulong>? uploadCallback,
+            bool isAsync,
+            CancellationToken cancellationToken)
         {
-            ThrowHelper.ThrowIfNull(input);
-            ThrowHelper.ThrowIfNullOrWhiteSpace(path);
+            Debug.Assert(isAsync || cancellationToken == default);
 
             if (_sftpSession is null)
             {
                 throw new SshConnectionException("Client not connected.");
             }
 
-            var fullPath = _sftpSession.GetCanonicalPath(path);
+            string fullPath;
+            byte[] handle;
 
-            var handle = _sftpSession.RequestOpen(fullPath, flags);
+            if (isAsync)
+            {
+                fullPath = await _sftpSession.GetCanonicalPathAsync(path, cancellationToken).ConfigureAwait(false);
+                handle = await _sftpSession.RequestOpenAsync(fullPath, flags, cancellationToken).ConfigureAwait(false);
+            }
+            else
+            {
+                fullPath = _sftpSession.GetCanonicalPath(path);
+                handle = _sftpSession.RequestOpen(fullPath, flags);
+            }
 
             ulong offset = 0;
 
             // create buffer of optimal length
             var buffer = new byte[_sftpSession.CalculateOptimalWriteLength(_bufferSize, handle)];
 
-            int bytesRead;
             var expectedResponses = 0;
 
             // We will send out all the write requests without waiting for each response.
@@ -2495,8 +2532,21 @@ namespace Renci.SshNet
 
             ExceptionDispatchInfo? exception = null;
 
-            while ((bytesRead = input.Read(buffer, 0, buffer.Length)) != 0)
+            while (true)
             {
+                var bytesRead = isAsync
+#if NET
+                    ? await input.ReadAsync(buffer, cancellationToken).ConfigureAwait(false)
+#else
+                    ? await input.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false)
+#endif
+                    : input.Read(buffer, 0, buffer.Length);
+
+                if (bytesRead == 0)
+                {
+                    break;
+                }
+
                 if (asyncResult is not null && asyncResult.IsUploadCanceled)
                 {
                     break;
@@ -2555,34 +2605,28 @@ namespace Renci.SshNet
 
             if (Volatile.Read(ref expectedResponses) != 0)
             {
-                _sftpSession.WaitOnHandle(mres.WaitHandle, _operationTimeout);
+                if (isAsync)
+                {
+                    await _sftpSession.WaitOnHandleAsync(mres.WaitHandle, _operationTimeout, cancellationToken).ConfigureAwait(false);
+                }
+                else
+                {
+                    _sftpSession.WaitOnHandle(mres.WaitHandle, _operationTimeout);
+                }
             }
 
             exception?.Throw();
 
-            _sftpSession.RequestClose(handle);
-        }
-
-        private async Task InternalUploadFileAsync(Stream input, string path, CancellationToken cancellationToken)
-        {
-            ThrowHelper.ThrowIfNull(input);
-            ThrowHelper.ThrowIfNullOrWhiteSpace(path);
-
-            if (_sftpSession is null)
+            if (isAsync)
             {
-                throw new SshConnectionException("Client not connected.");
+                await _sftpSession.RequestCloseAsync(handle, cancellationToken).ConfigureAwait(false);
             }
-
-            cancellationToken.ThrowIfCancellationRequested();
-
-            var fullPath = await _sftpSession.GetCanonicalPathAsync(path, cancellationToken).ConfigureAwait(false);
-            var openStreamTask = SftpFileStream.OpenAsync(_sftpSession, fullPath, FileMode.Create, FileAccess.Write, (int)_bufferSize, cancellationToken);
-
-            using (var output = await openStreamTask.ConfigureAwait(false))
+            else
             {
-                await input.CopyToAsync(output, 81920, cancellationToken).ConfigureAwait(false);
+                _sftpSession.RequestClose(handle);
             }
         }
+#pragma warning restore S6966 // Awaitable method should be used
 
         /// <summary>
         /// Called when client is connected to the server.

+ 68 - 43
src/Renci.SshNet/SubsystemSession.cs

@@ -257,56 +257,81 @@ namespace Renci.SshNet
             }
         }
 
-        protected async Task<T> WaitOnHandleAsync<T>(TaskCompletionSource<T> tcs, int millisecondsTimeout, CancellationToken cancellationToken)
+        public async Task WaitOnHandleAsync(WaitHandle waitHandle, int millisecondsTimeout, CancellationToken cancellationToken)
         {
-            cancellationToken.ThrowIfCancellationRequested();
-
-            var errorOccurredReg = ThreadPool.RegisterWaitForSingleObject(
-                _errorOccurredWaitHandle,
-                (tcs, _) => ((TaskCompletionSource<T>)tcs).TrySetException(_exception),
-                state: tcs,
-                millisecondsTimeOutInterval: -1,
-                executeOnlyOnce: true);
-
-            var sessionDisconnectedReg = ThreadPool.RegisterWaitForSingleObject(
-                _sessionDisconnectedWaitHandle,
-                static (tcs, _) => ((TaskCompletionSource<T>)tcs).TrySetException(new SshException("Connection was closed by the server.")),
-                state: tcs,
-                millisecondsTimeOutInterval: -1,
-                executeOnlyOnce: true);
-
-            var channelClosedReg = ThreadPool.RegisterWaitForSingleObject(
-                _channelClosedWaitHandle,
-                static (tcs, _) => ((TaskCompletionSource<T>)tcs).TrySetException(new SshException("Channel was closed.")),
-                state: tcs,
-                millisecondsTimeOutInterval: -1,
-                executeOnlyOnce: true);
-
-            using var timeoutCts = new CancellationTokenSource(millisecondsTimeout);
-            using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token);
-
-            using var tokenReg = linkedCts.Token.Register(
-                static s =>
-                {
-                    (var tcs, var cancellationToken) = ((TaskCompletionSource<T>, CancellationToken))s;
-                    _ = tcs.TrySetCanceled(cancellationToken);
-                },
-                state: (tcs, cancellationToken),
-                useSynchronizationContext: false);
+            var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
 
-            try
+            using RegisteredWait reg = new(
+                waitHandle,
+                (tcs, _) => ((TaskCompletionSource<object>)tcs).TrySetResult(null),
+                state: tcs);
+
+            _ = await WaitOnHandleAsync(tcs, millisecondsTimeout, cancellationToken).ConfigureAwait(false);
+        }
+
+        public Task<T> WaitOnHandleAsync<T>(TaskCompletionSource<T> tcs, int millisecondsTimeout, CancellationToken cancellationToken)
+        {
+            return tcs.Task.IsCompleted ? tcs.Task
+                : cancellationToken.IsCancellationRequested ? Task.FromCanceled<T>(cancellationToken)
+                : DoWaitAsync(tcs, millisecondsTimeout, cancellationToken);
+
+            async Task<T> DoWaitAsync(TaskCompletionSource<T> tcs, int millisecondsTimeout, CancellationToken cancellationToken)
             {
-                return await tcs.Task.ConfigureAwait(false);
+                using RegisteredWait errorOccurredReg = new(
+                    _errorOccurredWaitHandle,
+                    (tcs, _) => ((TaskCompletionSource<T>)tcs).TrySetException(_exception),
+                    state: tcs);
+
+                using RegisteredWait sessionDisconnectedReg = new(
+                    _sessionDisconnectedWaitHandle,
+                    static (tcs, _) => ((TaskCompletionSource<T>)tcs).TrySetException(new SshException("Connection was closed by the server.")),
+                    state: tcs);
+
+                using RegisteredWait channelClosedReg = new(
+                    _channelClosedWaitHandle,
+                    static (tcs, _) => ((TaskCompletionSource<T>)tcs).TrySetException(new SshException("Channel was closed.")),
+                    state: tcs);
+
+                using var timeoutCts = new CancellationTokenSource(millisecondsTimeout);
+                using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token);
+
+                using var tokenReg = linkedCts.Token.Register(
+                    static s =>
+                    {
+                        (var tcs, var cancellationToken) = ((TaskCompletionSource<T>, CancellationToken))s;
+                        _ = tcs.TrySetCanceled(cancellationToken);
+                    },
+                    state: (tcs, cancellationToken),
+                    useSynchronizationContext: false);
+
+                try
+                {
+                    return await tcs.Task.ConfigureAwait(false);
+                }
+                catch (OperationCanceledException oce) when (timeoutCts.IsCancellationRequested)
+                {
+                    throw new SshOperationTimeoutException("Operation has timed out.", oce);
+                }
             }
-            catch (OperationCanceledException oce) when (timeoutCts.IsCancellationRequested)
+        }
+
+        private readonly struct RegisteredWait : IDisposable
+        {
+            private readonly RegisteredWaitHandle _handle;
+
+            public RegisteredWait(WaitHandle waitObject, WaitOrTimerCallback callback, object state)
             {
-                throw new SshOperationTimeoutException("Operation has timed out.", oce);
+                _handle = ThreadPool.RegisterWaitForSingleObject(
+                    waitObject,
+                    callback,
+                    state,
+                    millisecondsTimeOutInterval: -1,
+                    executeOnlyOnce: true);
             }
-            finally
+
+            public void Dispose()
             {
-                _ = errorOccurredReg.Unregister(waitObject: null);
-                _ = sessionDisconnectedReg.Unregister(waitObject: null);
-                _ = channelClosedReg.Unregister(waitObject: null);
+                _ = _handle.Unregister(waitObject: null);
             }
         }