2
0
Эх сурвалжийг харах

Added internal ctor to ScpClient that takes an IServiceFactory to allow for mocking of PipeStream.
Added missing read (and check) of the returncode after a file copy command. Fixes issue #1382.

Gert Driesen 11 жил өмнө
parent
commit
6e11950d60

+ 140 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/ScpClientTest_Upload_FileInfoAndPath_SendExecRequestReturnsFalse.cs

@@ -0,0 +1,140 @@
+using System;
+using System.Collections.Generic;
+using System.Globalization;
+using System.IO;
+using System.Linq;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Channels;
+using Renci.SshNet.Common;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class ScpClientTest_Upload_FileInfoAndPath_SendExecRequestReturnsFalse
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<ISession> _sessionMock;
+        private Mock<IChannelSession> _channelSessionMock;
+        private Mock<PipeStream> _pipeStreamMock;
+        private ConnectionInfo _connectionInfo;
+        private ScpClient _scpClient;
+        private FileInfo _fileInfo;
+        private string _path;
+        private string _fileName;
+        private IList<ScpUploadEventArgs> _uploadingRegister;
+        private SshException _actualException;
+
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+            if (_fileName != null)
+            {
+                File.Delete(_fileName);
+                _fileName = null;
+            }
+        }
+
+        protected void Arrange()
+        {
+            var random = new Random();
+            _fileName = CreateTemporaryFile(new byte[] {1});
+            _connectionInfo = new ConnectionInfo("host", 22, "user", new PasswordAuthenticationMethod("user", "pwd"));
+            _fileInfo = new FileInfo(_fileName);
+            _path = random.Next().ToString(CultureInfo.InvariantCulture);
+            _uploadingRegister = new List<ScpUploadEventArgs>();
+
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+            _channelSessionMock = new Mock<IChannelSession>(MockBehavior.Strict);
+            _pipeStreamMock = new Mock<PipeStream>(MockBehavior.Strict);
+
+            var sequence = new MockSequence();
+            _serviceFactoryMock.InSequence(sequence)
+                .Setup(p => p.CreateSession(_connectionInfo))
+                .Returns(_sessionMock.Object);
+            _sessionMock.InSequence(sequence).Setup(p => p.Connect());
+            _serviceFactoryMock.InSequence(sequence).Setup(p => p.CreatePipeStream()).Returns(_pipeStreamMock.Object);
+            _sessionMock.InSequence(sequence).Setup(p => p.CreateChannelSession()).Returns(_channelSessionMock.Object);
+            _channelSessionMock.InSequence(sequence).Setup(p => p.Open());
+            _channelSessionMock.InSequence(sequence)
+                .Setup(
+                    p => p.SendExecRequest(string.Format("scp -t \"{0}\"", _path))).Returns(false);
+            _channelSessionMock.InSequence(sequence).Setup(p => p.Dispose());
+            _pipeStreamMock.As<IDisposable>().InSequence(sequence).Setup(p => p.Dispose());
+
+            _scpClient = new ScpClient(_connectionInfo, false, _serviceFactoryMock.Object);
+            _scpClient.Uploading += (sender, args) => _uploadingRegister.Add(args);
+            _scpClient.Connect();
+        }
+
+        protected virtual void Act()
+        {
+            try
+            {
+                _scpClient.Upload(_fileInfo, _path);
+                Assert.Fail();
+            }
+            catch (SshException ex)
+            {
+                _actualException = ex;
+            }
+        }
+
+        [TestMethod]
+        public void UploadShouldHaveThrownSshException()
+        {
+            Assert.IsNotNull(_actualException);
+            Assert.IsNull(_actualException.InnerException);
+            Assert.AreEqual("Secure copy execution request was rejected by the server. Please consult the server logs.", _actualException.Message);
+        }
+
+        [TestMethod]
+        public void SendExecREquestOnChannelSessionShouldBeInvokedOnce()
+        {
+            _channelSessionMock.Verify(p => p.SendExecRequest(string.Format("scp -t \"{0}\"", _path)), Times.Once);
+        }
+
+        [TestMethod]
+        public void CloseOnChannelShouldNeverBeInvoked()
+        {
+            _channelSessionMock.Verify(p => p.Close(), Times.Never);
+        }
+
+        [TestMethod]
+        public void DisposeOnChannelShouldBeInvokedOnce()
+        {
+            _channelSessionMock.Verify(p => p.Dispose(), Times.Once);
+        }
+
+        [TestMethod]
+        public void DisposeOnPipeStreamShouldBeInvokedOnce()
+        {
+            _pipeStreamMock.As<IDisposable>().Verify(p => p.Dispose(), Times.Once);
+        }
+
+        [TestMethod]
+        public void UploadingShouldNeverHaveFired()
+        {
+            Assert.AreEqual(0, _uploadingRegister.Count);
+        }
+
+        private string CreateTemporaryFile(byte[] content)
+        {
+            var tempFile = Path.GetTempFileName();
+            using (var fs = File.OpenWrite(tempFile))
+            {
+                fs.Write(content, 0, content.Length);
+            }
+            return tempFile;
+        }
+    }
+}

