Răsfoiți Sursa

Fix hang in SftpClient.UploadFile upon error (#1643)

* Fix deadlock in SftpClient.UploadFile upon error

* Make RequestWrite deterministic wrt. exception handling

* add regression test; fix race

* x
Rob Hague 5 luni în urmă
părinte
comite
a024b83def

+ 8 - 1
src/Renci.SshNet/Sftp/Responses/SftpStatusResponse.cs

@@ -12,7 +12,7 @@
         {
         }
 
-        public StatusCodes StatusCode { get; private set; }
+        public StatusCodes StatusCode { get; set; }
 
         public string ErrorMessage { get; private set; }
 
@@ -39,5 +39,12 @@
                 Language = ReadString(Ascii);
             }
         }
+
+        protected override void SaveData()
+        {
+            base.SaveData();
+
+            Write((uint)StatusCode);
+        }
     }
 }

+ 17 - 9
src/Renci.SshNet/Sftp/SftpSession.cs

@@ -1,5 +1,6 @@
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Globalization;
 using System.Text;
 using System.Threading;
@@ -914,6 +915,8 @@ namespace Renci.SshNet.Sftp
                                  AutoResetEvent wait,
                                  Action<SftpStatusResponse> writeCompleted = null)
         {
+            Debug.Assert((wait is null) != (writeCompleted is null), "Should have one parameter or the other.");
+
             SshException exception = null;
 
             var request = new SftpWriteRequest(ProtocolVersion,
@@ -925,10 +928,15 @@ namespace Renci.SshNet.Sftp
                                                length,
                                                response =>
                                                {
-                                                   writeCompleted?.Invoke(response);
-
-                                                   exception = GetSftpException(response);
-                                                   wait?.SetIgnoringObjectDisposed();
+                                                   if (writeCompleted is not null)
+                                                   {
+                                                       writeCompleted.Invoke(response);
+                                                   }
+                                                   else
+                                                   {
+                                                       exception = GetSftpException(response);
+                                                       wait.SetIgnoringObjectDisposed();
+                                                   }
                                                });
 
             SendRequest(request);
@@ -936,11 +944,11 @@ namespace Renci.SshNet.Sftp
             if (wait is not null)
             {
                 WaitOnHandle(wait, OperationTimeout);
-            }
 
-            if (exception is not null)
-            {
-                throw exception;
+                if (exception is not null)
+                {
+                    throw exception;
+                }
             }
         }
 
@@ -2272,7 +2280,7 @@ namespace Renci.SshNet.Sftp
             return Math.Min(bufferSize, maximumPacketSize) - lengthOfNonDataProtocolFields;
         }
 
