Prechádzať zdrojové kódy

Update Read(byte[] buffer, int offset, int count) to not write bytes to read buffer when the read buffer is empty and count is greater than the number of bytes read from the server.
Lazily initialize read and write buffer.

Gert Driesen 8 rokov pred
rodič
commit
6fd1cc78c3

+ 14 - 16
src/Renci.SshNet.Tests/Classes/Sftp/SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_OutsideOfReadBuffer.cs → src/Renci.SshNet.Tests/Classes/Sftp/SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_NoBuffering.cs

@@ -8,7 +8,7 @@ namespace Renci.SshNet.Tests.Classes.Sftp
 {
     [TestClass]
     //[Ignore]
-    public class SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_OutsideOfReadBuffer : SftpFileStreamTestBase
+    public class SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_NoBuffering : SftpFileStreamTestBase
     {
         private Random _random;
         private string _path;
@@ -21,8 +21,7 @@ namespace Renci.SshNet.Tests.Classes.Sftp
         private SftpFileStream _target;
         private long _actual;
         private byte[] _buffer;
-        private byte[] _serverData1;
-        private byte[] _serverData2;
+        private byte[] _serverData;
 
         protected override void SetupData()
         {
@@ -36,9 +35,8 @@ namespace Renci.SshNet.Tests.Classes.Sftp
             _readBufferSize = 20;
             _writeBufferSize = (uint) _random.Next(5, 1000);
             _handle = GenerateRandom(_random.Next(1, 10), _random);
-            _buffer = new byte[_readBufferSize + 1];
-            _serverData1 = GenerateRandom((int)  _readBufferSize, _random);
-            _serverData2 = GenerateRandom(10, _random);
+            _buffer = new byte[_readBufferSize];
+            _serverData = GenerateRandom(_buffer.Length, _random);
         }
 
         protected override void SetupMocks()
@@ -47,20 +45,17 @@ namespace Renci.SshNet.Tests.Classes.Sftp
                            .Setup(p => p.RequestOpen(_path, Flags.Read | Flags.CreateNewOrOpen, false))
                            .Returns(_handle);
             SftpSessionMock.InSequence(MockSequence)
-                           .Setup(p => p.CalculateOptimalReadLength((uint) _bufferSize))
+                           .Setup(p => p.CalculateOptimalReadLength((uint)_bufferSize))
                            .Returns(_readBufferSize);
             SftpSessionMock.InSequence(MockSequence)
-                           .Setup(p => p.CalculateOptimalWriteLength((uint) _bufferSize, _handle))
+                           .Setup(p => p.CalculateOptimalWriteLength((uint)_bufferSize, _handle))
                            .Returns(_writeBufferSize);
             SftpSessionMock.InSequence(MockSequence)
                            .Setup(p => p.IsOpen)
                            .Returns(true);
             SftpSessionMock.InSequence(MockSequence)
                            .Setup(p => p.RequestRead(_handle, 0UL, _readBufferSize))
-                           .Returns(_serverData1);
-            SftpSessionMock.InSequence(MockSequence)
-                           .Setup(p => p.RequestRead(_handle, _readBufferSize, _readBufferSize))
-                           .Returns(_serverData2);
+                           .Returns(_serverData);
             SftpSessionMock.InSequence(MockSequence)
                            .Setup(p => p.IsOpen)
                            .Returns(true);
@@ -104,21 +99,24 @@ namespace Renci.SshNet.Tests.Classes.Sftp
         }
 
         [TestMethod]
-        public void ReadShouldReadFromServer()
+        public void ReadShouldReturnReadBytesFromServer()
         {
             SftpSessionMock.InSequence(MockSequence)
                            .Setup(p => p.IsOpen)
                            .Returns(true);
             SftpSessionMock.InSequence(MockSequence)
-                           .Setup(p => p.RequestRead(_handle, 0, _readBufferSize))
-                           .Returns(new byte[] {0x04});
+                           .Setup(p => p.RequestRead(_handle, 0UL, _readBufferSize))
+                           .Returns(new byte[] { 0x05, 0x04 });
 
             var buffer = new byte[1];
 
             var bytesRead = _target.Read(buffer, 0, buffer.Length);
 
             Assert.AreEqual(buffer.Length, bytesRead);
-            Assert.AreEqual(0x04, buffer[0]);
+            Assert.AreEqual(0x05, buffer[0]);
+
+            SftpSessionMock.Verify(p => p.IsOpen, Times.Exactly(3));
+            SftpSessionMock.Verify(p => p.RequestRead(_handle, 0UL, _readBufferSize), Times.Exactly(2));
         }
     }
 }

