Browse Source

Override more methods in PipeStream, ShellStream (#1637)

Where beneficial, add additional overrides from the base Stream class. Namely the Span
variants and for PipeStream, the WriteAsync variants (see comments).

The change also adds an internal type borrowed from the runtime repo for easier buffer
management, which could also be used elsewhere.
Rob Hague 5 months ago
parent
commit
6039e121a8

+ 1 - 0
Directory.Build.props

@@ -9,6 +9,7 @@
     <SignAssembly>true</SignAssembly>
     <AssemblyOriginatorKeyFile>$(MSBuildThisFileDirectory)Renci.SshNet.snk</AssemblyOriginatorKeyFile>
     <GenerateDocumentationFile>true</GenerateDocumentationFile>
+    <EnablePackageValidation>true</EnablePackageValidation>
     <LangVersion>latest</LangVersion>
     <WarningLevel>9999</WarningLevel>
   </PropertyGroup>

+ 10 - 2
src/Renci.SshNet/Abstractions/StreamExtensions.cs

@@ -1,4 +1,5 @@
 #if NETFRAMEWORK || NETSTANDARD2_0
+using System;
 using System.IO;
 using System.Threading.Tasks;
 
@@ -8,8 +9,15 @@ namespace Renci.SshNet.Abstractions
     {
         public static ValueTask DisposeAsync(this Stream stream)
         {
-            stream.Dispose();
-            return default;
+            try
+            {
+                stream.Dispose();
+                return default;
+            }
+            catch (Exception exc)
+            {
+                return new ValueTask(Task.FromException(exc));
+            }
         }
     }
 }

+ 200 - 0
src/Renci.SshNet/Common/ArrayBuffer.cs

@@ -0,0 +1,200 @@
+#pragma warning disable
+// Copied verbatim from https://github.com/dotnet/runtime/blob/d2650b6ae7023a2d9d2c74c56116f1f18472ab04/src/libraries/Common/src/System/Net/ArrayBuffer.cs
+
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+
+namespace System.Net
+{
+    // Warning: Mutable struct!
+    // The purpose of this struct is to simplify buffer management.
+    // It manages a sliding buffer where bytes can be added at the end and removed at the beginning.
+    // [ActiveSpan/Memory] contains the current buffer contents; these bytes will be preserved
+    // (copied, if necessary) on any call to EnsureAvailableBytes.
+    // [AvailableSpan/Memory] contains the available bytes past the end of the current content,
+    // and can be written to in order to add data to the end of the buffer.
+    // Commit(byteCount) will extend the ActiveSpan by [byteCount] bytes into the AvailableSpan.
+    // Discard(byteCount) will discard [byteCount] bytes as the beginning of the ActiveSpan.
+
+    [StructLayout(LayoutKind.Auto)]
+    internal struct ArrayBuffer : IDisposable
+    {
+#if NET
+        private static int ArrayMaxLength => Array.MaxLength;
+#else
+        private const int ArrayMaxLength = 0X7FFFFFC7;
+#endif
+
+        private readonly bool _usePool;
+        private byte[] _bytes;
+        private int _activeStart;
+        private int _availableStart;
+
+        // Invariants:
+        // 0 <= _activeStart <= _availableStart <= bytes.Length
+
+        public ArrayBuffer(int initialSize, bool usePool = false)
+        {
+            Debug.Assert(initialSize > 0 || usePool);
+
+            _usePool = usePool;
+            _bytes = initialSize == 0
+                ? Array.Empty<byte>()
+                : usePool ? ArrayPool<byte>.Shared.Rent(initialSize) : new byte[initialSize];
+            _activeStart = 0;
+            _availableStart = 0;
+        }
+
+        public ArrayBuffer(byte[] buffer)
+        {
+            Debug.Assert(buffer.Length > 0);
+
+            _usePool = false;
+            _bytes = buffer;
+            _activeStart = 0;
+            _availableStart = 0;
+        }
+
+        public void Dispose()
+        {
+            _activeStart = 0;
+            _availableStart = 0;
+
+            byte[] array = _bytes;
+            _bytes = null!;
+
+            if (array is not null)
+            {
+                ReturnBufferIfPooled(array);
+            }
+        }
+
+        // This is different from Dispose as the instance remains usable afterwards (_bytes will not be null).
+        public void ClearAndReturnBuffer()
+        {
+            Debug.Assert(_usePool);
+            Debug.Assert(_bytes is not null);
+
+            _activeStart = 0;
+            _availableStart = 0;
+
+            byte[] bufferToReturn = _bytes;
+            _bytes = Array.Empty<byte>();
+            ReturnBufferIfPooled(bufferToReturn);
+        }
+
+        public int ActiveLength => _availableStart - _activeStart;
+        public Span<byte> ActiveSpan => new Span<byte>(_bytes, _activeStart, _availableStart - _activeStart);
+        public ReadOnlySpan<byte> ActiveReadOnlySpan => new ReadOnlySpan<byte>(_bytes, _activeStart, _availableStart - _activeStart);
+        public Memory<byte> ActiveMemory => new Memory<byte>(_bytes, _activeStart, _availableStart - _activeStart);
+
+        public int AvailableLength => _bytes.Length - _availableStart;
+        public Span<byte> AvailableSpan => _bytes.AsSpan(_availableStart);
+        public Memory<byte> AvailableMemory => _bytes.AsMemory(_availableStart);
+        public Memory<byte> AvailableMemorySliced(int length) => new Memory<byte>(_bytes, _availableStart, length);
+
+        public int Capacity => _bytes.Length;
+        public int ActiveStartOffset => _activeStart;
+
+        public byte[] DangerousGetUnderlyingBuffer() => _bytes;
+
+        public void Discard(int byteCount)
+        {
+            Debug.Assert(byteCount <= ActiveLength, $"Expected {byteCount} <= {ActiveLength}");
+            _activeStart += byteCount;
+
+            if (_activeStart == _availableStart)
+            {
+                _activeStart = 0;
+                _availableStart = 0;
+            }
+        }
+
+        public void Commit(int byteCount)
+        {
+            Debug.Assert(byteCount <= AvailableLength);
+            _availableStart += byteCount;
+        }
+
+        // Ensure at least [byteCount] bytes to write to.
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public void EnsureAvailableSpace(int byteCount)
+        {
+            if (byteCount > AvailableLength)
+            {
+                EnsureAvailableSpaceCore(byteCount);
+            }
+        }
+
+        private void EnsureAvailableSpaceCore(int byteCount)
+        {
+            Debug.Assert(AvailableLength < byteCount);
+
+            if (_bytes.Length == 0)
+            {
+                Debug.Assert(_usePool && _activeStart == 0 && _availableStart == 0);
+                _bytes = ArrayPool<byte>.Shared.Rent(byteCount);
+                return;
+            }
+
+            int totalFree = _activeStart + AvailableLength;
+            if (byteCount <= totalFree)
+            {
+                // We can free up enough space by just shifting the bytes down, so do so.
+                Buffer.BlockCopy(_bytes, _activeStart, _bytes, 0, ActiveLength);
+                _availableStart = ActiveLength;
+                _activeStart = 0;
+                Debug.Assert(byteCount <= AvailableLength);
+                return;
+            }
+
+            int desiredSize = ActiveLength + byteCount;
+
+            if ((uint)desiredSize > ArrayMaxLength)
+            {
+                throw new OutOfMemoryException();
+            }
+
+            // Double the existing buffer size (capped at Array.MaxLength).
+            int newSize = Math.Max(desiredSize, (int)Math.Min(ArrayMaxLength, 2 * (uint)_bytes.Length));
+
+            byte[] newBytes = _usePool ?
+                ArrayPool<byte>.Shared.Rent(newSize) :
+                new byte[newSize];
+            byte[] oldBytes = _bytes;
+
+            if (ActiveLength != 0)
+            {
+                Buffer.BlockCopy(oldBytes, _activeStart, newBytes, 0, ActiveLength);
+            }
+
+            _availableStart = ActiveLength;
+            _activeStart = 0;
+
+            _bytes = newBytes;
+            ReturnBufferIfPooled(oldBytes);
+
+            Debug.Assert(byteCount <= AvailableLength);
+        }
+
+        public void Grow()
+        {
+            EnsureAvailableSpaceCore(AvailableLength + 1);
+        }
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private void ReturnBufferIfPooled(byte[] buffer)
+        {
+            // The buffer may be Array.Empty<byte>()
+            if (_usePool && buffer.Length > 0)
+            {
+                ArrayPool<byte>.Shared.Return(buffer);
+            }
+        }
+    }
+}