-        private static SshException GetSftpException(SftpStatusResponse response)
+        internal static SshException GetSftpException(SftpStatusResponse response)
         {
 #pragma warning disable IDE0010 // Add missing cases
             switch (response.StatusCode)

+ 59 - 31
src/Renci.SshNet/SftpClient.cs

@@ -1,11 +1,13 @@
 #nullable enable
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Diagnostics.CodeAnalysis;
 using System.Globalization;
 using System.IO;
 using System.Net;
 using System.Runtime.CompilerServices;
+using System.Runtime.ExceptionServices;
 using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
@@ -2456,56 +2458,82 @@ namespace Renci.SshNet
             // create buffer of optimal length
             var buffer = new byte[_sftpSession.CalculateOptimalWriteLength(_bufferSize, handle)];
 
-            var bytesRead = input.Read(buffer, 0, buffer.Length);
+            int bytesRead;
             var expectedResponses = 0;
-            var responseReceivedWaitHandle = new AutoResetEvent(initialState: false);
 
-            do
+            // We will send out all the write requests without waiting for each response.
+            // Afterwards, we may wait on this handle until all responses are received
+            // or an error has occured.
+            using var mres = new ManualResetEventSlim(initialState: false);
+
+            ExceptionDispatchInfo? exception = null;
+
+            while ((bytesRead = input.Read(buffer, 0, buffer.Length)) != 0)
             {
-                // Cancel upload
                 if (asyncResult is not null && asyncResult.IsUploadCanceled)
                 {
                     break;
                 }
 
-                if (bytesRead > 0)
+                exception?.Throw();
+
+                var writtenBytes = offset + (ulong)bytesRead;
+
+                _ = Interlocked.Increment(ref expectedResponses);
+                mres.Reset();
+
+                _sftpSession.RequestWrite(handle, offset, buffer, offset: 0, bytesRead, wait: null, s =>
                 {
-                    var writtenBytes = offset + (ulong)bytesRead;
+                    var setHandle = false;
+
+                    try
+                    {
+                        if (Sftp.SftpSession.GetSftpException(s) is Exception ex)
+                        {
+                            exception = ExceptionDispatchInfo.Capture(ex);
+                        }
 
-                    _sftpSession.RequestWrite(handle, offset, buffer, offset: 0, bytesRead, wait: null, s =>
+                        if (exception is not null)
                         {
-                            if (s.StatusCode == StatusCodes.Ok)
-                            {
-                                _ = Interlocked.Decrement(ref expectedResponses);
-                                _ = responseReceivedWaitHandle.Set();
+                            setHandle = true;
+                            return;
+                        }
 
-                                asyncResult?.Update(writtenBytes);
+                        Debug.Assert(s.StatusCode == StatusCodes.Ok);
 
-                                // Call callback to report number of bytes written
-                                if (uploadCallback is not null)
-                                {
-                                    // Execute callback on different thread
-                                    ThreadAbstraction.ExecuteThread(() => uploadCallback(writtenBytes));
-                                }
-                            }
-                        });
+                        asyncResult?.Update(writtenBytes);
+
+                        // Call callback to report number of bytes written
+                        if (uploadCallback is not null)
+                        {
+                            // Execute callback on different thread
+                            ThreadAbstraction.ExecuteThread(() => uploadCallback(writtenBytes));
+                        }
+                    }
+                    finally
+                    {
+                        if (Interlocked.Decrement(ref expectedResponses) == 0 || setHandle)
+                        {
+                            mres.Set();
+                        }
+                    }
+                });
 
-                    _ = Interlocked.Increment(ref expectedResponses);
+                offset += (ulong)bytesRead;
+            }
 
-                    offset += (ulong)bytesRead;
+            // Make sure the read of exception cannot be executed ahead of
+            // the read of expectedResponses so that we do not miss an
+            // exception.
 
-                    bytesRead = input.Read(buffer, 0, buffer.Length);
-                }
-                else if (expectedResponses > 0)
-                {
-                    // Wait for expectedResponses to change
-                    _sftpSession.WaitOnHandle(responseReceivedWaitHandle, _operationTimeout);
-                }
+            if (Volatile.Read(ref expectedResponses) != 0)
+            {
+                _sftpSession.WaitOnHandle(mres.WaitHandle, _operationTimeout);
             }
-            while (expectedResponses > 0 || bytesRead > 0);
+
+            exception?.Throw();
 
             _sftpSession.RequestClose(handle);
-            responseReceivedWaitHandle.Dispose();
         }
 
         private async Task InternalUploadFileAsync(Stream input, string path, CancellationToken cancellationToken)

+ 251 - 0
test/Renci.SshNet.Tests/Classes/SftpClientTest.UploadFile.cs