+ 188 - 0
Renci.SshClient/Renci.SshNet.Tests/Classes/ScpClientTest_Upload_FileInfoAndPath_Success.cs

@@ -0,0 +1,188 @@
+using System;
+using System.Collections.Generic;
+using System.Globalization;
+using System.IO;
+using System.Linq;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Moq;
+using Renci.SshNet.Channels;
+using Renci.SshNet.Common;
+
+namespace Renci.SshNet.Tests.Classes
+{
+    [TestClass]
+    public class ScpClientTest_Upload_FileInfoAndPath_Success
+    {
+        private Mock<IServiceFactory> _serviceFactoryMock;
+        private Mock<ISession> _sessionMock;
+        private Mock<IChannelSession> _channelSessionMock;
+        private Mock<PipeStream> _pipeStreamMock;
+        private ConnectionInfo _connectionInfo;
+        private ScpClient _scpClient;
+        private FileInfo _fileInfo;
+        private string _path;
+        private int _bufferSize;
+        private byte[] _fileContent;
+        private string _fileName;
+        private int _fileSize;
+        private IList<ScpUploadEventArgs> _uploadingRegister;
+        [TestInitialize]
+        public void Setup()
+        {
+            Arrange();
+            Act();
+        }
+
+        [TestCleanup]
+        public void Cleanup()
+        {
+            if (_fileName != null)
+            {
+                File.Delete(_fileName);
+                _fileName = null;
+            }
+        }
+
+        protected void Arrange()
+        {
+            var random = new Random();
+            _bufferSize = random.Next(5, 15);
+            _fileSize = _bufferSize + 2; //force uploading 2 chunks
+            _fileContent = CreateContent(_fileSize);
+            _fileName = CreateTemporaryFile(_fileContent);
+            _connectionInfo = new ConnectionInfo("host", 22, "user", new PasswordAuthenticationMethod("user", "pwd"));
+            _fileInfo = new FileInfo(_fileName);
+            _path = random.Next().ToString(CultureInfo.InvariantCulture);
+            _uploadingRegister = new List<ScpUploadEventArgs>();
+
+            _serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
+            _sessionMock = new Mock<ISession>(MockBehavior.Strict);
+            _channelSessionMock = new Mock<IChannelSession>(MockBehavior.Strict);
+            _pipeStreamMock = new Mock<PipeStream>(MockBehavior.Strict);
+
+            var sequence = new MockSequence();
+            _serviceFactoryMock.InSequence(sequence)
+                .Setup(p => p.CreateSession(_connectionInfo))
+                .Returns(_sessionMock.Object);
+            _sessionMock.InSequence(sequence).Setup(p => p.Connect());
+            _serviceFactoryMock.InSequence(sequence).Setup(p => p.CreatePipeStream()).Returns(_pipeStreamMock.Object);
+            _sessionMock.InSequence(sequence).Setup(p => p.CreateChannelSession()).Returns(_channelSessionMock.Object);
+            _channelSessionMock.InSequence(sequence).Setup(p => p.Open());
+            _channelSessionMock.InSequence(sequence)
+                .Setup(
+                    p => p.SendExecRequest(string.Format("scp -t \"{0}\"", _path))).Returns(true);
+            for (var i = 0; i < random.Next(1, 3); i++)
+                _pipeStreamMock.InSequence(sequence).Setup(p => p.ReadByte()).Returns(-1);
+            _pipeStreamMock.InSequence(sequence).Setup(p => p.ReadByte()).Returns(0);
+            _channelSessionMock.InSequence(sequence).Setup(p => p.SendData(It.IsAny<byte[]>()));
+            for (var i = 0; i < random.Next(1, 3); i++)
+                _pipeStreamMock.InSequence(sequence).Setup(p => p.ReadByte()).Returns(-1);
+            _pipeStreamMock.InSequence(sequence).Setup(p => p.ReadByte()).Returns(0);
+            _channelSessionMock.InSequence(sequence)
+                .Setup(p => p.SendData(It.Is<byte[]>(b => b.SequenceEqual(CreateData(
+                                                                            string.Format("C0644 {0} {1}\n",
+                                                                                          _fileInfo.Length,
+                                                                                           Path.GetFileName(_fileName)
+                                                                                         )
+                                                                                     )))));
+            for (var i = 0; i < random.Next(1, 3); i++)
+                _pipeStreamMock.InSequence(sequence).Setup(p => p.ReadByte()).Returns(-1);
+            _pipeStreamMock.InSequence(sequence).Setup(p => p.ReadByte()).Returns(0);
+            _channelSessionMock.InSequence(sequence)
+                .Setup(
+                    p => p.SendData(It.Is<byte[]>(b => b.SequenceEqual(_fileContent.Take(_bufferSize)))));
+            _channelSessionMock.InSequence(sequence)
+                .Setup(
+                    p => p.SendData(It.Is<byte[]>(b => b.SequenceEqual(_fileContent.Skip(_bufferSize)))));
+            _channelSessionMock.InSequence(sequence)
+                .Setup(
+                    p => p.SendData(It.Is<byte[]>(b => b.SequenceEqual(new byte[] {0}))));
+            for (var i = 0; i < random.Next(1, 3); i++)
+                _pipeStreamMock.InSequence(sequence).Setup(p => p.ReadByte()).Returns(-1);
+            _pipeStreamMock.InSequence(sequence).Setup(p => p.ReadByte()).Returns(0);
+            _channelSessionMock.InSequence(sequence).Setup(p => p.Close());
+            _channelSessionMock.InSequence(sequence).Setup(p => p.Dispose());
+            _pipeStreamMock.As<IDisposable>().InSequence(sequence).Setup(p => p.Dispose());
+
+            _scpClient = new ScpClient(_connectionInfo, false, _serviceFactoryMock.Object)
+                {
+                    BufferSize = (uint) _bufferSize
+                };
+            _scpClient.Uploading += (sender, args) => _uploadingRegister.Add(args);
+            _scpClient.Connect();
+        }
+
+        protected virtual void Act()
+        {
+            _scpClient.Upload(_fileInfo, _path);
+        }
+
+        [TestMethod]
+        public void SendExecREquestOnChannelSessionShouldBeInvokedOnce()
+        {
+            _channelSessionMock.Verify(p => p.SendExecRequest(string.Format("scp -t \"{0}\"", _path)), Times.Once);
+        }
+
+        [TestMethod]
+        public void CloseOnChannelShouldBeInvokedOnce()
+        {
+            _channelSessionMock.Verify(p => p.Close(), Times.Once);
+        }
+
+        [TestMethod]
+        public void DisposeOnChannelShouldBeInvokedOnce()
+        {
+            _channelSessionMock.Verify(p => p.Dispose(), Times.Once);
+        }
+
+        [TestMethod]
+        public void DisposeOnPipeStreamShouldBeInvokedOnce()
+        {
+            _pipeStreamMock.As<IDisposable>().Verify(p => p.Dispose(), Times.Once);
+        }
+
+        [TestMethod]
+        public void UploadingShouldHaveFiredTwice()
+        {
+            Assert.AreEqual(2, _uploadingRegister.Count);
+
+            var uploading = _uploadingRegister[0];
+            Assert.IsNotNull(uploading);
+            Assert.AreSame(_fileInfo.Name, uploading.Filename);
+            Assert.AreEqual(_fileSize, uploading.Size);
+            Assert.AreEqual(_bufferSize, uploading.Uploaded);
+
+            uploading = _uploadingRegister[1];
+            Assert.IsNotNull(uploading);
+            Assert.AreSame(_fileInfo.Name, uploading.Filename);
+            Assert.AreEqual(_fileSize, uploading.Size);
+            Assert.AreEqual(_fileSize, uploading.Uploaded);
+        }
+
+        private IEnumerable<byte> CreateData(string command)
+        {
+            return Encoding.Default.GetBytes(command);
+        }
+
+        private byte[] CreateContent(int length)
+        {
+            var random = new Random();
+            var content = new byte[length];
+
+            for (var i = 0; i < length; i++)
+                content[i] = (byte) random.Next(byte.MinValue, byte.MaxValue);
+            return content;
+        }
+
+        private string CreateTemporaryFile(byte[] content)
+        {
+            var tempFile = Path.GetTempFileName();
+            using (var fs = File.OpenWrite(tempFile))
+            {
+                fs.Write(content, 0, content.Length);
+            }
+            return tempFile;
+        }
+    }
+}

