Browse Source

Don't dispose channel when completing SshCommand (#1596)

The new(-ish) implementation of SshCommand has a race condition for short-lived
commands where SSH_MSG_CHANNEL_CLOSE may be processed on the message loop thread
before SSH_MSG_CHANNEL_SUCCESS is waited upon on the Execute (main) thread. This
manifests in an ArgumentNull/NullReference exception on the wait handle because
the channel has already been closed and disposed.

Fix this by only delaying the channel dispose until the command dispose.
Rob Hague 8 months ago
parent
commit
99ef23cd87

+ 21 - 20
src/Renci.SshNet/SshCommand.cs

@@ -21,7 +21,7 @@ namespace Renci.SshNet
         private readonly ISession _session;
         private readonly Encoding _encoding;
 
-        private IChannelSession? _channel;
+        private IChannelSession _channel;
         private TaskCompletionSource<object>? _tcs;
         private CancellationTokenSource? _cts;
         private CancellationTokenRegistration _tokenRegistration;
@@ -142,14 +142,14 @@ namespace Renci.SshNet
         /// </example>
         public Stream CreateInputStream()
         {
-            if (_channel == null)
+            if (!_channel.IsOpen)
             {
-                throw new InvalidOperationException($"The input stream can be used only after calling BeginExecute and before calling EndExecute.");
+                throw new InvalidOperationException("The input stream can be used only during execution.");
             }
 
             if (_inputStream != null)
             {
-                throw new InvalidOperationException($"The input stream already exists.");
+                throw new InvalidOperationException("The input stream already exists.");
             }
 
             _inputStream = new ChannelInputStream(_channel);
@@ -226,6 +226,7 @@ namespace Renci.SshNet
             ExtendedOutputStream = new PipeStream();
             _session.Disconnected += Session_Disconnected;
             _session.ErrorOccured += Session_ErrorOccured;
+            _channel = _session.CreateChannelSession();
         }
 
         /// <summary>
@@ -257,6 +258,8 @@ namespace Renci.SshNet
                     throw new InvalidOperationException("Asynchronous operation is already in progress.");
                 }
 
+                UnsubscribeFromChannelEvents(dispose: true);
+
                 OutputStream.Dispose();
                 ExtendedOutputStream.Dispose();
 
@@ -265,6 +268,7 @@ namespace Renci.SshNet
                 // so we just need to reinitialise them for subsequent executions.
                 OutputStream = new PipeStream();
                 ExtendedOutputStream = new PipeStream();
+                _channel = _session.CreateChannelSession();
             }
 
             _exitStatus = default;
@@ -282,7 +286,6 @@ namespace Renci.SshNet
             _tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
             _userToken = cancellationToken;
 
-            _channel = _session.CreateChannelSession();
             _channel.DataReceived += Channel_DataReceived;
             _channel.ExtendedDataReceived += Channel_ExtendedDataReceived;
             _channel.RequestReceived += Channel_RequestReceived;
@@ -542,7 +545,10 @@ namespace Renci.SshNet
                 }
             }
 
-            UnsubscribeFromEventsAndDisposeChannel();
+            // We don't dispose the channel here to avoid a race condition
+            // where SSH_MSG_CHANNEL_CLOSE arrives before _channel starts
+            // waiting for a response in _channel.SendExecRequest().
+            UnsubscribeFromChannelEvents(dispose: false);
 
             OutputStream.Dispose();
             ExtendedOutputStream.Dispose();
@@ -568,7 +574,7 @@ namespace Renci.SshNet
 
                 Debug.Assert(!exitSignalInfo.WantReply, "exit-signal is want_reply := false by definition.");
             }
-            else if (e.Info.WantReply && _channel?.RemoteChannelNumber is uint remoteChannelNumber)
+            else if (e.Info.WantReply && sender is IChannel { RemoteChannelNumber: uint remoteChannelNumber })
             {
                 var replyMessage = new ChannelFailureMessage(remoteChannelNumber);
                 _session.SendMessage(replyMessage);
@@ -591,20 +597,13 @@ namespace Renci.SshNet
         }
 
         /// <summary>
-        /// Unsubscribes the current <see cref="SshCommand"/> from channel events, and disposes
-        /// the <see cref="_channel"/>.
+        /// Unsubscribes the current <see cref="SshCommand"/> from channel events, and optionally,
+        /// disposes <see cref="_channel"/>.
         /// </summary>
