Selaa lähdekoodia

Override WriteAsync in ShellStream (#1711)

ShellStream does not currently override the Read/Write async variants. They fall back to
the base class implementations which run the sync variants on a thread pool thread, only
allowing one call of either at a time in order to protect implementations that would
break if Read/Write were called simultaneously. In ShellStream, reads and writes are
independent so mutually excluding their use is unnecessary and can lead to effective
deadlocks.

We therefore override WriteAsync to get around this restriction. We do not override
ReadAsync because the sync implementation does not lend itself well to async given the
use of Monitor.Wait/Pulse. Note that while reading and writing simultaneously is allowed,
it is not intended that ShellStream is used with multiple simultaneous reads or multiple
simultaneous writes, so it is fine to keep the base one-at-a-time implementation on
ReadAsync.

Another note is that the new WriteAsync will be simple (synchronous) buffer copying in
most cases, with a call to FlushAsync in others. We also do not override FlushAsync, so
that will go onto a thread pool thread and potentially acquire some locks. But given that
the current base implementation of WriteAsync does that unconditionally, it makes the new
WriteAsync slightly better and certainly no worse than the current version.
Rob Hague 2 päivää sitten
vanhempi
sitoutus
58f61a2398

+ 51 - 0
src/Renci.SshNet/ShellStream.cs

@@ -891,6 +891,57 @@ namespace Renci.SshNet
             Write([value]);
         }
 
+        /// <inheritdoc/>
+        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        {
+#if !NET
+            ThrowHelper.
+#endif
+            ValidateBufferArguments(buffer, offset, count);
+
+            return WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
+        }
+
+#if NET
+        /// <inheritdoc/>
+        public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
+#else
+        private async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
+#endif
+        {
+            ThrowHelper.ThrowObjectDisposedIf(_disposed, this);
+
+            while (!buffer.IsEmpty)
+            {
+                if (_writeBuffer.AvailableLength == 0)
+                {
+                    await FlushAsync(cancellationToken).ConfigureAwait(false);
+                }
+
+                var bytesToCopy = Math.Min(buffer.Length, _writeBuffer.AvailableLength);
+
+                Debug.Assert(bytesToCopy > 0);
+
+                buffer.Slice(0, bytesToCopy).CopyTo(_writeBuffer.AvailableMemory);
+
+                _writeBuffer.Commit(bytesToCopy);
+
+                buffer = buffer.Slice(bytesToCopy);
+            }
+        }
+
+        /// <inheritdoc/>
+        public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
+        {
+            return TaskToAsyncResult.Begin(WriteAsync(buffer, offset, count), callback, state);
+        }
+
+        /// <inheritdoc/>
+        public override void EndWrite(IAsyncResult asyncResult)
+        {
+            TaskToAsyncResult.End(asyncResult);
+        }
+
         /// <summary>
         /// Writes the line to the shell.
         /// </summary>

+ 15 - 0
test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs

@@ -236,6 +236,21 @@ namespace Renci.SshNet.Tests.Classes
             Assert.IsNull(_shellStream.ReadLine());
         }
 
+        [TestMethod]
+        public async Task ReadAsyncDoesNotBlockWriteAsync()
+        {
+            byte[] buffer = new byte[16];
+            Task<int> readTask = _shellStream.ReadAsync(buffer, 0, buffer.Length);
+
+            await _shellStream.WriteAsync("ls\n"u8.ToArray(), 0, 3);
+
+            Assert.IsFalse(readTask.IsCompleted);
+
+            _channelSessionStub.Receive("Directory.Build.props"u8.ToArray());
+
+            await readTask;
+        }
+
         [TestMethod]
         public void Read_MultiByte()
         {