+ 22 - 35
src/Renci.SshNet.Tests/Classes/Sftp/SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_WithinReadBuffer.cs → src/Renci.SshNet.Tests/Classes/Sftp/SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_ReadBuffer.cs

@@ -8,7 +8,7 @@ namespace Renci.SshNet.Tests.Classes.Sftp
 {
     [TestClass]
     //[Ignore]
-    public class SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_WithinReadBuffer : SftpFileStreamTestBase
+    public class SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_ReadBuffer : SftpFileStreamTestBase
     {
         private Random _random;
         private string _path;
@@ -21,7 +21,8 @@ namespace Renci.SshNet.Tests.Classes.Sftp
         private SftpFileStream _target;
         private long _actual;
         private byte[] _buffer;
-        private byte[] _serverData;
+        private byte[] _serverData1;
+        private byte[] _serverData2;
 
         protected override void SetupData()
         {
@@ -35,8 +36,9 @@ namespace Renci.SshNet.Tests.Classes.Sftp
             _readBufferSize = 20;
             _writeBufferSize = (uint) _random.Next(5, 1000);
             _handle = GenerateRandom(_random.Next(1, 10), _random);
-            _buffer = new byte[_readBufferSize - 5];
-            _serverData = GenerateRandom((int) _readBufferSize, _random);
+            _buffer = new byte[2]; // should be less than size of read buffer
+            _serverData1 = GenerateRandom((int) _readBufferSize, _random);
+            _serverData2 = GenerateRandom((int) _readBufferSize, _random);
         }
 
         protected override void SetupMocks()
@@ -45,17 +47,17 @@ namespace Renci.SshNet.Tests.Classes.Sftp
                            .Setup(p => p.RequestOpen(_path, Flags.Read | Flags.CreateNewOrOpen, false))
                            .Returns(_handle);
             SftpSessionMock.InSequence(MockSequence)
-                           .Setup(p => p.CalculateOptimalReadLength((uint)_bufferSize))
+                           .Setup(p => p.CalculateOptimalReadLength((uint) _bufferSize))
                            .Returns(_readBufferSize);
             SftpSessionMock.InSequence(MockSequence)
-                           .Setup(p => p.CalculateOptimalWriteLength((uint)_bufferSize, _handle))
+                           .Setup(p => p.CalculateOptimalWriteLength((uint) _bufferSize, _handle))
                            .Returns(_writeBufferSize);
             SftpSessionMock.InSequence(MockSequence)
                            .Setup(p => p.IsOpen)
                            .Returns(true);
             SftpSessionMock.InSequence(MockSequence)
                            .Setup(p => p.RequestRead(_handle, 0UL, _readBufferSize))
-                           .Returns(_serverData);
+                           .Returns(_serverData1);
             SftpSessionMock.InSequence(MockSequence)
                            .Setup(p => p.IsOpen)
                            .Returns(true);
@@ -99,54 +101,39 @@ namespace Renci.SshNet.Tests.Classes.Sftp
         }
 
         [TestMethod]
-        public void ReadLessThanReadBufferSizeShouldReturnBytesFromReadBuffer()
+        public void ReadBytesThatWereNotBufferedBeforeSeekShouldReadBytesFromServer()
         {
             SftpSessionMock.InSequence(MockSequence)
                            .Setup(p => p.IsOpen)
                            .Returns(true);
-
-            var buffer = new byte[_readBufferSize - 3];
-
-            var bytesRead = _target.Read(buffer, 0, buffer.Length);
-
-            Assert.AreEqual(buffer.Length, bytesRead);
-            Assert.IsTrue(_serverData.Take(buffer.Length).IsEqualTo(buffer));
-        }
-
-        [TestMethod]
-        public void ReadReadBufferSizeShouldReturnBytesFromReadBuffer()
-        {
             SftpSessionMock.InSequence(MockSequence)
-                           .Setup(p => p.IsOpen)
-                           .Returns(true);
+                           .Setup(p => p.RequestRead(_handle, 0UL, _readBufferSize))
+                           .Returns(_serverData2);
 
-            var buffer = new byte[_readBufferSize];
+            var bytesRead = _target.Read(_buffer, 0, _buffer.Length);
 
-            var bytesRead = _target.Read(buffer, 0, buffer.Length);
+            Assert.AreEqual(_buffer.Length, bytesRead);
+            Assert.IsTrue(_serverData2.Take(_buffer.Length).IsEqualTo(_buffer));
 
-            Assert.AreEqual(buffer.Length, bytesRead);
-            Assert.IsTrue(_serverData.IsEqualTo(buffer));
+            SftpSessionMock.Verify(p => p.IsOpen, Times.Exactly(3));
+            SftpSessionMock.Verify(p => p.RequestRead(_handle, 0UL, _readBufferSize), Times.Exactly(2));
         }
 
         [TestMethod]
-        public void ReadMoreThanReadBufferSizeShouldReturnBytesFromReadBufferAndReadRemaningBytesFromServer()
+        public void ReadBytesThatWereBufferedBeforeSeekShouldReadBytesFromServer()
         {
-            var serverData2 = GenerateRandom(6, _random);
-
             SftpSessionMock.InSequence(MockSequence)
                            .Setup(p => p.IsOpen)
                            .Returns(true);
             SftpSessionMock.InSequence(MockSequence)
-                           .Setup(p => p.RequestRead(_handle, _readBufferSize, _readBufferSize))
-                           .Returns(serverData2);
-
-            var buffer = new byte[_readBufferSize + 4];
+                           .Setup(p => p.RequestRead(_handle, 0UL, _readBufferSize))
+                           .Returns(_serverData2);
 
+            var buffer = new byte[_buffer.Length + 1]; // we read one byte that was previously buffered
             var bytesRead = _target.Read(buffer, 0, buffer.Length);
 
             Assert.AreEqual(buffer.Length, bytesRead);
-            Assert.IsTrue(_serverData.Take((int) _readBufferSize).IsEqualTo(buffer.Take((int) _readBufferSize)));
-            Assert.IsTrue(serverData2.Take(4).IsEqualTo(buffer.Take((int) _readBufferSize, 4)));
+            Assert.IsTrue(_serverData2.Take(buffer.Length).IsEqualTo(buffer));
         }
     }
 }

