浏览代码

Fix async SFTP operations and ensure that only valid asyncResult objct can be used with End* operation

olegkap_cp 14 年之前
父节点
当前提交
a1b3ce729b

+ 16 - 0
Renci.SshClient/Renci.SshNet.Tests/SftpClientTests/ListDirectoryTest.cs

@@ -248,5 +248,21 @@ namespace Renci.SshNet.Tests.SftpClientTests
 				sftp.Disconnect();
 			}
 		}
+
+        [TestMethod]
+        [TestCategory("Sftp")]
+        [Description("Test calling EndListDirectory method more then once.")]
+        [ExpectedException(typeof(ArgumentException))]
+        public void Test_Sftp_Call_EndListDirectory_Twice()
+        {
+            using (var sftp = new SftpClient(Resources.HOST, Resources.USERNAME, Resources.PASSWORD))
+            {
+                sftp.Connect();
+                var ar = sftp.BeginListDirectory("/", null, null);
+                var result = sftp.EndListDirectory(ar);
+                var result1 = sftp.EndListDirectory(ar);
+            }
+        }
+
     }
 }

+ 35 - 2
Renci.SshClient/Renci.SshNet.Tests/SftpClientTests/UploadDownloadFileTest.cs

@@ -15,7 +15,6 @@ namespace Renci.SshNet.Tests.SftpClientTests
     [TestClass]
     public class UploadDownloadFileTest
     {
-
         [TestInitialize()]
         public void CleanCurrentFolder()
         {
@@ -149,7 +148,7 @@ namespace Renci.SshNet.Tests.SftpClientTests
                 {
                     using (var ms = new MemoryStream())
                     {
-                        sftp.UploadFile(ms, remoteFileName);
+                        sftp.DownloadFile(remoteFileName, ms);
                     }
                 }
                 catch (SshFileNotFoundException)
@@ -475,6 +474,40 @@ namespace Renci.SshNet.Tests.SftpClientTests
             }
         }
 
+        [TestMethod]
+        [TestCategory("Sftp")]
+        [ExpectedException(typeof(ArgumentException))]
+        public void Test_Sftp_EndUploadFile_Invalid_Async_Handle()
+        {
+            using (var sftp = new SftpClient(Resources.HOST, Resources.USERNAME, Resources.PASSWORD))
+            {
+                sftp.Connect();
+                var async1 = sftp.BeginListDirectory("/", null, null);
+                var filename = Path.GetTempFileName();
+                this.CreateTestFile(filename, 100);
+                var async2 = sftp.BeginUploadFile(File.OpenRead(filename), "test", null, null);
+                sftp.EndUploadFile(async1);
+            }
+        }
+
+        [TestMethod]
+        [TestCategory("Sftp")]
+        [ExpectedException(typeof(ArgumentException))]
+        public void Test_Sftp_EndDownloadFile_Invalid_Async_Handle()
+        {
+            using (var sftp = new SftpClient(Resources.HOST, Resources.USERNAME, Resources.PASSWORD))
+            {
+                sftp.Connect();
+                var filename = Path.GetTempFileName();
+                this.CreateTestFile(filename, 1);
+                sftp.UploadFile(File.OpenRead(filename), "test123");
+                var async1 = sftp.BeginListDirectory("/", null, null);
+                var async2 = sftp.BeginDownloadFile("test123", new MemoryStream(), null, null);
+                sftp.EndDownloadFile(async1);
+            }
+        }
+
+
         /// <summary>
         /// Creates the test file.
         /// </summary>

+ 97 - 75
Renci.SshClient/Renci.SshNet/Sftp/SftpAsyncResult.cs

@@ -1,13 +1,34 @@
 using System;
 using System.Threading;