+ 142 - 71
src/Renci.SshNet/Common/PipeStream.cs

@@ -3,6 +3,7 @@ using System;
 using System.Diagnostics;
 using System.IO;
 using System.Threading;
+using System.Threading.Tasks;
 
 namespace Renci.SshNet.Common
 {
@@ -14,31 +15,9 @@ namespace Renci.SshNet.Common
     {
         private readonly object _sync = new object();
 
-        private byte[] _buffer = new byte[1024];
-        private int _head; // The index from which the data starts in _buffer.
-        private int _tail; // The index at which to add new data into _buffer.
+        private System.Net.ArrayBuffer _buffer = new(1024);
         private bool _disposed;
 
-#pragma warning disable MA0076 // Do not use implicit culture-sensitive ToString in interpolated strings
-        [Conditional("DEBUG")]
-        private void AssertValid()
-        {
-            Debug.Assert(Monitor.IsEntered(_sync), $"Should be in lock on {nameof(_sync)}");
-            Debug.Assert(_head >= 0, $"{nameof(_head)} should be non-negative but is {_head}");
-            Debug.Assert(_tail >= 0, $"{nameof(_tail)} should be non-negative but is {_tail}");
-            Debug.Assert(_head <= _buffer.Length, $"{nameof(_head)} should be <= {nameof(_buffer)}.Length but is {_head}");
-            Debug.Assert(_tail <= _buffer.Length, $"{nameof(_tail)} should be <= {nameof(_buffer)}.Length but is {_tail}");
-            Debug.Assert(_head <= _tail, $"Should have {nameof(_head)} <= {nameof(_tail)} but have {_head} <= {_tail}");
-        }
-#pragma warning restore MA0076 // Do not use implicit culture-sensitive ToString in interpolated strings
-
-        /// <summary>
-        /// This method does nothing.
-        /// </summary>
-        public override void Flush()
-        {
-        }
-
         /// <summary>
         /// This method always throws <see cref="NotSupportedException"/>.
         /// </summary>
@@ -69,27 +48,43 @@ namespace Renci.SshNet.Common
 #endif
             ValidateBufferArguments(buffer, offset, count);
 
+            return Read(buffer.AsSpan(offset, count));
+        }
+
+#if NETSTANDARD2_1 || NET
+        /// <inheritdoc/>
+        public override int Read(Span<byte> buffer)
+#else
+        private int Read(Span<byte> buffer)
+#endif
+        {
             lock (_sync)
             {
-                while (_head == _tail && !_disposed)
+                while (_buffer.ActiveLength == 0 && !_disposed)
                 {
                     _ = Monitor.Wait(_sync);
                 }
 
-                AssertValid();
+                var bytesRead = Math.Min(buffer.Length, _buffer.ActiveLength);
 
-                var bytesRead = Math.Min(count, _tail - _head);
+                _buffer.ActiveReadOnlySpan.Slice(0, bytesRead).CopyTo(buffer);
 
-                Buffer.BlockCopy(_buffer, _head, buffer, offset, bytesRead);
-
-                _head += bytesRead;
-
-                AssertValid();
+                _buffer.Discard(bytesRead);
 
                 return bytesRead;
             }
         }
 
+#if NET
+        /// <inheritdoc/>
+        public override int ReadByte()
+        {
+            byte b = default;
+            var read = Read(new Span<byte>(ref b));
+            return read == 0 ? -1 : b;
+        }
+#endif
+
         /// <inheritdoc/>
         public override void Write(byte[] buffer, int offset, int count)
         {
@@ -100,50 +95,127 @@ namespace Renci.SshNet.Common
 
             lock (_sync)
             {
-                ThrowHelper.ThrowObjectDisposedIf(_disposed, this);
+                WriteCore(buffer.AsSpan(offset, count));
+            }
+        }
+
+#if NETSTANDARD2_1 || NET
+        /// <inheritdoc/>
+        public override void Write(ReadOnlySpan<byte> buffer)
+        {
+            lock (_sync)
+            {
+                WriteCore(buffer);
+            }
+        }
+#endif
 
-                AssertValid();
+        /// <inheritdoc/>
+        public override void WriteByte(byte value)
+        {
+            lock (_sync)
+            {
+                WriteCore([value]);
+            }
+        }
 
-                // Ensure sufficient buffer space and copy the new data in.
+        private void WriteCore(ReadOnlySpan<byte> buffer)
+        {
+            Debug.Assert(Monitor.IsEntered(_sync));
 
-                if (_buffer.Length - _tail >= count)
-                {
-                    // If there is enough space after _tail for the new data,
-                    // then copy the data there.
-                    Buffer.BlockCopy(buffer, offset, _buffer, _tail, count);
-                    _tail += count;
-                }
-                else
-                {
-                    // We can't fit the new data after _tail.
-
-                    var newLength = _tail - _head + count;
-
-                    if (newLength <= _buffer.Length)
-                    {
-                        // If there is sufficient space at the start of the buffer,
-                        // then move the current data to the start of the buffer.
-                        Buffer.BlockCopy(_buffer, _head, _buffer, 0, _tail - _head);
-                    }
-                    else
-                    {
-                        // Otherwise, we're gonna need a bigger buffer.
-                        var newBuffer = new byte[Math.Max(newLength, _buffer.Length * 2)];
-                        Buffer.BlockCopy(_buffer, _head, newBuffer, 0, _tail - _head);
-                        _buffer = newBuffer;
-                    }
-
-                    // Copy the new data into the freed-up space.
-                    Buffer.BlockCopy(buffer, offset, _buffer, _tail - _head, count);
-
-                    _head = 0;
-                    _tail = newLength;
-                }
+            ThrowHelper.ThrowObjectDisposedIf(_disposed, this);
 
-                AssertValid();
+            _buffer.EnsureAvailableSpace(buffer.Length);
 
-                Monitor.PulseAll(_sync);
+            buffer.CopyTo(_buffer.AvailableSpan);
+
+            _buffer.Commit(buffer.Length);
+
+            Monitor.PulseAll(_sync);
+        }
+
+        // We provide overrides for async Write methods but not async Read.
+        // The default implementations from the base class effectively call the
+        // sync methods on a threadpool thread, but only allowing one async
+        // operation at a time (for protecting thread-unsafe implementations).
+        // This constraint is desirable for reads because if there were multiple
+        // readers and no data coming in, our current Monitor.Wait implementation
+        // would just block as many threadpool threads as there are readers.
+        // But since a write is just short-lived buffer copying and can unblock
+        // readers, it is beneficial to circumvent the one-at-a-time constraint,
+        // as otherwise a waiting async read will block the async write that could
+        // unblock it.
+
+        /// <inheritdoc/>
+        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        {
+#if !NET
+            ThrowHelper.
+#endif
+            ValidateBufferArguments(buffer, offset, count);
+
+            return WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
+        }
+
+#if NETSTANDARD2_1 || NET
+        /// <inheritdoc/>
+        public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
+#else
+        private async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
+#endif
+        {
+            cancellationToken.ThrowIfCancellationRequested();
+
+            if (!Monitor.TryEnter(_sync))
+            {
+                // If we cannot immediately enter the lock and complete the write
+                // synchronously, then go async and wait for it there.
+                // This is not great! But since there is very little work being
+                // done under the lock, this should be a rare case and we should
+                // not be blocking threads for long.
+
+                await Task.Yield();
+
+                Monitor.Enter(_sync);
             }
+
+            try
+            {
+                WriteCore(buffer.Span);
+            }
+            finally
+            {
+                Monitor.Exit(_sync);
+            }
+        }
+
+        /// <inheritdoc/>
+        public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
+        {
+            return TaskToAsyncResult.Begin(WriteAsync(buffer, offset, count), callback, state);
+        }
+
+        /// <inheritdoc/>
+        public override void EndWrite(IAsyncResult asyncResult)
+        {
+            TaskToAsyncResult.End(asyncResult);
+        }
+
+        /// <summary>
+        /// This method does nothing.
+        /// </summary>
+        public override void Flush()
+        {
+        }
+
+        /// <summary>
+        /// This method does nothing.
+        /// </summary>
+        /// <param name="cancellationToken">Unobserved cancellation token.</param>
+        /// <returns><see cref="Task.CompletedTask"/>.</returns>
+        public override Task FlushAsync(CancellationToken cancellationToken)
+        {
+            return Task.CompletedTask;
         }
 
         /// <inheritdoc/>
@@ -221,8 +293,7 @@ namespace Renci.SshNet.Common
             {
                 lock (_sync)
                 {
-                    AssertValid();
-                    return _tail - _head;
+                    return _buffer.ActiveLength;
                 }
             }
         }

