Explorar o código

Fix ShellStream when receiving data larger than buffer length (#1337)

Rob Hague hai 1 ano
pai
achega
69f145159a

+ 2 - 2
src/Renci.SshNet/ShellStream.cs

@@ -75,7 +75,7 @@ namespace Renci.SshNet
             Debug.Assert(Monitor.IsEntered(_sync), $"Should be in lock on {nameof(_sync)}");
             Debug.Assert(_readHead >= 0, $"{nameof(_readHead)} should be non-negative but is {_readHead}");
             Debug.Assert(_readTail >= 0, $"{nameof(_readTail)} should be non-negative but is {_readTail}");
-            Debug.Assert(_readHead < _readBuffer.Length || _readBuffer.Length == 0, $"{nameof(_readHead)} should be < {nameof(_readBuffer)}.Length but is {_readHead}");
+            Debug.Assert(_readHead <= _readBuffer.Length, $"{nameof(_readHead)} should be <= {nameof(_readBuffer)}.Length but is {_readHead}");
             Debug.Assert(_readTail <= _readBuffer.Length, $"{nameof(_readTail)} should be <= {nameof(_readBuffer)}.Length but is {_readTail}");
             Debug.Assert(_readHead <= _readTail, $"Should have {nameof(_readHead)} <= {nameof(_readTail)} but have {_readHead} <= {_readTail}");
         }
@@ -938,7 +938,7 @@ namespace Renci.SshNet
                     else
                     {
                         // Otherwise, we're gonna need a bigger buffer.
-                        var newBuffer = new byte[_readBuffer.Length * 2];
+                        var newBuffer = new byte[Math.Max(newLength, _readBuffer.Length * 2)];
                         Buffer.BlockCopy(_readBuffer, _readHead, newBuffer, 0, _readTail - _readHead);
                         _readBuffer = newBuffer;
                     }

+ 14 - 0
test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs

@@ -9,6 +9,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting;
 
 using Moq;
 
+using Renci.SshNet.Abstractions;
 using Renci.SshNet.Channels;
 using Renci.SshNet.Common;
 
@@ -70,6 +71,19 @@ namespace Renci.SshNet.Tests.Classes
             CollectionAssert.AreEqual(Encoding.UTF8.GetBytes("orld!llo W\0\0"), buffer);
         }
 
+        [TestMethod]
+        public void Channel_DataReceived_MoreThanBufferSize()
+        {
+            // Test buffer resizing
+            byte[] expectedData = CryptoAbstraction.GenerateRandom(BufferSize * 3);
+            _channelSessionStub.Receive(expectedData);
+
+            byte[] actualData = new byte[expectedData.Length + 1];
+
+            Assert.AreEqual(expectedData.Length, _shellStream.Read(actualData, 0, actualData.Length));
+            CollectionAssert.AreEqual(expectedData, actualData.Take(expectedData.Length));
+        }
+
         [DataTestMethod]
         [DataRow("\r\n")]
         [DataRow("\r")]