+using System.Threading.Tasks;
 
 namespace Renci.SshNet.Sftp
 {
     /// <summary>
     /// Represents the status of an asynchronous SFTP operation.
     /// </summary>
-    public class SftpAsyncResult : IAsyncResult, IDisposable
+    public class SftpAsyncResult : IAsyncResult
     {
+        private const int _statePending = 0;
+
+        private const int _stateCompletedSynchronously = 1;
+
+        private const int _stateCompletedAsynchronously = 2;
+
+        private readonly AsyncCallback _asyncCallback;
+
+        private readonly Object _asyncState;
+
+        private Exception _exception;
+
+        private ManualResetEvent _asyncWaitHandle;
+
+        private int _completedState = _statePending;
+
+        private SftpSession _sftpSession;
+
+        private TimeSpan _commandTimeout;
+
         /// <summary>
         /// Gets or sets the uploaded bytes.
         /// </summary>
@@ -20,35 +41,61 @@ namespace Renci.SshNet.Sftp
         /// <value>The downloaded bytes.</value>
         public ulong DownloadedBytes { get; internal set; }
 
-        private SftpCommand _command;
-
         /// <summary>
         /// Initializes a new instance of the <see cref="SftpAsyncResult"/> class.
         /// </summary>
-        /// <param name="command">The command.</param>
+        /// <param name="sftpSession">The SFTP session.</param>
+        /// <param name="commandTimeout">The command timeout.</param>
+        /// <param name="asyncCallback">The async callback.</param>
         /// <param name="state">The state.</param>
-        internal SftpAsyncResult(SftpCommand command, object state)
+        internal SftpAsyncResult(SftpSession sftpSession, TimeSpan commandTimeout, AsyncCallback asyncCallback, object state)
         {
-            this._command = command;
-            this.AsyncState = state;
-            this.AsyncWaitHandle = new ManualResetEvent(false);
+            this._sftpSession = sftpSession;
+            this._commandTimeout = commandTimeout;
+            this._asyncCallback = asyncCallback;
+            this._asyncState = state;
         }
 
         /// <summary>
-        /// Gets the command.
+        /// Marks result as completed.
         /// </summary>
-        /// <typeparam name="T">Type of the command</typeparam>
-        /// <returns></returns>
-        internal T GetCommand<T>() where T : SftpCommand
+        /// <param name="exception">The exception.</param>
+        /// <param name="completedSynchronously">if set to <c>true</c> [completed synchronously].</param>
+        public void SetAsCompleted(Exception exception, bool completedSynchronously)
         {
-            T cmd = this._command as T;
+            // Passing null for exception means no error occurred. 
+            // This is the common case
+            this._exception = exception;
+
+            // The _completedState field MUST be set prior calling the callback
+            Int32 prevState = Interlocked.Exchange(ref this._completedState, completedSynchronously ? _stateCompletedSynchronously : _stateCompletedAsynchronously);
+            if (prevState != _statePending)
+                throw new InvalidOperationException("You can set a result only once");
+
+            // If the event exists, set it
+            if (this._asyncWaitHandle != null)
+                this._asyncWaitHandle.Set();
+
+            // If a callback method was set, call it on different thread
+            if (this._asyncCallback != null)
+                Task.Factory.StartNew(() => { this._asyncCallback(this); });
+        }
 
-            if (cmd == null)
+        public void EndInvoke()
+        {
+            // This method assumes that only 1 thread calls EndInvoke 
+            // for this object
+            if (!this.IsCompleted)
             {
-                throw new InvalidOperationException("Not valid IAsyncResult object.");
+                // If the operation isn't done, wait for it
+                this._sftpSession.WaitHandle(this.AsyncWaitHandle, this._commandTimeout);
+                this.AsyncWaitHandle.Close();
+                this._asyncWaitHandle = null;  // Allow early GC
             }
 
-            return cmd;
+            // Operation is done: if an exception occurred, throw it
+            if (this._exception != null)
+                throw _exception;
         }
 
         #region IAsyncResult Members
@@ -57,89 +104,64 @@ namespace Renci.SshNet.Sftp
         /// Gets a user-defined object that qualifies or contains information about an asynchronous operation.
         /// </summary>
         /// <returns>A user-defined object that qualifies or contains information about an asynchronous operation.</returns>
-        public object AsyncState { get; private set; }
-
-        /// <summary>
-        /// Gets a <see cref="T:System.Threading.WaitHandle"/> that is used to wait for an asynchronous operation to complete.
-        /// </summary>
-        /// <returns>A <see cref="T:System.Threading.WaitHandle"/> that is used to wait for an asynchronous operation to complete.</returns>
-        public WaitHandle AsyncWaitHandle { get; private set; }
+        public Object AsyncState { get { return this._asyncState; } }
 
         /// <summary>
         /// Gets a value that indicates whether the asynchronous operation completed synchronously.
         /// </summary>
         /// <returns>true if the asynchronous operation completed synchronously; otherwise, false.</returns>
-        public bool CompletedSynchronously { get; private set; }
-
-        /// <summary>
-        /// Gets a value that indicates whether the asynchronous operation has completed.
-        /// </summary>
-        /// <returns>true if the operation is complete; otherwise, false.</returns>
-        public bool IsCompleted { get; private set; }
-
-        #endregion
-
-        #region IDisposable Members
-
-        private bool _disposed = false;
-
-        /// <summary>
-        /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
-        /// </summary>
-        public void Dispose()
+        public Boolean CompletedSynchronously
         {
-            Dispose(true);
-
-            GC.SuppressFinalize(this);
+            get
+            {
+                return Thread.VolatileRead(ref this._completedState) == _stateCompletedSynchronously;
+            }
         }
 
         /// <summary>
-        /// Releases unmanaged and - optionally - managed resources
+        /// Gets a <see cref="T:System.Threading.WaitHandle"/> that is used to wait for an asynchronous operation to complete.
         /// </summary>
-        /// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
-        protected virtual void Dispose(bool disposing)
+        /// <returns>A <see cref="T:System.Threading.WaitHandle"/> that is used to wait for an asynchronous operation to complete.</returns>
+        public WaitHandle AsyncWaitHandle
         {
-            // Check to see if Dispose has already been called.
-            if (!this._disposed)
+            get
             {
-                // If disposing equals true, dispose all managed
-                // and unmanaged resources.
-                if (disposing)
+                if (this._asyncWaitHandle == null)
                 {
-                    // Dispose managed resources.
-                    if (this.AsyncWaitHandle != null)
+                    Boolean done = this.IsCompleted;
+                    ManualResetEvent mre = new ManualResetEvent(done);
+                    if (Interlocked.CompareExchange(ref this._asyncWaitHandle, mre, null) != null)
                     {
-                        this.AsyncWaitHandle.Dispose();
-                        this.AsyncWaitHandle = null;
+                        // Another thread created this object's event; dispose 
+                        // the event we just created
+                        mre.Close();
+                    }
+                    else
+                    {
+                        if (!done && this.IsCompleted)
+                        {
+                            // If the operation wasn't done when we created 
+                            // the event but now it is done, set the event
+                            this._asyncWaitHandle.Set();
+                        }
                     }
                 }
-
-                // Note disposing has been done.
-                this._disposed = true;
+                return this._asyncWaitHandle;
             }
         }
 
         /// <summary>
-        /// Releases unmanaged resources and performs other cleanup operations before the
-        /// <see cref="SftpAsyncResult"/> is reclaimed by garbage collection.
+        /// Gets a value that indicates whether the asynchronous operation has completed.
         /// </summary>
-        ~SftpAsyncResult()
+        /// <returns>true if the operation is complete; otherwise, false.</returns>
+        public Boolean IsCompleted
         {
-            // Do not re-create Dispose clean-up code here.
-            // Calling Dispose(false) is optimal in terms of
-            // readability and maintainability.
-            Dispose(false);
+            get
+            {
+                return Thread.VolatileRead(ref this._completedState) != _statePending;
+            }
         }
 
         #endregion
-
-        /// <summary>
-        /// Completes asynchronous operation.
-        /// </summary>
-        internal void Complete()
-        {
-            this.IsCompleted = true;
-            ((EventWaitHandle)this.AsyncWaitHandle).Set();
-        }
     }
 }