+ 103 - 131
src/Renci.SshNet/ShellStream.cs

@@ -2,7 +2,6 @@
 using System;
 using System.Collections.Generic;
 using System.Diagnostics;
-using System.Globalization;
 using System.IO;
 using System.Text;
 using System.Text.RegularExpressions;
@@ -27,16 +26,13 @@ namespace Renci.SshNet
         private readonly IChannelSession _channel;
         private readonly byte[] _carriageReturnBytes;
         private readonly byte[] _lineFeedBytes;
+        private readonly bool _noTerminal;
 
         private readonly object _sync = new object();
 
-        private readonly byte[] _writeBuffer;
-        private readonly bool _noTerminal;
-        private int _writeLength; // The length of the data in _writeBuffer.
+        private System.Net.ArrayBuffer _readBuffer;
+        private System.Net.ArrayBuffer _writeBuffer;
 
-        private byte[] _readBuffer;
-        private int _readHead; // The index from which the data starts in _readBuffer.
-        private int _readTail; // The index at which to add new data into _readBuffer.
         private bool _disposed;
 
         /// <summary>
@@ -66,23 +62,11 @@ namespace Renci.SshNet
             {
                 lock (_sync)
                 {
-                    AssertValid();
-                    return _readTail != _readHead;
+                    return _readBuffer.ActiveLength > 0;
                 }
             }
         }
 
