Quellcode durchsuchen

Fix CancelAsync Cause Deadlock (#1345)

* Fix CancelAsync Cause Deadlock

* Fix CancelAsync Cause Deadlock

* Support manual cancelling if exit-signal does not cancel

* Fix switch with duplicate case

* Revert wait exit response, use existing OperationCancelledException

* Not executing callback when command is cancelled
zeotuan vor 1 Jahr
Ursprung
Commit
db3d7e8d03

+ 36 - 17
src/Renci.SshNet/SshCommand.cs

@@ -26,11 +26,13 @@ namespace Renci.SshNet
         private CommandAsyncResult _asyncResult;
         private AsyncCallback _callback;
         private EventWaitHandle _sessionErrorOccuredWaitHandle;
+        private EventWaitHandle _commandCancelledWaitHandle;
         private Exception _exception;
         private StringBuilder _result;
         private StringBuilder _error;
         private bool _hasError;
         private bool _isDisposed;
+        private bool _isCancelled;
         private ChannelInputStream _inputStream;
         private TimeSpan _commandTimeout;
 
@@ -84,7 +86,7 @@ namespace Renci.SshNet
         /// <returns>
         /// The stream that can be used to transfer data to the command's input stream.
         /// </returns>
- #pragma warning disable CA1859 // Use concrete types when possible for improved performance
+#pragma warning disable CA1859 // Use concrete types when possible for improved performance
         public Stream CreateInputStream()
 #pragma warning restore CA1859 // Use concrete types when possible for improved performance
         {
@@ -186,7 +188,7 @@ namespace Renci.SshNet
             _encoding = encoding;
             CommandTimeout = Timeout.InfiniteTimeSpan;
             _sessionErrorOccuredWaitHandle = new AutoResetEvent(initialState: false);
-
+            _commandCancelledWaitHandle = new AutoResetEvent(initialState: false);
             _session.Disconnected += Session_Disconnected;
             _session.ErrorOccured += Session_ErrorOccured;
         }
@@ -249,11 +251,11 @@ namespace Renci.SshNet
 
             // Create new AsyncResult object
             _asyncResult = new CommandAsyncResult
-                {
-                    AsyncWaitHandle = new ManualResetEvent(initialState: false),
-                    IsCompleted = false,
-                    AsyncState = state,
-                };
+            {
+                AsyncWaitHandle = new ManualResetEvent(initialState: false),
+                IsCompleted = false,
+                AsyncState = state,
+            };
 
             if (_channel is not null)
             {
@@ -349,20 +351,25 @@ namespace Renci.SshNet
 
                 commandAsyncResult.EndCalled = true;
 
-                return Result;
+                if (!_isCancelled)
+                {
+                    return Result;
+                }
+
+                SetAsyncComplete();
+                throw new OperationCanceledException();
             }
         }
 
         /// <summary>
         /// Cancels command execution in asynchronous scenarios.
         /// </summary>
-        public void CancelAsync()
+        /// <param name="forceKill">if true send SIGKILL instead of SIGTERM.</param>
+        public void CancelAsync(bool forceKill = false)
         {
-            if (_channel is not null && _channel.IsOpen && _asyncResult is not null)
-            {
-                // TODO: check with Oleg if we shouldn't dispose the channel and uninitialize it ?
-                _channel.Dispose();
-            }
+            var signal = forceKill ? "KILL" : "TERM";
+            _ = _channel?.SendExitSignalRequest(signal, coreDumped: false, "Command execution has been cancelled.", "en");
+            _ = _commandCancelledWaitHandle?.Set();
         }
 
         /// <summary>
@@ -430,14 +437,14 @@ namespace Renci.SshNet
             _ = _sessionErrorOccuredWaitHandle.Set();
         }
 