+ 9 - 38
Renci.SshClient/Renci.SshNet/Sftp/SftpCommand.cs

@@ -13,12 +13,8 @@ namespace Renci.SshNet.Sftp
     /// </summary>
     internal abstract class SftpCommand : IDisposable
     {
-        private AsyncCallback _callback;
-
         private uint _requestId;
 
-        private Exception _error;
-
         private bool _handleCloseMessageSent;
 
         protected SftpAsyncResult AsyncResult { get; private set; }
@@ -32,6 +28,7 @@ namespace Renci.SshNet.Sftp
         public SftpCommand(SftpSession sftpSession)
         {
             this.SftpSession = sftpSession;
+
             this.SftpSession.AttributesMessageReceived += SftpSession_AttributesMessageReceived;
             this.SftpSession.DataMessageReceived += SftpSession_DataMessageReceived;
             this.SftpSession.HandleMessageReceived += SftpSession_HandleMessageReceived;
@@ -42,9 +39,7 @@ namespace Renci.SshNet.Sftp
 
         public SftpAsyncResult BeginExecute(AsyncCallback callback, object state)
         {
-            this._callback = callback;
-
-            this.AsyncResult = new SftpAsyncResult(this, state);
+            this.AsyncResult = new SftpAsyncResult(this.SftpSession, this.CommandTimeout, callback, state);
 
             this.OnExecute();
 
@@ -53,11 +48,7 @@ namespace Renci.SshNet.Sftp
 
         public void EndExecute(SftpAsyncResult result)
         {
-            this.SftpSession.WaitHandle(result.AsyncWaitHandle, this.CommandTimeout);
-
-            if (this._error != null)
-                throw this._error;
-
+            result.EndInvoke();
         }
 
         public void Execute()
@@ -191,23 +182,14 @@ namespace Renci.SshNet.Sftp
 
         protected void CompleteExecution()
         {
-            //  Call callback if exists
-            if (this._callback != null)
-            {
-                //  Execute callback on new pool thread
-                Task.Factory.StartNew(() =>
-                {
-                    this._callback(this.AsyncResult);
-                });
-            }
-
-            this.AsyncResult.Complete();
-
             this.SftpSession.AttributesMessageReceived -= SftpSession_AttributesMessageReceived;
             this.SftpSession.DataMessageReceived -= SftpSession_DataMessageReceived;
             this.SftpSession.HandleMessageReceived -= SftpSession_HandleMessageReceived;
             this.SftpSession.NameMessageReceived -= SftpSession_NameMessageReceived;
             this.SftpSession.StatusMessageReceived -= SftpSession_StatusMessageReceived;
+            this.SftpSession.ErrorOccured -= SftpSession_ErrorOccured;
+
+            this.AsyncResult.SetAsCompleted(null, false);
         }
 
         private void SftpSession_StatusMessageReceived(object sender, MessageEventArgs<StatusMessage> e)
@@ -232,8 +214,7 @@ namespace Renci.SshNet.Sftp
                  e.Message.StatusCode == StatusCodes.BadMessage ||
                  e.Message.StatusCode == StatusCodes.NoConnection ||
                  e.Message.StatusCode == StatusCodes.ConnectionLost ||
-                 e.Message.StatusCode == StatusCodes.OperationUnsupported
-                 )
+                 e.Message.StatusCode == StatusCodes.OperationUnsupported)
                 {
                     //  Throw an exception if it was not handled by the command
                     throw new SshException(e.Message.ErrorMessage);
@@ -275,9 +256,7 @@ namespace Renci.SshNet.Sftp
 
         private void SftpSession_ErrorOccured(object sender, ErrorEventArgs e)
         {
-            this._error = e.GetException();
-
-            this.CompleteExecution();
+            this.AsyncResult.SetAsCompleted(e.GetException(), false);
         }
 
         private void SendMessage(SftpRequestMessage message)
@@ -288,7 +267,6 @@ namespace Renci.SshNet.Sftp
             this.SftpSession.SendMessage(message);
         }
 
-
         #region IDisposable Members
 
         private bool _isDisposed = false;
@@ -323,17 +301,10 @@ namespace Renci.SshNet.Sftp
                     this.SftpSession.NameMessageReceived -= SftpSession_NameMessageReceived;
                     this.SftpSession.StatusMessageReceived -= SftpSession_StatusMessageReceived;
                     this.SftpSession.ErrorOccured -= SftpSession_ErrorOccured;
-
-                    // Dispose managed ResourceMessages.
-                    if (this.AsyncResult != null)
-                    {
-                        this.AsyncResult.Dispose();
-                        this.AsyncResult = null;
-                    }
                 }
 
                 // Note disposing has been done.