+ 2 - 0
Renci.SshClient/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj

@@ -155,6 +155,8 @@
     <Compile Include="Classes\ForwardedPortRemoteTest_Stop_PortNeverStarted.cs" />
     <Compile Include="Classes\ForwardedPortRemoteTest_Stop_PortStarted_ChannelBound.cs" />
     <Compile Include="Classes\ForwardedPortRemoteTest_Stop_PortStopped.cs" />
+    <Compile Include="Classes\ScpClientTest_Upload_FileInfoAndPath_SendExecRequestReturnsFalse.cs" />
+    <Compile Include="Classes\ScpClientTest_Upload_FileInfoAndPath_Success.cs" />
     <Compile Include="Classes\Security\AlgorithmTest.cs" />
     <Compile Include="Classes\Security\CertificateHostAlgorithmTest.cs" />
     <Compile Include="Classes\Security\Cryptography\BlockCipherTest.cs" />

+ 15 - 1
Renci.SshClient/Renci.SshNet/BaseClient.cs

@@ -21,10 +21,24 @@ namespace Renci.SshNet
         private ConnectionInfo _connectionInfo;
 
         /// <summary>
-        /// Gets current session.
+        /// Gets the current session.
         /// </summary>
+        /// <value>
+        /// The current session.
+        /// </value>
         internal ISession Session { get; private set; }
 