-        private void Channel_Closed(object sender, ChannelEventArgs e)
+        private void SetAsyncComplete()
         {
             OutputStream?.Flush();
             ExtendedOutputStream?.Flush();
 
             _asyncResult.IsCompleted = true;
 
-            if (_callback is not null)
+            if (_callback is not null && !_isCancelled)
             {
                 // Execute callback on different thread
                 ThreadAbstraction.ExecuteThread(() => _callback(_asyncResult));
@@ -446,6 +453,11 @@ namespace Renci.SshNet
             _ = ((EventWaitHandle) _asyncResult.AsyncWaitHandle).Set();
         }
 
+        private void Channel_Closed(object sender, ChannelEventArgs e)
+        {
+            SetAsyncComplete();
+        }
+
         private void Channel_RequestReceived(object sender, ChannelRequestEventArgs e)
         {
             if (e.Info is ExitStatusRequestInfo exitStatusInfo)
@@ -506,7 +518,8 @@ namespace Renci.SshNet
             var waitHandles = new[]
                 {
                     _sessionErrorOccuredWaitHandle,
-                    waitHandle
+                    waitHandle,
+                    _commandCancelledWaitHandle
                 };
 
             var signaledElement = WaitHandle.WaitAny(waitHandles, CommandTimeout);
@@ -518,6 +531,9 @@ namespace Renci.SshNet
                 case 1:
                     // Specified waithandle was signaled
                     break;
+                case 2:
+                    _isCancelled = true;
+                    break;
                 case WaitHandle.WaitTimeout:
                     throw new SshOperationTimeoutException(string.Format(CultureInfo.CurrentCulture, "Command '{0}' has timed out.", CommandText));
                 default:
@@ -620,6 +636,9 @@ namespace Renci.SshNet
                     _sessionErrorOccuredWaitHandle = null;
                 }
 
+                _commandCancelledWaitHandle?.Dispose();
+                _commandCancelledWaitHandle = null;
+
                 _isDisposed = true;
             }
         }

+ 45 - 2
test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs

@@ -51,6 +51,49 @@ namespace Renci.SshNet.IntegrationTests.OldIntegrationTests
             }
         }
 
+        [TestMethod]
+        [Timeout(5000)]
+        public void Test_CancelAsync_Unfinished_Command()
+        {
+            using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password);
+            #region Example SshCommand CancelAsync Unfinished Command Without Sending exit-signal
+            client.Connect();
+            var testValue = Guid.NewGuid().ToString();
+            var command = $"sleep 15s; echo {testValue}";
+            using var cmd = client.CreateCommand(command);
+            var asyncResult = cmd.BeginExecute();
+            cmd.CancelAsync();
+            Assert.ThrowsException<OperationCanceledException>(() => cmd.EndExecute(asyncResult));
+            Assert.IsTrue(asyncResult.IsCompleted);
+            client.Disconnect();
+            Assert.AreEqual(string.Empty, cmd.Result.Trim());
+            #endregion
+        }
+
+        [TestMethod]
+        public async Task Test_CancelAsync_Finished_Command()
+        {
+            using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password);
+            #region Example SshCommand CancelAsync Finished Command
+            client.Connect();
+            var testValue = Guid.NewGuid().ToString();
+            var command = $"echo {testValue}";
+            using var cmd = client.CreateCommand(command);
+            var asyncResult = cmd.BeginExecute();
+            while (!asyncResult.IsCompleted)
+            {
+                await Task.Delay(200);
+            }
+
+            cmd.CancelAsync();
+            cmd.EndExecute(asyncResult);
+            client.Disconnect();
+
+            Assert.IsTrue(asyncResult.IsCompleted);
+            Assert.AreEqual(testValue, cmd.Result.Trim());
+            #endregion
+        }
+
         [TestMethod]
         public void Test_Execute_OutputStream()
         {
@@ -222,7 +265,7 @@ namespace Renci.SshNet.IntegrationTests.OldIntegrationTests
                 client.Connect();
 
                 var cmd = client.RunCommand("exit 128");
-                
+
                 Console.WriteLine(cmd.ExitStatus);
 
                 client.Disconnect();
@@ -443,7 +486,7 @@ namespace Renci.SshNet.IntegrationTests.OldIntegrationTests
         }
 
         [TestMethod]
-        
+
         public void Test_MultipleThread_100_MultipleConnections()
         {
             try