+ 2 - 2
src/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj

@@ -439,8 +439,8 @@
     <Compile Include="Classes\Sftp\SftpFileStreamTest_Seek_PositionedAtBeginningOfStream_OriginBeginAndOffsetNegative.cs" />
     <Compile Include="Classes\Sftp\SftpFileStreamTest_Seek_PositionedAtBeginningOfStream_OriginBeginAndOffsetPositive.cs" />
     <Compile Include="Classes\Sftp\SftpFileStreamTest_Seek_PositionedAtBeginningOfStream_OriginBeginAndOffsetZero.cs" />
-    <Compile Include="Classes\Sftp\SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_OutsideOfReadBuffer.cs" />
-    <Compile Include="Classes\Sftp\SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_WithinReadBuffer.cs" />
+    <Compile Include="Classes\Sftp\SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_NoBuffering.cs" />
+    <Compile Include="Classes\Sftp\SftpFileStreamTest_Seek_PositionedAtMiddleOfStream_OriginBeginAndOffsetZero_ReadBuffer.cs" />
     <Compile Include="Classes\Sftp\SftpFileStreamTest_SetLength_Closed.cs" />
     <Compile Include="Classes\Sftp\SftpFileStreamTest_SetLength_Disposed.cs" />
     <Compile Include="Classes\Sftp\SftpFileStreamTest_SetLength_SessionNotOpen.cs" />

+ 71 - 40
src/Renci.SshNet/Sftp/SftpFileStream.cs

@@ -18,9 +18,9 @@ namespace Renci.SshNet.Sftp
 
         // Buffer information.
         private readonly int _readBufferSize;
-        private readonly byte[] _readBuffer;
+        private byte[] _readBuffer;
         private readonly int _writeBufferSize;