-                _isDisposed = true;
+                this._isDisposed = true;
             }
         }
 

+ 120 - 33
Renci.SshClient/Renci.SshNet/SftpClient.cs

@@ -16,6 +16,11 @@ namespace Renci.SshNet
         /// </summary>
         private SftpSession _sftpSession;
 
+        /// <summary>
+        /// Keeps track of all async command execution
+        /// </summary>
+        private Dictionary<SftpAsyncResult, SftpCommand> _asyncCommands = new Dictionary<SftpAsyncResult, SftpCommand>();
+
         /// <summary>
         /// Gets or sets the operation timeout.
         /// </summary>
@@ -219,7 +224,7 @@ namespace Renci.SshNet
 
             if (newPath == null)
                 throw new ArgumentNullException("newPath");
-            
+
             //  Ensure that connection is established.
             this.EnsureConnection();
 
@@ -294,7 +299,14 @@ namespace Renci.SshNet
 
             cmd.CommandTimeout = this.OperationTimeout;
 
-            return cmd.BeginExecute(asyncCallback, state);
+            var async = cmd.BeginExecute(asyncCallback, state);
+
+            lock (this._asyncCommands)
+            {
+                this._asyncCommands.Add(async, cmd);
+            }
+
+            return async;
         }
 
         /// <summary>
