Przeglądaj źródła

Fixed dequeuing of incoming queue (#1319)

* Fixed dequeuing of incoming queue.

* Adjusted return of Expect to make sure it returns the full incoming queue.
Jean-Sebastien Carle 1 rok temu
rodzic
commit
d07827b6c8

+ 52 - 43
src/Renci.SshNet/ShellStream.cs

@@ -282,19 +282,14 @@ namespace Renci.SshNet
 
                             if (match.Success)
                             {
-                                var returnText = matchText.Substring(0, match.Index + match.Length);
-                                var returnLength = _encoding.GetByteCount(returnText);
+#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
+                                var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
+#else
+                                var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
+#endif
 
                                 // Remove processed items from the queue
-                                for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
-                                {
-                                    if (_expect.Count == _incoming.Count)
-                                    {
-                                        _ = _expect.Dequeue();
-                                    }
-
-                                    _ = _incoming.Dequeue();
-                                }
+                                var returnText = SyncQueuesAndReturn(returnLength);
 
                                 expectAction.Action(returnText);
                                 expectedFound = true;
@@ -385,19 +380,14 @@ namespace Renci.SshNet
 
                     if (match.Success)
                     {
-                        returnText = matchText.Substring(0, match.Index + match.Length);
-                        var returnLength = _encoding.GetByteCount(returnText);
+#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
+                        var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
+#else
+                        var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
+#endif
 
                         // Remove processed items from the queue
-                        for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
-                        {
-                            if (_expect.Count == _incoming.Count)
-                            {
-                                _ = _expect.Dequeue();
-                            }
-
-                            _ = _incoming.Dequeue();
-                        }
+                        returnText = SyncQueuesAndReturn(returnLength);
 
                         break;
                     }
@@ -501,19 +491,14 @@ namespace Renci.SshNet
 
                                     if (match.Success)
                                     {
-                                        returnText = matchText.Substring(0, match.Index + match.Length);
-                                        var returnLength = _encoding.GetByteCount(returnText);
+#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
+                                        var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
+#else
+                                        var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
+#endif
 
                                         // Remove processed items from the queue
-                                        for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
-                                        {
-                                            if (_expect.Count == _incoming.Count)
-                                            {
-                                                _ = _expect.Dequeue();
-                                            }
-
-                                            _ = _incoming.Dequeue();
-                                        }
+                                        returnText = SyncQueuesAndReturn(returnLength);
 
                                         expectAction.Action(returnText);
                                         callback?.Invoke(asyncResult);
@@ -614,15 +599,7 @@ namespace Renci.SshNet
                         var bytesProcessed = _encoding.GetByteCount(text + CrLf);
 
                         // remove processed bytes from the queue
-                        for (var i = 0; i < bytesProcessed; i++)
-                        {
-                            if (_expect.Count == _incoming.Count)
-                            {
-                                _ = _expect.Dequeue();
-                            }
-
-                            _ = _incoming.Dequeue();
-                        }
+                        SyncQueuesAndDequeue(bytesProcessed);
 
                         break;
                     }
@@ -687,7 +664,7 @@ namespace Renci.SshNet
             {
                 for (; i < count && _incoming.Count > 0; i++)
                 {
-                    if (_expect.Count == _incoming.Count)
+                    if (_incoming.Count == _expect.Count)
                     {
                         _ = _expect.Dequeue();
                     }
@@ -869,5 +846,37 @@ namespace Renci.SshNet
         {
             DataReceived?.Invoke(this, new ShellDataEventArgs(data));
         }
+
+        private string SyncQueuesAndReturn(int bytesToDequeue)
+        {
+            string incomingText;
+
+            lock (_incoming)
+            {
+                var incomingLength = _incoming.Count - _expect.Count + bytesToDequeue;
+                incomingText = _encoding.GetString(_incoming.ToArray(), 0, incomingLength);
+
+                SyncQueuesAndDequeue(bytesToDequeue);
+            }
+
+            return incomingText;
+        }
+
+        private void SyncQueuesAndDequeue(int bytesToDequeue)
+        {
+            lock (_incoming)
+            {
+                while (_incoming.Count > _expect.Count)
+                {
+                    _ = _incoming.Dequeue();
+                }
+
+                for (var count = 0; count < bytesToDequeue && _incoming.Count > 0; count++)
+                {
+                    _ = _incoming.Dequeue();
+                    _ = _expect.Dequeue();
+                }
+            }
+        }
     }
 }

+ 29 - 2
test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs

@@ -17,6 +17,8 @@ namespace Renci.SshNet.Tests.Classes
     [TestClass]
     public class ShellStreamTest_ReadExpect
     {
+        private const int BufferSize = 1024;
+        private const int ExpectSize = BufferSize * 2;
         private ShellStream _shellStream;
         private ChannelSessionStub _channelSessionStub;
 
@@ -42,8 +44,8 @@ namespace Renci.SshNet.Tests.Classes
                 width: 800,
                 height: 600,
                 terminalModeValues: null,
-                bufferSize: 1024,
-                expectSize: 2048);
+                bufferSize: BufferSize,
+                expectSize: ExpectSize);
         }
 
         [TestMethod]
@@ -244,6 +246,31 @@ namespace Renci.SshNet.Tests.Classes
             Assert.AreEqual($"{new string('c', 100)}", _shellStream.Read());
         }
 
+        [TestMethod]
+        public void Expect_String_DequeueChecks()
+        {
+            const string expected = "ccccc";
+
+            // Prime buffer
+            _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string(' ', BufferSize)));
+            _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string(' ', ExpectSize)));
+
+            // Test data
+            _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('a', 100)));
+            _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('b', 100)));
+            _channelSessionStub.Receive(Encoding.UTF8.GetBytes(expected));
+            _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('d', 100)));
+            _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('e', 100)));
+
+            // Expected result
+            var expectedResult = $"{new string(' ', BufferSize)}{new string(' ', ExpectSize)}{new string('a', 100)}{new string('b', 100)}{expected}";
+            var expectedRead = $"{new string('d', 100)}{new string('e', 100)}";
+
+            Assert.AreEqual(expectedResult, _shellStream.Expect(expected));
+
+            Assert.AreEqual(expectedRead, _shellStream.Read());
+        }
+
         [TestMethod]
         public void Expect_Timeout()
         {