+        /// <summary>
+        /// Gets the factory for creating new services.
+        /// </summary>
+        /// <value>
+        /// The factory for creating new services.
+        /// </value>
+        internal IServiceFactory ServiceFactory
+        {
+            get { return _serviceFactory; }
+        }
+
         /// <summary>
         /// Gets the connection info.
         /// </summary>

+ 23 - 1
Renci.SshClient/Renci.SshNet/IServiceFactory.cs

@@ -1,7 +1,29 @@
-namespace Renci.SshNet
+using System;
+using Renci.SshNet.Common;
+
+namespace Renci.SshNet
 {
+    /// <summary>
+    /// Factory for creating new services.
+    /// </summary>
     internal interface IServiceFactory
     {
+        /// <summary>
+        /// Creates a new <see cref="ISession"/> with the specified <see cref="ConnectionInfo"/>.
+        /// </summary>
+        /// <param name="connectionInfo">The <see cref="ConnectionInfo"/> to use for creating a new session.</param>
+        /// <returns>
+        /// An <see cref="ISession"/> for the specified <see cref="ConnectionInfo"/>.
+        /// </returns>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <c>null</c>.</exception>
         ISession CreateSession(ConnectionInfo connectionInfo);
+
+        /// <summary>
+        /// Create a new <see cref="PipeStream"/>.
+        /// </summary>
+        /// <returns>
+        /// A <see cref="PipeStream"/>.
+        /// </returns>
+        PipeStream CreatePipeStream();
     }
 }

+ 4 - 3
Renci.SshClient/Renci.SshNet/ScpClient.NET.cs

@@ -28,7 +28,7 @@ namespace Renci.SshNet
             if (string.IsNullOrEmpty(path))
                 throw new ArgumentException("path");
 
-            using (var input = new PipeStream())
+            using (var input = ServiceFactory.CreatePipeStream())
             using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
@@ -39,8 +39,9 @@ namespace Renci.SshNet
 
                 channel.Open();
 