@@ -304,22 +316,38 @@ namespace Renci.SshNet
         /// <returns>List of files</returns>
         public IEnumerable<SftpFile> EndListDirectory(IAsyncResult asyncResult)
         {
-            var sftpAsyncResult = asyncResult as SftpAsyncResult;
+            var sftpAsync = asyncResult as SftpAsyncResult;
 
-            if (sftpAsyncResult == null)
+            if (this._asyncCommands.ContainsKey(sftpAsync))
             {
-                throw new InvalidOperationException("Not valid IAsyncResult object.");
+                lock (this._asyncCommands)
+                {
+                    if (this._asyncCommands.ContainsKey(sftpAsync))
+                    {
+                        var cmd = this._asyncCommands[sftpAsync] as ListDirectoryCommand;
+
+                        if (cmd != null)
+                        {
+                            try
+                            {
+                                this._asyncCommands.Remove(sftpAsync);
+
+                                cmd.EndExecute(sftpAsync);
+
+                                var files = cmd.Files;
+
+                                return files;
+                            }
+                            finally
+                            {
+                                cmd.Dispose();
+                            }
+                        }
+                    }
+                }
             }
 
-            var cmd = sftpAsyncResult.GetCommand<ListDirectoryCommand>();
-
-            cmd.EndExecute(sftpAsyncResult);
-
-            var files = cmd.Files;
-
-            cmd.Dispose();
-
-            return files;
+            throw new ArgumentException("Either the IAsyncResult object did not come from the corresponding async method on this type, or EndListDirectory was called multiple times with the same IAsyncResult.");
         }
 
         /// <summary>
@@ -381,7 +409,14 @@ namespace Renci.SshNet
 
             cmd.CommandTimeout = this.OperationTimeout;
 
-            return cmd.BeginExecute(asyncCallback, state);
+            var async = cmd.BeginExecute(asyncCallback, state);
+
+            lock (this._asyncCommands)
+            {
+                this._asyncCommands.Add(async, cmd);
+            }
+
+            return async;
         }
 
         /// <summary>