-        [Conditional("DEBUG")]
-        private void AssertValid()
-        {
-            Debug.Assert(Monitor.IsEntered(_sync), $"Should be in lock on {nameof(_sync)}");
-            Debug.Assert(_readHead >= 0, $"{nameof(_readHead)} should be non-negative but is {_readHead.ToString(CultureInfo.InvariantCulture)}");
-            Debug.Assert(_readTail >= 0, $"{nameof(_readTail)} should be non-negative but is {_readTail.ToString(CultureInfo.InvariantCulture)}");
-            Debug.Assert(_readHead <= _readBuffer.Length, $"{nameof(_readHead)} should be <= {nameof(_readBuffer)}.Length but is {_readHead.ToString(CultureInfo.InvariantCulture)}");
-            Debug.Assert(_readTail <= _readBuffer.Length, $"{nameof(_readTail)} should be <= {nameof(_readBuffer)}.Length but is {_readTail.ToString(CultureInfo.InvariantCulture)}");
-            Debug.Assert(_readHead <= _readTail, $"Should have {nameof(_readHead)} <= {nameof(_readTail)} but have {_readHead.ToString(CultureInfo.InvariantCulture)} <= {_readTail.ToString(CultureInfo.InvariantCulture)}");
-        }
-
         /// <summary>
         /// Initializes a new instance of the <see cref="ShellStream"/> class.
         /// </summary>
@@ -180,8 +164,8 @@ namespace Renci.SshNet
             _session.Disconnected += Session_Disconnected;
             _session.ErrorOccured += Session_ErrorOccured;
 
-            _readBuffer = new byte[bufferSize];
-            _writeBuffer = new byte[bufferSize];
+            _readBuffer = new System.Net.ArrayBuffer(bufferSize);
+            _writeBuffer = new System.Net.ArrayBuffer(bufferSize);
 
             _noTerminal = noTerminal;
         }