-        private readonly byte[] _writeBuffer;
+        private byte[] _writeBuffer;
         private int _bufferPosition;
         private int _bufferLen;
         private long _position;
@@ -253,9 +253,7 @@ namespace Renci.SshNet.Sftp
             // or SSH_FXP_WRITE message
 
             _readBufferSize = (int) session.CalculateOptimalReadLength((uint) bufferSize);
-            _readBuffer = new byte[_readBufferSize];
             _writeBufferSize = (int) session.CalculateOptimalWriteLength((uint) bufferSize, _handle);
-            _writeBuffer = new byte[_writeBufferSize];
 
             if (mode == FileMode.Append)
             {
@@ -342,52 +340,69 @@ namespace Renci.SshNet.Sftp
                     {
                         var data = _session.RequestRead(_handle, (ulong) _position, (uint) _readBufferSize);
 
-                        _bufferPosition = 0;
-                        _bufferLen = data.Length;
-
-                        if (_bufferLen == 0)
+                        if (data.Length == 0)
                         {
+                            _bufferPosition = 0;
+                            _bufferLen = 0;
+
                             break;
                         }
 
-                        // determine number of bytes that we can write into caller-provided buffer
-                        var bytesToWriteToCallerBuffer = Math.Min(_bufferLen, count);
+                        var bytesToWriteToCallerBuffer = count;
+                        if (bytesToWriteToCallerBuffer >= data.Length)
+                        {
+                            // write all data read to caller-provided buffer
+                            bytesToWriteToCallerBuffer = data.Length;
+                            // reset buffer since we will skip buffering
+                            _bufferPosition = 0;
+                            _bufferLen = 0;
+                        }
+                        else
+                        {
+                            // determine number of bytes that we should write into read buffer
+                            var bytesToWriteToReadBuffer = data.Length - bytesToWriteToCallerBuffer;
+                            // write remaining bytes to read buffer
+                            Buffer.BlockCopy(data, count, GetOrCreateReadBuffer(), 0, bytesToWriteToReadBuffer);
+                            // update position in read buffer
+                            _bufferPosition = 0;
+                            // update number of bytes in read buffer
+                            _bufferLen = bytesToWriteToReadBuffer;
+                        }
+
                         // write bytes to caller-provided buffer
                         Buffer.BlockCopy(data, 0, buffer, offset, bytesToWriteToCallerBuffer);
                         // advance offset to start writing bytes into caller-provided buffer
                         offset += bytesToWriteToCallerBuffer;
-                        // update number of bytes left to read
-                        count -= bytesToWriteToCallerBuffer;
                         // record total number of bytes read into caller-provided buffer
                         readLen += bytesToWriteToCallerBuffer;
+                        // signal that all caller-requested bytes are read
+                        count -= bytesToWriteToCallerBuffer;
                         // update stream position
                         _position += bytesToWriteToCallerBuffer;
-                        // update position in read buffer
-                        _bufferPosition = bytesToWriteToCallerBuffer;
-                        // write read bytes to read buffer
-                        Buffer.BlockCopy(data, 0, _readBuffer, 0, _bufferLen);
                     }
                     else
                     {
-                        // determine number of bytes that we can write from read buffer to caller-provided buffer
-                        var bytesToWriteToCallerBuffer = Math.Min(bytesAvailableInBuffer, count);
+                        // limit the number of bytes to use from read buffer to the caller-request number of bytes
+                        if (bytesAvailableInBuffer > count)
+                            bytesAvailableInBuffer = count;
+
                         // copy data from read buffer to the caller-provided buffer
-                        Buffer.BlockCopy(_readBuffer, _bufferPosition, buffer, offset, bytesToWriteToCallerBuffer);
+                        Buffer.BlockCopy(GetOrCreateReadBuffer(), _bufferPosition, buffer, offset, bytesAvailableInBuffer);
                         // update position in read buffer
-                        _bufferPosition += bytesToWriteToCallerBuffer;
+                        _bufferPosition += bytesAvailableInBuffer;
                         // advance offset to start writing bytes into caller-provided buffer
                         offset += bytesAvailableInBuffer;
-                        // update number of bytes left to read
-                        count -= bytesToWriteToCallerBuffer;
                         // record total number of bytes read into caller-provided buffer
-                        readLen += bytesToWriteToCallerBuffer;
+                        readLen += bytesAvailableInBuffer;
+                        // update number of bytes left to read
+                        count -= bytesAvailableInBuffer;
                         // update stream position
-                        _position += bytesToWriteToCallerBuffer;
+                        _position += bytesAvailableInBuffer;
                     }
                 }
             }
 