@@ -0,0 +1,251 @@
+using System;
+using System.IO;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Moq;
+
+using Renci.SshNet.Channels;
+using Renci.SshNet.Common;
+using Renci.SshNet.Connection;
+using Renci.SshNet.Messages;
+using Renci.SshNet.Messages.Authentication;
+using Renci.SshNet.Messages.Connection;
+using Renci.SshNet.Sftp;
+using Renci.SshNet.Sftp.Responses;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    public partial class SftpClientTest
+    {
+        [TestMethod]
+        public void UploadFile_ObservesErrorResponses()
+        {
+            // A regression test for UploadFile hanging instead of observing
+            // error responses from the server.
+            // https://github.com/sshnet/SSH.NET/issues/957
+
+            var serviceFactoryMock = new Mock<IServiceFactory>();
+
+            var connInfo = new PasswordConnectionInfo("host", "user", "pwd");
+
+            var session = new MySession(connInfo);
+
+            var concreteServiceFactory = new ServiceFactory();
+
+            serviceFactoryMock
+                .Setup(p => p.CreateSession(It.IsAny<ConnectionInfo>(), It.IsAny<ISocketFactory>()))
+                .Returns(session);
+
+            serviceFactoryMock
+                .Setup(p => p.CreateSftpResponseFactory())
+                .Returns(concreteServiceFactory.CreateSftpResponseFactory);
+
+            serviceFactoryMock
+                .Setup(p => p.CreateSftpSession(session, It.IsAny<int>(), It.IsAny<Encoding>(), It.IsAny<ISftpResponseFactory>()))
+                .Returns(concreteServiceFactory.CreateSftpSession);
+
+            using var client = new SftpClient(connInfo, false, serviceFactoryMock.Object);
+            client.Connect();
+
+            Assert.Throws<SftpPermissionDeniedException>(() => client.UploadFile(
+                new OneByteStream(new MemoryStream("Hello World"u8.ToArray())),
+                "path.txt"));
+        }
+
+#pragma warning disable IDE0022 // Use block body for method
+#pragma warning disable IDE0025 // Use block body for property
+#pragma warning disable IDE0027 // Use block body for accessor
+#pragma warning disable CS0067 // event is unused
+
+        private class MySession(ConnectionInfo connectionInfo) : ISession
+        {
+            public IConnectionInfo ConnectionInfo => connectionInfo;
+
+            public event EventHandler<MessageEventArgs<ChannelCloseMessage>> ChannelCloseReceived;
+            public event EventHandler<MessageEventArgs<ChannelDataMessage>> ChannelDataReceived;
+            public event EventHandler<MessageEventArgs<ChannelEofMessage>> ChannelEofReceived;
+            public event EventHandler<MessageEventArgs<ChannelExtendedDataMessage>> ChannelExtendedDataReceived;
+            public event EventHandler<MessageEventArgs<ChannelFailureMessage>> ChannelFailureReceived;
+            public event EventHandler<MessageEventArgs<ChannelOpenConfirmationMessage>> ChannelOpenConfirmationReceived;
+            public event EventHandler<MessageEventArgs<ChannelOpenFailureMessage>> ChannelOpenFailureReceived;
+            public event EventHandler<MessageEventArgs<ChannelOpenMessage>> ChannelOpenReceived;
+            public event EventHandler<MessageEventArgs<ChannelRequestMessage>> ChannelRequestReceived;
+            public event EventHandler<MessageEventArgs<ChannelSuccessMessage>> ChannelSuccessReceived;
+            public event EventHandler<MessageEventArgs<ChannelWindowAdjustMessage>> ChannelWindowAdjustReceived;
+            public event EventHandler<EventArgs> Disconnected;
+            public event EventHandler<ExceptionEventArgs> ErrorOccured;
+            public event EventHandler<SshIdentificationEventArgs> ServerIdentificationReceived;
+            public event EventHandler<HostKeyEventArgs> HostKeyReceived;
+            public event EventHandler<MessageEventArgs<RequestSuccessMessage>> RequestSuccessReceived;
+            public event EventHandler<MessageEventArgs<RequestFailureMessage>> RequestFailureReceived;
+            public event EventHandler<MessageEventArgs<BannerMessage>> UserAuthenticationBannerReceived;
+
+            private uint _numRequests;
+            private int _numWriteRequests;
+
+            public void SendMessage(Message message)
+            {
+                // Initialisation sequence for SFTP session
+
+                if (message is ChannelOpenMessage)
+                {
+                    ChannelOpenConfirmationReceived?.Invoke(
+                        this,
+                        new MessageEventArgs<ChannelOpenConfirmationMessage>(
+                            new ChannelOpenConfirmationMessage(0, int.MaxValue, int.MaxValue, 0)));
+                }
+                else if (message is ChannelRequestMessage)
+                {
+                    ChannelSuccessReceived?.Invoke(
+                        this,
+                        new MessageEventArgs<ChannelSuccessMessage>(new ChannelSuccessMessage(0)));
+                }
+                else if (message is ChannelDataMessage dataMsg)
+                {
+                    if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.Init)
+                    {
+                        ChannelDataReceived?.Invoke(
+                            this,
+                            new MessageEventArgs<ChannelDataMessage>(
+                                new ChannelDataMessage(0, new SftpVersionResponse() { Version = 3 }.GetBytes())));
+                    }
+                    else if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.RealPath)
+                    {
+                        ChannelDataReceived?.Invoke(
+                            this,
+                            new MessageEventArgs<ChannelDataMessage>(
+                                new ChannelDataMessage(0,
+                                    new SftpNameResponse(3, Encoding.UTF8)
+                                    {
+                                        ResponseId = ++_numRequests,
+                                        Files = [new("thepath", new SftpFileAttributes(default, default, default, default, default, default, default))]
+                                    }.GetBytes())));
+                    }
+                    else if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.Open)
+                    {
+                        ChannelDataReceived?.Invoke(
+                            this,
+                            new MessageEventArgs<ChannelDataMessage>(
+                                new ChannelDataMessage(0,
+                                    new SftpHandleResponse(3)
+                                    {
+                                        ResponseId = ++_numRequests,
+                                        Handle = "file"u8.ToArray()
+                                    }.GetBytes())));
+                    }
+
+                    // --------- The actual interesting part of all of this ---------
+                    //
+                    else if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.Write)
+                    {
+                        // Fail the 5th write request
+                        var statusCode = ++_numWriteRequests == 5 ? StatusCodes.PermissionDenied : StatusCodes.Ok;
+                        var responseId = ++_numRequests;
+
+                        // Dispatch the responses on a different thread to simulate reality.
+                        _ = Task.Run(() =>
+                        {
+                            ChannelDataReceived?.Invoke(
+                                this,
+                                new MessageEventArgs<ChannelDataMessage>(
+                                    new ChannelDataMessage(0,
+                                        new SftpStatusResponse(3)
+                                        {
+                                            ResponseId = responseId,
+                                            StatusCode = statusCode
+                                        }.GetBytes())));
+                        });
+                    }
+                    //
+                    // --------------------------------------------------------------
+                }
+            }
+
+            public bool IsConnected => false;
+
+            public SemaphoreSlim SessionSemaphore { get; } = new(1);
+
+            public IChannelSession CreateChannelSession() => new ChannelSession(this, 0, int.MaxValue, int.MaxValue);
+
+            public WaitHandle MessageListenerCompleted => throw new NotImplementedException();
+
+            public void Connect()
+            {
+            }
+
+            public Task ConnectAsync(CancellationToken cancellationToken) => throw new NotImplementedException();
+
+            public IChannelDirectTcpip CreateChannelDirectTcpip() => throw new NotImplementedException();
+
+            public IChannelForwardedTcpip CreateChannelForwardedTcpip(uint remoteChannelNumber, uint remoteWindowSize, uint remoteChannelDataPacketSize)
+                => throw new NotImplementedException();
+
+            public void Dispose()
+            {
+            }
+
+            public void OnDisconnecting()
+            {
+            }
+
+            public void Disconnect() => throw new NotImplementedException();
+
+            public void RegisterMessage(string messageName) => throw new NotImplementedException();
+
+            public bool TrySendMessage(Message message) => throw new NotImplementedException();
+
+            public WaitResult TryWait(WaitHandle waitHandle, TimeSpan timeout, out Exception exception) => throw new NotImplementedException();
+
+            public WaitResult TryWait(WaitHandle waitHandle, TimeSpan timeout) => throw new NotImplementedException();
+
+            public void UnRegisterMessage(string messageName) => throw new NotImplementedException();
+
+            public void WaitOnHandle(WaitHandle waitHandle)
+            {
+            }
+
+            public void WaitOnHandle(WaitHandle waitHandle, TimeSpan timeout) => throw new NotImplementedException();
+        }
+
+        private class OneByteStream : Stream
+        {
+            private readonly Stream _stream;
+
+            public OneByteStream(Stream stream)
+            {
+                _stream = stream;
+            }
+
+            public override bool CanRead => _stream.CanRead;
+
+            public override bool CanSeek => throw new NotImplementedException();
+
+            public override bool CanWrite => throw new NotImplementedException();
+
+            public override long Length => _stream.Length;
+
+            public override long Position
+            {
+                get => throw new NotImplementedException();
+                set => throw new NotImplementedException();
+            }
+
+            public override void Flush() => throw new NotImplementedException();
+
+            public override int Read(byte[] buffer, int offset, int count)
+            {
+                return _stream.Read(buffer, offset, Math.Min(1, count));
+            }
+
+            public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException();
+
+            public override void SetLength(long value) => throw new NotImplementedException();
+
+            public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException();
+        }
+    }
+}