@@ -233,12 +217,14 @@ namespace Renci.SshNet
         {
             ThrowHelper.ThrowObjectDisposedIf(_disposed, this);
 
-            Debug.Assert(_writeLength >= 0 && _writeLength <= _writeBuffer.Length);
-
-            if (_writeLength > 0)
+            if (_writeBuffer.ActiveLength > 0)
             {
-                _channel.SendData(_writeBuffer, 0, _writeLength);
-                _writeLength = 0;
+                _channel.SendData(
+                    _writeBuffer.DangerousGetUnderlyingBuffer(),
+                    _writeBuffer.ActiveStartOffset,
+                    _writeBuffer.ActiveLength);
+
+                _writeBuffer.Discard(_writeBuffer.ActiveLength);
             }
         }
 
@@ -252,8 +238,7 @@ namespace Renci.SshNet
             {
                 lock (_sync)
                 {
-                    AssertValid();
-                    return _readTail - _readHead;
+                    return _readBuffer.ActiveLength;
                 }
             }
         }
@@ -385,23 +370,19 @@ namespace Renci.SshNet
             {
                 while (true)
                 {
-                    AssertValid();
-
                     var searchHead = lookback == -1
-                        ? _readHead
-                        : Math.Max(_readTail - lookback, _readHead);
-
-                    Debug.Assert(_readHead <= searchHead && searchHead <= _readTail);
+                        ? 0
+                        : Math.Max(0, _readBuffer.ActiveLength - lookback);
 
-                    var indexOfMatch = _readBuffer.AsSpan(searchHead, _readTail - searchHead).IndexOf(expectBytes);
+                    var indexOfMatch = _readBuffer.ActiveReadOnlySpan.Slice(searchHead).IndexOf(expectBytes);
 
                     if (indexOfMatch >= 0)
                     {
-                        var returnText = _encoding.GetString(_readBuffer, _readHead, searchHead - _readHead + indexOfMatch + expectBytes.Length);
+                        var readLength = searchHead + indexOfMatch + expectBytes.Length;
 
-                        _readHead = searchHead + indexOfMatch + expectBytes.Length;
+                        var returnText = GetString(readLength);
 
-                        AssertValid();
+                        _readBuffer.Discard(readLength);
 
                         return returnText;
                     }
@@ -471,9 +452,7 @@ namespace Renci.SshNet
             {
                 while (true)
                 {
-                    AssertValid();
-
-                    var bufferText = _encoding.GetString(_readBuffer, _readHead, _readTail - _readHead);
+                    var bufferText = GetString(_readBuffer.ActiveLength);
 
                     var searchStart = lookback == -1
                         ? 0
@@ -496,9 +475,7 @@ namespace Renci.SshNet
                         {
                             var returnText = bufferText.Substring(0, match.Index + match.Length);
 #endif
-                            _readHead += _encoding.GetByteCount(returnText);
-
-                            AssertValid();
+                            _readBuffer.Discard(_encoding.GetByteCount(returnText));
 
                             expectAction.Action(returnText);
 
@@ -659,48 +636,40 @@ namespace Renci.SshNet
             {
                 while (true)
                 {
-                    AssertValid();
-
-                    var indexOfCr = _readBuffer.AsSpan(_readHead, _readTail - _readHead).IndexOf(_carriageReturnBytes);
+                    var indexOfCr = _readBuffer.ActiveReadOnlySpan.IndexOf(_carriageReturnBytes);
 
                     if (indexOfCr >= 0)
                     {
                         // We have found \r. We only need to search for \n up to and just after the \r
                         // (in order to consume \r\n if we can).
-                        var indexOfLf = indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length <= _readTail - _readHead
-                            ? _readBuffer.AsSpan(_readHead, indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length).IndexOf(_lineFeedBytes)
-                            : _readBuffer.AsSpan(_readHead, indexOfCr).IndexOf(_lineFeedBytes);
+                        var indexOfLf = indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length <= _readBuffer.ActiveLength
+                            ? _readBuffer.ActiveReadOnlySpan.Slice(0, indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length).IndexOf(_lineFeedBytes)
+                            : _readBuffer.ActiveReadOnlySpan.Slice(0, indexOfCr).IndexOf(_lineFeedBytes);
 
                         if (indexOfLf >= 0 && indexOfLf < indexOfCr)
                         {
                             // If there is \n before the \r, then return up to the \n
-                            var returnText = _encoding.GetString(_readBuffer, _readHead, indexOfLf);
-
-                            _readHead += indexOfLf + _lineFeedBytes.Length;
+                            var returnText = GetString(indexOfLf);
 
-                            AssertValid();
+                            _readBuffer.Discard(indexOfLf + _lineFeedBytes.Length);
 
                             return returnText;
                         }
                         else if (indexOfLf == indexOfCr + _carriageReturnBytes.Length)
                         {
                             // If we have \r\n, then consume both
-                            var returnText = _encoding.GetString(_readBuffer, _readHead, indexOfCr);
+                            var returnText = GetString(indexOfCr);
 
-                            _readHead += indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length;
-
-                            AssertValid();
+                            _readBuffer.Discard(indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length);
 
                             return returnText;
                         }
                         else
                         {
                             // Return up to the \r
-                            var returnText = _encoding.GetString(_readBuffer, _readHead, indexOfCr);
-
-                            _readHead += indexOfCr + _carriageReturnBytes.Length;
+                            var returnText = GetString(indexOfCr);
 
-                            AssertValid();
+                            _readBuffer.Discard(indexOfCr + _carriageReturnBytes.Length);
 
                             return returnText;
                         }
@@ -708,15 +677,13 @@ namespace Renci.SshNet
                     else
                     {
                         // There is no \r. What about \n?
-                        var indexOfLf = _readBuffer.AsSpan(_readHead, _readTail - _readHead).IndexOf(_lineFeedBytes);
+                        var indexOfLf = _readBuffer.ActiveReadOnlySpan.IndexOf(_lineFeedBytes);
 
                         if (indexOfLf >= 0)
                         {
-                            var returnText = _encoding.GetString(_readBuffer, _readHead, indexOfLf);
+                            var returnText = GetString(indexOfLf);
 
-                            _readHead += indexOfLf + _lineFeedBytes.Length;
-
-                            AssertValid();
+                            _readBuffer.Discard(indexOfLf + _lineFeedBytes.Length);
 
                             return returnText;
                         }
@@ -724,11 +691,11 @@ namespace Renci.SshNet
 
                     if (_disposed)
                     {
-                        var lastLine = _readHead == _readTail
+                        var lastLine = _readBuffer.ActiveLength == 0
                             ? null
-                            : _encoding.GetString(_readBuffer, _readHead, _readTail - _readHead);
+                            : GetString(_readBuffer.ActiveLength);
 
-                        _readHead = _readTail = 0;
+                        _readBuffer.Discard(_readBuffer.ActiveLength);
 
                         return lastLine;
                     }
@@ -776,11 +743,9 @@ namespace Renci.SshNet
         {
             lock (_sync)
             {
-                AssertValid();
-
-                var text = _encoding.GetString(_readBuffer, _readHead, _readTail - _readHead);
+                var text = GetString(_readBuffer.ActiveLength);
 
-                _readHead = _readTail = 0;
+                _readBuffer.Discard(_readBuffer.ActiveLength);
 
                 return text;
             }
@@ -794,27 +759,54 @@ namespace Renci.SshNet
 #endif
             ValidateBufferArguments(buffer, offset, count);
 
+            return Read(buffer.AsSpan(offset, count));
+        }
+
+#if NETSTANDARD2_1 || NET
+        /// <inheritdoc/>
+        public override int Read(Span<byte> buffer)
+#else
+        private int Read(Span<byte> buffer)
+#endif
+        {
             lock (_sync)
             {
-                while (_readHead == _readTail && !_disposed)
+                while (_readBuffer.ActiveLength == 0 && !_disposed)
                 {
                     _ = Monitor.Wait(_sync);
                 }
 
-                AssertValid();
+                var bytesRead = Math.Min(buffer.Length, _readBuffer.ActiveLength);
 
-                var bytesRead = Math.Min(count, _readTail - _readHead);
+                _readBuffer.ActiveReadOnlySpan.Slice(0, bytesRead).CopyTo(buffer);
 
-                Buffer.BlockCopy(_readBuffer, _readHead, buffer, offset, bytesRead);
-
-                _readHead += bytesRead;
-
-                AssertValid();
+                _readBuffer.Discard(bytesRead);
 
                 return bytesRead;
             }
         }
 
+#if NET
+        /// <inheritdoc/>
+        public override int ReadByte()
+        {
+            byte b = default;
+            var read = Read(new Span<byte>(ref b));
+            return read == 0 ? -1 : b;
+        }
+#endif
+
+        private string GetString(int length)
+        {
+            Debug.Assert(Monitor.IsEntered(_sync));
+            Debug.Assert(length <= _readBuffer.ActiveLength);
+
+            return _encoding.GetString(
+                _readBuffer.DangerousGetUnderlyingBuffer(),
+                _readBuffer.ActiveStartOffset,
+                length);
+        }
+
         /// <summary>
         /// Writes the specified text to the shell.
         /// </summary>
@@ -831,9 +823,7 @@ namespace Renci.SshNet
                 return;
             }
 
-            var data = _encoding.GetBytes(text);
-
-            Write(data, 0, data.Length);
+            Write(_encoding.GetBytes(text));
             Flush();
         }
 
@@ -845,27 +835,43 @@ namespace Renci.SshNet
 #endif
             ValidateBufferArguments(buffer, offset, count);
 
+            Write(buffer.AsSpan(offset, count));
+        }
+
+#if NETSTANDARD2_1 || NET
+        /// <inheritdoc/>
+        public override void Write(ReadOnlySpan<byte> buffer)
+#else
+        private void Write(ReadOnlySpan<byte> buffer)
+#endif
+        {
             ThrowHelper.ThrowObjectDisposedIf(_disposed, this);
 
-            while (count > 0)
+            while (!buffer.IsEmpty)
             {
-                if (_writeLength == _writeBuffer.Length)
+                if (_writeBuffer.AvailableLength == 0)
                 {
                     Flush();
                 }
 
-                var bytesToCopy = Math.Min(count, _writeBuffer.Length - _writeLength);
+                var bytesToCopy = Math.Min(buffer.Length, _writeBuffer.AvailableLength);
 
-                Buffer.BlockCopy(buffer, offset, _writeBuffer, _writeLength, bytesToCopy);
+                Debug.Assert(bytesToCopy > 0);
 
-                offset += bytesToCopy;
-                count -= bytesToCopy;
-                _writeLength += bytesToCopy;
+                buffer.Slice(0, bytesToCopy).CopyTo(_writeBuffer.AvailableSpan);
 
-                Debug.Assert(_writeLength >= 0 && _writeLength <= _writeBuffer.Length);
+                _writeBuffer.Commit(bytesToCopy);
+
+                buffer = buffer.Slice(bytesToCopy);
             }
         }
 
+        /// <inheritdoc/>
+        public override void WriteByte(byte value)
+        {
+            Write([value]);
+        }
+
         /// <summary>
         /// Writes the line to the shell.
         /// </summary>
@@ -940,45 +946,11 @@ namespace Renci.SshNet
         {
             lock (_sync)
             {
-                AssertValid();
-
-                // Ensure sufficient buffer space and copy the new data in.
-
-                if (_readBuffer.Length - _readTail >= e.Data.Length)
-                {
-                    // If there is enough space after _tail for the new data,
-                    // then copy the data there.
-                    Buffer.BlockCopy(e.Data, 0, _readBuffer, _readTail, e.Data.Length);
-                    _readTail += e.Data.Length;
-                }
-                else
-                {
-                    // We can't fit the new data after _tail.
+                _readBuffer.EnsureAvailableSpace(e.Data.Length);
 
-                    var newLength = _readTail - _readHead + e.Data.Length;
-
-                    if (newLength <= _readBuffer.Length)
-                    {
-                        // If there is sufficient space at the start of the buffer,
-                        // then move the current data to the start of the buffer.
-                        Buffer.BlockCopy(_readBuffer, _readHead, _readBuffer, 0, _readTail - _readHead);
-                    }
-                    else
-                    {
-                        // Otherwise, we're gonna need a bigger buffer.
-                        var newBuffer = new byte[Math.Max(newLength, _readBuffer.Length * 2)];
-                        Buffer.BlockCopy(_readBuffer, _readHead, newBuffer, 0, _readTail - _readHead);
-                        _readBuffer = newBuffer;
-                    }
-
-                    // Copy the new data into the freed-up space.
-                    Buffer.BlockCopy(e.Data, 0, _readBuffer, _readTail - _readHead, e.Data.Length);
-
-                    _readHead = 0;
-                    _readTail = newLength;
-                }
+                e.Data.AsSpan().CopyTo(_readBuffer.AvailableSpan);
 
-                AssertValid();
+                _readBuffer.Commit(e.Data.Length);
 
                 Monitor.PulseAll(_sync);
             }

+ 54 - 36
test/Renci.SshNet.Tests/Classes/Common/PipeStreamTest.cs

@@ -16,7 +16,6 @@ namespace Renci.SshNet.Tests.Classes.Common
     public class PipeStreamTest : TestBase
     {
         [TestMethod]
-        [TestCategory("PipeStream")]
         public void Test_PipeStream_Write_Read_Buffer()
         {
             var testBuffer = new byte[1024];
@@ -39,7 +38,6 @@ namespace Renci.SshNet.Tests.Classes.Common
         }
 
         [TestMethod]
-        [TestCategory("PipeStream")]
         public void Test_PipeStream_Write_Read_Byte()
         {
             var testBuffer = new byte[1024];
@@ -133,14 +131,32 @@ namespace Renci.SshNet.Tests.Classes.Common
 
             Assert.IsFalse(readTask.IsCompleted);
 
-            // not using WriteAsync here because it deadlocks the test
-#pragma warning disable S6966 // Awaitable method should be used
-            pipeStream.Write(new byte[] { 1, 2, 3, 4 }, 0, 4);
-#pragma warning restore S6966 // Awaitable method should be used
+            await pipeStream.WriteAsync(new byte[] { 1, 2, 3, 4 }, 0, 4);
 
             Assert.AreEqual(0, await readTask);
         }
 
+#if NET
+        [TestMethod]
+        public async Task Read_EmptySpan_OnlyReturnsZeroWhenDataAvailable()
+        {
+            // And zero byte reads should block but then return 0 once data
+            // is available (the span version).
+
+            var pipeStream = new PipeStream();
+
+            ValueTask<int> readTask = pipeStream.ReadAsync(Memory<byte>.Empty);
+
+            await Task.Delay(50);
+
+            Assert.IsFalse(readTask.IsCompleted);
+
+            await pipeStream.WriteAsync(new byte[] { 1, 2, 3, 4 });
+
+            Assert.AreEqual(0, await readTask);
+        }
+#endif
+
         [TestMethod]
         public void Read_AfterDispose_StillWorks()
         {
@@ -153,6 +169,8 @@ namespace Renci.SshNet.Tests.Classes.Common
             pipeStream.Dispose(); // Check that multiple Dispose is OK.
 #pragma warning restore S3966 // Objects should not be disposed more than once
 
+            Assert.IsTrue(pipeStream.CanRead);
+
             Assert.AreEqual(4, pipeStream.Read(new byte[5], 0, 5));
             Assert.AreEqual(0, pipeStream.Read(new byte[5], 0, 5));
         }
@@ -160,34 +178,15 @@ namespace Renci.SshNet.Tests.Classes.Common
         [TestMethod]
         public void SeekShouldThrowNotSupportedException()
         {
-            const long offset = 0;
-            const SeekOrigin origin = new SeekOrigin();
             var target = new PipeStream();
-
-            try
-            {
-                _ = target.Seek(offset, origin);
-                Assert.Fail();
-            }
-            catch (NotSupportedException)
-            {
-            }
-
+            Assert.Throws<NotSupportedException>(() => target.Seek(offset: 0, SeekOrigin.Begin));
         }
 
         [TestMethod]
         public void SetLengthShouldThrowNotSupportedException()
         {
             var target = new PipeStream();
-
-            try
-            {
-                target.SetLength(1);
-                Assert.Fail();
-            }
-            catch (NotSupportedException)
-            {
-            }
+            Assert.Throws<NotSupportedException>(() => target.SetLength(1));
         }
 
         [TestMethod]
@@ -213,6 +212,31 @@ namespace Renci.SshNet.Tests.Classes.Common
             Assert.AreEqual(0x00, readBuffer[5]);
         }
 
+#if NET
+        [TestMethod]
+        public void WriteTest_Span()
+        {
+            var target = new PipeStream();
+
+            var writeBuffer = new byte[] { 0x0a, 0x05, 0x0d };
+            target.Write(writeBuffer.AsSpan(0, 2));
+
+            writeBuffer = new byte[] { 0x02, 0x04, 0x03, 0x06, 0x09 };
+            target.Write(writeBuffer.AsSpan(1, 2));
+
+            var readBuffer = new byte[6];
+            var bytesRead = target.Read(readBuffer.AsSpan(0, 4));
+
+            Assert.AreEqual(4, bytesRead);
+            Assert.AreEqual(0x0a, readBuffer[0]);
+            Assert.AreEqual(0x05, readBuffer[1]);
+            Assert.AreEqual(0x04, readBuffer[2]);
+            Assert.AreEqual(0x03, readBuffer[3]);
+            Assert.AreEqual(0x00, readBuffer[4]);
+            Assert.AreEqual(0x00, readBuffer[5]);
+        }
+#endif
+
         [TestMethod]
         public void CanReadTest()
         {
@@ -232,6 +256,8 @@ namespace Renci.SshNet.Tests.Classes.Common
         {
             var target = new PipeStream();
             Assert.IsTrue(target.CanWrite);
+            target.Dispose();
+            Assert.IsFalse(target.CanWrite);
         }
 
         [TestMethod]
@@ -265,15 +291,7 @@ namespace Renci.SshNet.Tests.Classes.Common
         public void Position_SetterAlwaysThrowsNotSupportedException()
         {
             var target = new PipeStream();
-
-            try
-            {
-                target.Position = 0;
-                Assert.Fail();
-            }
-            catch (NotSupportedException)
-            {
-            }
+            Assert.Throws<NotSupportedException>(() => target.Position = 0);
         }
     }
 }

+ 34 - 0
test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs

@@ -71,6 +71,23 @@ namespace Renci.SshNet.Tests.Classes
             CollectionAssert.AreEqual(Encoding.UTF8.GetBytes("orld!llo W\0\0"), buffer);
         }
 
+#if NET
+        [TestMethod]
+        public void Read_Bytes_Span()
+        {
+            _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Hello "));
+            _channelSessionStub.Receive(Encoding.UTF8.GetBytes("World!"));
+
+            byte[] buffer = new byte[12];
+
+            Assert.AreEqual(7, _shellStream.Read(buffer.AsSpan(3, 7)));
+            CollectionAssert.AreEqual(Encoding.UTF8.GetBytes("\0\0\0Hello W\0\0"), buffer);
+
+            Assert.AreEqual(5, _shellStream.Read(buffer));
+            CollectionAssert.AreEqual(Encoding.UTF8.GetBytes("orld!llo W\0\0"), buffer);
+        }
+#endif
+
         [TestMethod]
         public void Channel_DataReceived_MoreThanBufferSize()
         {
@@ -172,6 +189,22 @@ namespace Renci.SshNet.Tests.Classes
             Assert.AreEqual(0, await readTask);
         }
 
+#if NET
+        [TestMethod]
+        public async Task Read_EmptySpan_OnlyReturnsZeroWhenDataAvailable()
+        {
+            ValueTask<int> readTask = _shellStream.ReadAsync(Memory<byte>.Empty);
+
+            await Task.Delay(50);
+
+            Assert.IsFalse(readTask.IsCompleted);
+
+            _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Hello World!"));
+
+            Assert.AreEqual(0, await readTask);
+        }
+#endif
+
         [TestMethod]
         public void Expect()
         {
@@ -196,6 +229,7 @@ namespace Renci.SshNet.Tests.Classes
             _shellStream.Dispose(); // Check that multiple Dispose is OK.
 #pragma warning restore S3966 // Objects should not be disposed more than once
 
+            Assert.IsTrue(_shellStream.CanRead);
             Assert.AreEqual("Hello World!", _shellStream.ReadLine());
             Assert.IsNull(_shellStream.ReadLine());
         }