-            // Return the number of bytes that were read to the caller.
+            // return the number of bytes that were read to the caller.
             return readLen;
         }
 
@@ -410,28 +425,32 @@ namespace Renci.SshNet.Sftp
                 // Setup the object for reading.
                 SetupRead();
 
+                byte[] readBuffer;
+
                 // Read more data into the internal buffer if necessary.
                 if (_bufferPosition >= _bufferLen)
                 {
-                    _bufferPosition = 0;
-                    _bufferLen = 0;
-
                     var data = _session.RequestRead(_handle, (ulong) _position, (uint) _readBufferSize);
-
-                    _bufferLen = data.Length;
-
-                    if (_bufferLen == 0)
+                    if (data.Length == 0)
                     {
                         // We've reached EOF.
                         return -1;
                     }
 
-                    Buffer.BlockCopy(data, 0, _readBuffer, 0, _bufferLen);
+                    readBuffer = GetOrCreateReadBuffer();
+                    Buffer.BlockCopy(data, 0, readBuffer, 0, data.Length);
+
+                    _bufferPosition = 0;
+                    _bufferLen = data.Length;
+                }
+                else
+                {
+                    readBuffer = GetOrCreateReadBuffer();
                 }
 
                 // Extract the next byte from the buffer.
                 ++_position;
-                return _readBuffer[_bufferPosition++];
+                return readBuffer[_bufferPosition++];
             }
         }
 
@@ -501,8 +520,7 @@ namespace Renci.SshNet.Sftp
                     if (origin == SeekOrigin.Begin)
                     {
                         newPosn = _position - _bufferPosition;
-                        if (offset >= newPosn && offset <
-                                (newPosn + _bufferLen))
+                        if (offset >= newPosn && offset < (newPosn + _bufferLen))
                         {
                             _bufferPosition = (int)(offset - newPosn);
                             _position = offset;
@@ -515,8 +533,7 @@ namespace Renci.SshNet.Sftp
                         if (newPosn >= (_position - _bufferPosition) &&
                            newPosn < (_position - _bufferPosition + _bufferLen))
                         {
-                            _bufferPosition =
-                                (int)(newPosn - (_position - _bufferPosition));
+                            _bufferPosition = (int) (newPosn - (_position - _bufferPosition));
                             _position = newPosn;
                             return _position;
                         }
@@ -642,7 +659,7 @@ namespace Renci.SshNet.Sftp
                     else
                     {
                         // No: copy the data to the write buffer first.
-                        Buffer.BlockCopy(buffer, offset, _writeBuffer, _bufferPosition, tempLen);
+                        Buffer.BlockCopy(buffer, offset, GetOrCreateWriteBuffer(), _bufferPosition, tempLen);
                         _bufferPosition += tempLen;
                     }
 
@@ -658,7 +675,7 @@ namespace Renci.SshNet.Sftp
                 {
                     using (var wait = new AutoResetEvent(false))
                     {
-                        _session.RequestWrite(_handle, (ulong) (_position - _bufferPosition), _writeBuffer, 0, _bufferPosition, wait);
+                        _session.RequestWrite(_handle, (ulong) (_position - _bufferPosition), GetOrCreateWriteBuffer(), 0, _bufferPosition, wait);
                     }
 
                     _bufferPosition = 0;
@@ -742,6 +759,20 @@ namespace Renci.SshNet.Sftp
             }
         }
 
+        private byte[] GetOrCreateReadBuffer()
+        {
+            if (_readBuffer == null)
+                _readBuffer = new byte[_readBufferSize];
+            return _readBuffer;
+        }
+
+        private byte[] GetOrCreateWriteBuffer()
+        {
+            if (_writeBuffer == null)
+                _writeBuffer = new byte[_writeBufferSize];
+            return _writeBuffer;
+        }
+
         /// <summary>
         /// Flushes the read data from the buffer.
         /// </summary>