@@ -390,19 +425,36 @@ namespace Renci.SshNet
         /// <param name="asyncResult">The pending asynchronous SFTP request.</param>
         public void EndDownloadFile(IAsyncResult asyncResult)
         {
-            var sftpAsyncResult = asyncResult as SftpAsyncResult;
+            var sftpAsync = asyncResult as SftpAsyncResult;
 
-            if (sftpAsyncResult == null)
+            if (this._asyncCommands.ContainsKey(sftpAsync))
             {
-                throw new InvalidOperationException("Not valid IAsyncResult object.");
+                lock (this._asyncCommands)
+                {
+                    if (this._asyncCommands.ContainsKey(sftpAsync))
+                    {
+                        var cmd = this._asyncCommands[sftpAsync] as DownloadFileCommand;
+
+                        if (cmd != null)
+                        {
+                            try
+                            {
+                                this._asyncCommands.Remove(sftpAsync);
+
+                                cmd.EndExecute(sftpAsync);
+
+                                return;
+                            }
+                            finally
+                            {
+                                cmd.Dispose();
+                            }
+                        }
+                    }
+                }
             }
 
-            var cmd = sftpAsyncResult.GetCommand<DownloadFileCommand>();
-
-            cmd.EndExecute(sftpAsyncResult);
-
-            cmd.Dispose();
-
+            throw new ArgumentException("Either the IAsyncResult object did not come from the corresponding async method on this type, or EndDownloadFile was called multiple times with the same IAsyncResult.");
         }
 
         /// <summary>
@@ -440,7 +492,14 @@ namespace Renci.SshNet
 
             cmd.CommandTimeout = this.OperationTimeout;
 
-            return cmd.BeginExecute(asyncCallback, state);
+            var async = cmd.BeginExecute(asyncCallback, state);
+
+            lock (this._asyncCommands)
+            {
+                this._asyncCommands.Add(async, cmd);
+            }
+
+            return async;
         }
 
         /// <summary>
@@ -449,18 +508,36 @@ namespace Renci.SshNet
         /// <param name="asyncResult">The pending asynchronous SFTP request.</param>
         public void EndUploadFile(IAsyncResult asyncResult)
         {
-            var sftpAsyncResult = asyncResult as SftpAsyncResult;
+            var sftpAsync = asyncResult as SftpAsyncResult;
 
-            if (sftpAsyncResult == null)
+            if (this._asyncCommands.ContainsKey(sftpAsync))
             {
-                throw new InvalidOperationException("Not valid IAsyncResult object.");
+                lock (this._asyncCommands)
+                {
+                    if (this._asyncCommands.ContainsKey(sftpAsync))
+                    {
+                        var cmd = this._asyncCommands[sftpAsync] as UploadFileCommand;
+
+                        if (cmd != null)
+                        {
+                            try
+                            {
+                                this._asyncCommands.Remove(sftpAsync);
+
+                                cmd.EndExecute(sftpAsync);
+
+                                return;
+                            }
+                            finally
+                            {
+                                cmd.Dispose();
+                            }
+                        }
+                    }
+                }
             }
 
-            var cmd = sftpAsyncResult.GetCommand<UploadFileCommand>();
-
-            cmd.EndExecute(sftpAsyncResult);
-
-            cmd.Dispose();
+            throw new ArgumentException("Either the IAsyncResult object did not come from the corresponding async method on this type, or EndUploadFile was called multiple times with the same IAsyncResult.");
         }
 
         /// <summary>
@@ -500,6 +577,16 @@ namespace Renci.SshNet
                 this._sftpSession = null;
             }
 
+            if (this._asyncCommands != null)
+            {
+                foreach (var command in this._asyncCommands.Values)
+                {
+                    command.Dispose();
+                }
+
+                this._asyncCommands = null;
+            }
+
             base.Dispose(disposing);
         }
     }

+ 4 - 9
Renci.SshClient/Renci.SshNet/SshCommand.cs

@@ -220,18 +220,13 @@ namespace Renci.SshNet
                         this._channel = null;
 
                         this._asyncResult = null;
-                    }
-                    else
-                    {
-                        throw new ArgumentException("Either the IAsyncResult object did not come from the corresponding async method on this type, or EndExecute was called multiple times with the same IAsyncResult.");
+
+                        return this.Result;
                     }
                 }
             }
-            else
-            {
-                throw new ArgumentException("Either the IAsyncResult object did not come from the corresponding async method on this type, or EndExecute was called multiple times with the same IAsyncResult.");
-            }
-            return this.Result;
+
+            throw new ArgumentException("Either the IAsyncResult object did not come from the corresponding async method on this type, or EndExecute was called multiple times with the same IAsyncResult.");
         }
 
         /// <summary>