-        private void UnsubscribeFromEventsAndDisposeChannel()
+        private void UnsubscribeFromChannelEvents(bool dispose)
         {
             var channel = _channel;
 
-            if (channel is null)
-            {
-                return;
-            }
-
-            _channel = null;
-
             // unsubscribe from events as we do not want to be signaled should these get fired
             // during the dispose of the channel
             channel.DataReceived -= Channel_DataReceived;
@@ -612,8 +611,10 @@ namespace Renci.SshNet
             channel.RequestReceived -= Channel_RequestReceived;
             channel.Closed -= Channel_Closed;
 
-            // actually dispose the channel
-            channel.Dispose();
+            if (dispose)
+            {
+                channel.Dispose();
+            }
         }
 
         /// <summary>
@@ -645,7 +646,7 @@ namespace Renci.SshNet
 
                 // unsubscribe from channel events to ensure other objects that we're going to dispose
                 // are not accessed while disposing
-                UnsubscribeFromEventsAndDisposeChannel();
+                UnsubscribeFromChannelEvents(dispose: true);
 
                 _inputStream?.Dispose();
                 _inputStream = null;

+ 0 - 72
test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute.cs

@@ -1,72 +0,0 @@
-using System;
-using System.Globalization;
-using System.Text;
-
-using Microsoft.VisualStudio.TestTools.UnitTesting;
-
-using Moq;
-
-using Renci.SshNet.Channels;
-using Renci.SshNet.Common;
-using Renci.SshNet.Tests.Common;
-
-namespace Renci.SshNet.Tests.Classes
-{
-    [TestClass]
-    public class SshCommand_EndExecute : TestBase
-    {
-        private Mock<ISession> _sessionMock;
-        private Mock<IChannelSession> _channelSessionMock;
-        private string _commandText;
-        private Encoding _encoding;
-        private SshCommand _sshCommand;
-
-        protected override void OnInit()
-        {
-            base.OnInit();
-
-            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
-            _commandText = new Random().Next().ToString(CultureInfo.InvariantCulture);
-            _encoding = Encoding.UTF8;
-            _channelSessionMock = new Mock<IChannelSession>(MockBehavior.Strict);
-
-            _sshCommand = new SshCommand(_sessionMock.Object, _commandText, _encoding);
-        }
-
-        [TestMethod]
-        public void EndExecute_ChannelClosed_ShouldDisposeChannelSession()
-        {
-            var seq = new MockSequence();
-
-            _sessionMock.InSequence(seq).Setup(p => p.CreateChannelSession()).Returns(_channelSessionMock.Object);
-            _channelSessionMock.InSequence(seq).Setup(p => p.Open());
-            _channelSessionMock.InSequence(seq).Setup(p => p.SendExecRequest(_commandText))
-                .Returns(true)
-                .Raises(c => c.Closed += null, new ChannelEventArgs(5));
-            _channelSessionMock.InSequence(seq).Setup(p => p.Dispose());
-
-            var asyncResult = _sshCommand.BeginExecute();
-            _sshCommand.EndExecute(asyncResult);
-
-            _channelSessionMock.Verify(p => p.Dispose(), Times.Once);
-        }
-
-        [TestMethod]
-        public void EndExecute_ChannelOpen_ShouldSendEofAndCloseAndDisposeChannelSession()
-        {
-            var seq = new MockSequence();
-
-            _sessionMock.InSequence(seq).Setup(p => p.CreateChannelSession()).Returns(_channelSessionMock.Object);
-            _channelSessionMock.InSequence(seq).Setup(p => p.Open());
-            _channelSessionMock.InSequence(seq).Setup(p => p.SendExecRequest(_commandText))
-                .Returns(true)
-                .Raises(c => c.Closed += null, new ChannelEventArgs(5));
-            _channelSessionMock.InSequence(seq).Setup(p => p.Dispose());
-
-            var asyncResult = _sshCommand.BeginExecute();
-            _sshCommand.EndExecute(asyncResult);
-
-            _channelSessionMock.Verify(p => p.Dispose(), Times.Once);
-        }
-    }
-}

+ 1 - 1
test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_AsyncResultIsNull.cs

@@ -30,7 +30,7 @@ namespace Renci.SshNet.Tests.Classes
 
         private void Arrange()
         {
-            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>();
             _commandText = new Random().Next().ToString(CultureInfo.InvariantCulture);
             _encoding = Encoding.UTF8;
             _asyncResult = null;

+ 0 - 6
test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_ChannelOpen.cs

@@ -81,12 +81,6 @@ namespace Renci.SshNet.Tests.Classes
             _actual = _sshCommand.EndExecute(_asyncResult);
         }
 
-        [TestMethod]
-        public void ChannelSessionShouldBeDisposedOnce()
-        {
-            _channelSessionMock.Verify(p => p.Dispose(), Times.Once);
-        }
-
         [TestMethod]
         public void EndExecuteShouldReturnAllDataReceivedInSpecifiedEncoding()
         {