-                //  Send channel command request
-                channel.SendExecRequest(string.Format("scp -t \"{0}\"", path));
+                if (!channel.SendExecRequest(string.Format("scp -t \"{0}\"", path)))
+                    throw new SshException("Secure copy execution request was rejected by the server. Please consult the server logs.");
+
                 this.CheckReturnCode(input);
 
                 this.InternalUpload(channel, input, fileInfo, fileInfo.Name);

+ 20 - 2
Renci.SshClient/Renci.SshNet/ScpClient.cs

@@ -126,7 +126,24 @@ namespace Renci.SshNet
         /// connection info will be disposed when this instance is disposed.
         /// </remarks>
         private ScpClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo)
-            : base(connectionInfo, ownsConnectionInfo)
+            : this(connectionInfo, ownsConnectionInfo, new ServiceFactory())
+        {
+        }
+
+        /// <summary>
+        /// Initializes a new instance of the <see cref="ScpClient"/> class.
+        /// </summary>
+        /// <param name="connectionInfo">The connection info.</param>
+        /// <param name="ownsConnectionInfo">Specified whether this instance owns the connection info.</param>
+        /// <param name="serviceFactory">The factory to use for creating new services.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is null.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serviceFactory"/> is null.</exception>
+        /// <remarks>
+        /// If <paramref name="ownsConnectionInfo"/> is <c>true</c>, then the
+        /// connection info will be disposed when this instance is disposed.
+        /// </remarks>
+        internal ScpClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory)
+            : base(connectionInfo, ownsConnectionInfo, serviceFactory)
         {
             this.OperationTimeout = new TimeSpan(0, 0, 0, 0, -1);
             this.BufferSize = 1024 * 16;
@@ -151,7 +168,7 @@ namespace Renci.SshNet
         /// <param name="path">Remote host file name.</param>
         public void Upload(Stream source, string path)
         {
-            using (var input = new PipeStream())
+            using (var input = ServiceFactory.CreatePipeStream())
             using (var channel = this.Session.CreateChannelSession())
             {
                 channel.DataReceived += delegate(object sender, ChannelDataEventArgs e)
@@ -249,6 +266,7 @@ namespace Renci.SshNet
             var length = source.Length;
 
             this.SendData(channel, string.Format("C0644 {0} {1}\n", length, Path.GetFileName(filename)));
+            this.CheckReturnCode(input);
 
             var buffer = new byte[this.BufferSize];
 

+ 26 - 1
Renci.SshClient/Renci.SshNet/ServiceFactory.cs

@@ -1,10 +1,35 @@
-namespace Renci.SshNet
+using System;
+using Renci.SshNet.Common;
+
+namespace Renci.SshNet
 {
+    /// <summary>
+    /// Basic factory for creating new services.
+    /// </summary>
     internal class ServiceFactory : IServiceFactory
     {
+        /// <summary>
+        /// Creates a new <see cref="ISession"/> with the specified <see cref="ConnectionInfo"/>.
+        /// </summary>
+        /// <param name="connectionInfo">The <see cref="ConnectionInfo"/> to use for creating a new session.</param>
+        /// <returns>
+        /// An <see cref="ISession"/> for the specified <see cref="ConnectionInfo"/>.
+        /// </returns>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <c>null</c>.</exception>
         public ISession CreateSession(ConnectionInfo connectionInfo)
         {
             return new Session(connectionInfo);
         }
+
+        /// <summary>
+        /// Create a new <see cref="PipeStream"/>.
+        /// </summary>
+        /// <returns>
+        /// A <see cref="PipeStream"/>.
+        /// </returns>
+        public PipeStream CreatePipeStream()
+        {
+            return new PipeStream();
+        }
     }
 }

+ 2 - 3
Renci.SshClient/Renci.SshNet/SshClient.cs

@@ -135,10 +135,9 @@ namespace Renci.SshNet
         /// If <paramref name="ownsConnectionInfo"/> is <c>true</c>, then the
         /// connection info will be disposed when this instance is disposed.
         /// </remarks>
-        internal SshClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo)
-            : base(connectionInfo, ownsConnectionInfo, new ServiceFactory())
+        private SshClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo)
+            : this(connectionInfo, ownsConnectionInfo, new ServiceFactory())
         {
-            _forwardedPorts = new List<ForwardedPort>();
         }
 
         /// <summary>