Browse Source

Limit TimeSpan timeouts to Int32 MaxValue (#1321)

* Added guard clauses to various timeouts to ensure they don't exceed an Int32 in milliseconds.

* Fixed guard clauses.

* Updated build tags.

* Added guard clauses to various timeouts to ensure they don't exceed an Int32 in milliseconds.

* Fixed tests.

* Added additional tests.

* Replaced NoWarn with .editorconfig setting

* Fixed references to parameter names.
Jean-Sebastien Carle 1 year ago
parent
commit
4b9c3bf144

+ 4 - 0
src/Renci.SshNet/.editorconfig

@@ -159,3 +159,7 @@ dotnet_diagnostic.IDE0048.severity = none
 # IDE0305: Collection initialization can be simplified
 # https://learn.microsoft.com/en-us/dotnet/fundamentals/code-analysis/style-rules/ide0305
 dotnet_diagnostic.IDE0305.severity = none
+
+# IDE0005: Remove unnecessary using directives
+# https://learn.microsoft.com/en-us/dotnet/fundamentals/code-analysis/style-rules/ide0005
+dotnet_diagnostic.IDE0005.severity = suggestion

+ 2 - 2
src/Renci.SshNet/Abstractions/SocketAbstraction.cs

@@ -129,7 +129,7 @@ namespace Renci.SshNet.Abstractions
 
         public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout)
         {
-            socket.ReceiveTimeout = (int) timeout.TotalMilliseconds;
+            socket.ReceiveTimeout = timeout.AsTimeout(nameof(timeout));
 
             try
             {
@@ -274,7 +274,7 @@ namespace Renci.SshNet.Abstractions
             var totalBytesRead = 0;
             var totalBytesToRead = size;
 
-            socket.ReceiveTimeout = (int) readTimeout.TotalMilliseconds;
+            socket.ReceiveTimeout = readTimeout.AsTimeout(nameof(readTimeout));
 
             do
             {

+ 2 - 0
src/Renci.SshNet/BaseClient.cs

@@ -101,6 +101,8 @@ namespace Renci.SshNet
             {
                 CheckDisposed();
 
+                value.EnsureValidTimeout(nameof(KeepAliveInterval));
+
                 if (value == _keepAliveInterval)
                 {
                     return;

+ 46 - 0
src/Renci.SshNet/Common/TimeSpanExtensions.cs

@@ -0,0 +1,46 @@
+using System;
+
+namespace Renci.SshNet.Common
+{
+    /// <summary>
+    /// Provides extension methods for <see cref="TimeSpan"/>.
+    /// </summary>
+    internal static class TimeSpanExtensions
+    {
+        private const string OutOfRangeTimeoutMessage =
+            $"The timeout must represent a value between -1 and Int32.MaxValue milliseconds, inclusive.";
+
+        /// <summary>
+        /// Returns the specified <paramref name="timeSpan"/> as a valid timeout in milliseconds.
+        /// </summary>
+        /// <param name="timeSpan">The <see cref="TimeSpan"/> to ensure validity.</param>
+        /// <param name="callerMemberName">The name of the calling member.</param>
+        /// <exception cref="ArgumentOutOfRangeException">
+        /// Thrown when <paramref name="timeSpan"/> does not represent a value between -1 and <see cref="int.MaxValue"/>, inclusive.
+        /// </exception>
+        public static int AsTimeout(this TimeSpan timeSpan, string callerMemberName)
+        {
+            var timeoutInMilliseconds = timeSpan.TotalMilliseconds;
+            return timeoutInMilliseconds is < -1d or > int.MaxValue
+                       ? throw new ArgumentOutOfRangeException(callerMemberName, OutOfRangeTimeoutMessage)
+                       : (int) timeoutInMilliseconds;
+        }
+
+        /// <summary>
+        /// Ensures that the specified <paramref name="timeSpan"/> represents a valid timeout in milliseconds.
+        /// </summary>
+        /// <param name="timeSpan">The <see cref="TimeSpan"/> to ensure validity.</param>
+        /// <param name="callerMemberName">The name of the calling member.</param>
+        /// <exception cref="ArgumentOutOfRangeException">
+        /// Thrown when <paramref name="timeSpan"/> does not represent a value between -1 and <see cref="int.MaxValue"/>, inclusive.
+        /// </exception>
+        public static void EnsureValidTimeout(this TimeSpan timeSpan, string callerMemberName)
+        {
+            var timeoutInMilliseconds = timeSpan.TotalMilliseconds;
+            if (timeoutInMilliseconds is < -1d or > int.MaxValue)
+            {
+                throw new ArgumentOutOfRangeException(callerMemberName, OutOfRangeTimeoutMessage);
+            }
+        }
+    }
+}

+ 29 - 2
src/Renci.SshNet/ConnectionInfo.cs

@@ -44,6 +44,9 @@ namespace Renci.SshNet
         /// </value>
         private static readonly TimeSpan DefaultChannelCloseTimeout = TimeSpan.FromSeconds(1);
 
+        private TimeSpan _timeout;
+        private TimeSpan _channelCloseTimeout;
+
         /// <summary>
         /// Gets supported key exchange algorithms for this connection.
         /// </summary>
@@ -145,7 +148,19 @@ namespace Renci.SshNet
         /// <value>
         /// The connection timeout. The default value is 30 seconds.
         /// </value>
-        public TimeSpan Timeout { get; set; }
+        public TimeSpan Timeout
+        {
+            get
+            {
+                return _timeout;
+            }
+            set
+            {
+                value.EnsureValidTimeout(nameof(Timeout));
+
+                _timeout = value;
+            }
+        }
 
         /// <summary>
         /// Gets or sets the timeout to use when waiting for a server to acknowledge closing a channel.
@@ -157,7 +172,19 @@ namespace Renci.SshNet
         /// If a server does not send a <c>SSH_MSG_CHANNEL_CLOSE</c> message before the specified timeout
         /// elapses, the channel will be closed immediately.
         /// </remarks>
-        public TimeSpan ChannelCloseTimeout { get; set; }
+        public TimeSpan ChannelCloseTimeout
+        {
+            get
+            {
+                return _channelCloseTimeout;
+            }
+            set
+            {
+                value.EnsureValidTimeout(nameof(ChannelCloseTimeout));
+
+                _channelCloseTimeout = value;
+            }
+        }
 
         /// <summary>
         /// Gets or sets the character encoding.

+ 2 - 0
src/Renci.SshNet/ForwardedPort.cs

@@ -102,6 +102,8 @@ namespace Renci.SshNet
         /// <param name="timeout">The maximum amount of time to wait for pending requests to finish processing.</param>
         protected virtual void StopPort(TimeSpan timeout)
         {
+            timeout.EnsureValidTimeout(nameof(timeout));
+
             RaiseClosing();
 
             var session = Session;

+ 2 - 0
src/Renci.SshNet/ForwardedPortDynamic.cs

@@ -101,6 +101,8 @@ namespace Renci.SshNet
         /// <param name="timeout">The maximum amount of time to wait for pending requests to finish processing.</param>
         protected override void StopPort(TimeSpan timeout)
         {
+            timeout.EnsureValidTimeout(nameof(timeout));
+
             if (!ForwardedPortStatus.ToStopping(ref _status))
             {
                 return;

+ 2 - 0
src/Renci.SshNet/ForwardedPortLocal.cs

@@ -138,6 +138,8 @@ namespace Renci.SshNet
         /// <param name="timeout">The maximum amount of time to wait for pending requests to finish processing.</param>
         protected override void StopPort(TimeSpan timeout)
         {
+            timeout.EnsureValidTimeout(nameof(timeout));
+
             if (!ForwardedPortStatus.ToStopping(ref _status))
             {
                 return;

+ 2 - 0
src/Renci.SshNet/ForwardedPortRemote.cs

@@ -188,6 +188,8 @@ namespace Renci.SshNet
         /// <param name="timeout">The maximum amount of time to wait for the port to stop.</param>
         protected override void StopPort(TimeSpan timeout)
         {
+            timeout.EnsureValidTimeout(nameof(timeout));
+
             if (!ForwardedPortStatus.ToStopping(ref _status))
             {
                 return;

+ 1 - 7
src/Renci.SshNet/NetConfClient.cs

@@ -36,13 +36,7 @@ namespace Renci.SshNet
             }
             set
             {
-                var timeoutInMilliseconds = value.TotalMilliseconds;
-                if (timeoutInMilliseconds is < -1d or > int.MaxValue)
-                {
-                    throw new ArgumentOutOfRangeException(nameof(value), "The timeout must represent a value between -1 and Int32.MaxValue, inclusive.");
-                }
-
-                _operationTimeout = (int) timeoutInMilliseconds;
+                _operationTimeout = value.AsTimeout(nameof(OperationTimeout));
             }
         }
 

+ 14 - 1
src/Renci.SshNet/ScpClient.cs

@@ -38,6 +38,7 @@ namespace Renci.SshNet
         private static readonly Regex TimestampRe = new Regex(@"T(?<mtime>\d+) 0 (?<atime>\d+) 0", RegexOptions.Compiled);
 
         private IRemotePathTransformation _remotePathTransformation;
+        private TimeSpan _operationTimeout;
 
         /// <summary>
         /// Gets or sets the operation timeout.
@@ -46,7 +47,19 @@ namespace Renci.SshNet
         /// The timeout to wait until an operation completes. The default value is negative
         /// one (-1) milliseconds, which indicates an infinite time-out period.
         /// </value>
-        public TimeSpan OperationTimeout { get; set; }
+        public TimeSpan OperationTimeout
+        {
+            get
+            {
+                return _operationTimeout;
+            }
+            set
+            {
+                value.EnsureValidTimeout(nameof(OperationTimeout));
+
+                _operationTimeout = value;
+            }
+        }
 
         /// <summary>
         /// Gets or sets the size of the buffer.

+ 14 - 1
src/Renci.SshNet/Sftp/SftpFileStream.cs

@@ -35,6 +35,7 @@ namespace Renci.SshNet.Sftp
         private bool _canRead;
         private bool _canSeek;
         private bool _canWrite;
+        private TimeSpan _timeout;
 
         /// <summary>
         /// Gets a value indicating whether the current stream supports reading.
@@ -176,7 +177,19 @@ namespace Renci.SshNet.Sftp
         /// <value>
         /// The timeout.
         /// </value>
-        public TimeSpan Timeout { get; set; }
+        public TimeSpan Timeout
+        {
+            get
+            {
+                return _timeout;
+            }
+            set
+            {
+                value.EnsureValidTimeout(nameof(Timeout));
+
+                _timeout = value;
+            }
+        }
 
         private SftpFileStream(ISftpSession session, string path, FileAccess access, int bufferSize, byte[] handle, long position)
         {

+ 1 - 7
src/Renci.SshNet/SftpClient.cs

@@ -59,13 +59,7 @@ namespace Renci.SshNet
             {
                 CheckDisposed();
 
-                var timeoutInMilliseconds = value.TotalMilliseconds;
-                if (timeoutInMilliseconds is < -1d or > int.MaxValue)
-                {
-                    throw new ArgumentOutOfRangeException(nameof(value), "The timeout must represent a value between -1 and Int32.MaxValue, inclusive.");
-                }
-
-                _operationTimeout = (int) timeoutInMilliseconds;
+                _operationTimeout = value.AsTimeout(nameof(OperationTimeout));
             }
         }
 

+ 14 - 1
src/Renci.SshNet/SshCommand.cs

@@ -32,6 +32,7 @@ namespace Renci.SshNet
         private bool _hasError;
         private bool _isDisposed;
         private ChannelInputStream _inputStream;
+        private TimeSpan _commandTimeout;
 
         /// <summary>
         /// Gets the command text.
@@ -44,7 +45,19 @@ namespace Renci.SshNet
         /// <value>
         /// The command timeout.
         /// </value>
-        public TimeSpan CommandTimeout { get; set; }
+        public TimeSpan CommandTimeout
+        {
+            get
+            {
+                return _commandTimeout;
+            }
+            set
+            {
+                value.EnsureValidTimeout(nameof(CommandTimeout));
+
+                _commandTimeout = value;
+            }
+        }
 
         /// <summary>
         /// Gets the command exit status.

+ 103 - 0
test/Renci.SshNet.Tests/Classes/Common/TimeSpanExtensionsTest.cs

@@ -0,0 +1,103 @@
+using System;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Tests.Common;
+
+namespace Renci.SshNet.Tests.Classes.Common
+{
+    [TestClass]
+    public class TimeSpanExtensionsTest
+    {
+        [TestMethod]
+        public void AsTimeout_ValidTimeSpan_ReturnsExpectedMilliseconds()
+        {
+            var timeSpan = TimeSpan.FromSeconds(10);
+
+            var timeout = timeSpan.AsTimeout("TestMethodName");
+
+            Assert.AreEqual(10000, timeout);
+        }
+
+        [TestMethod]
+        [ExpectedException(typeof(ArgumentOutOfRangeException))]
+        public void AsTimeout_NegativeTimeSpan_ThrowsArgumentOutOfRangeException()
+        {
+            var timeSpan = TimeSpan.FromSeconds(-1);
+
+            timeSpan.AsTimeout("TestMethodName");
+        }
+
+        [TestMethod]
+        [ExpectedException(typeof(ArgumentOutOfRangeException))]
+        public void AsTimeout_TimeSpanExceedingMaxValue_ThrowsArgumentOutOfRangeException()
+        {
+            var timeSpan = TimeSpan.FromMilliseconds((double) int.MaxValue + 1);
+
+            timeSpan.AsTimeout("TestMethodName");
+        }
+
+        [TestMethod]
+        public void AsTimeout_ArgumentOutOfRangeException_HasCorrectInformation()
+        {
+
+            try
+            {
+                var timeSpan = TimeSpan.FromMilliseconds((double) int.MaxValue + 1);
+
+                timeSpan.AsTimeout("TestMethodName");
+            }
+            catch (ArgumentOutOfRangeException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue milliseconds, inclusive.", ex);
+                Assert.AreEqual("TestMethodName", ex.ParamName);
+            }
+        }
+
+        [TestMethod]
+        public void EnsureValidTimeout_ValidTimeSpan_DoesNotThrow()
+        {
+            var timeSpan = TimeSpan.FromSeconds(5);
+
+            timeSpan.EnsureValidTimeout("TestMethodName");
+        }
+
+        [TestMethod]
+        [ExpectedException(typeof(ArgumentOutOfRangeException))]
+        public void EnsureValidTimeout_NegativeTimeSpan_ThrowsArgumentOutOfRangeException()
+        {
+            var timeSpan = TimeSpan.FromSeconds(-1);
+
+            timeSpan.EnsureValidTimeout("TestMethodName");
+        }
+
+        [TestMethod]
+        [ExpectedException(typeof(ArgumentOutOfRangeException))]
+        public void EnsureValidTimeout_TimeSpanExceedingMaxValue_ThrowsArgumentOutOfRangeException()
+        {
+            var timeSpan = TimeSpan.FromMilliseconds((double) int.MaxValue + 1);
+
+            timeSpan.EnsureValidTimeout("TestMethodName");
+        }
+
+        [TestMethod]
+        public void EnsureValidTimeout_ArgumentOutOfRangeException_HasCorrectInformation()
+        {
+
+            try
+            {
+                var timeSpan = TimeSpan.FromMilliseconds((double) int.MaxValue + 1);
+
+                timeSpan.EnsureValidTimeout("TestMethodName");
+            }
+            catch (ArgumentOutOfRangeException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue milliseconds, inclusive.", ex);
+                Assert.AreEqual("TestMethodName", ex.ParamName);
+            }
+        }
+    }
+}

+ 78 - 0
test/Renci.SshNet.Tests/Classes/ConnectionInfoTest.cs

@@ -275,6 +275,84 @@ namespace Renci.SshNet.Tests.Classes
             Assert.AreEqual(port, connectionInfo.Port);
         }
 
+        [TestMethod]
+        [TestCategory("ConnectionInfo")]
+        public void Test_ConnectionInfo_Timeout_Valid()
+        {
+            var connectionInfo = new ConnectionInfo(Resources.HOST, int.Parse(Resources.PORT), Resources.USERNAME, ProxyTypes.None,
+                                                    Resources.HOST, int.Parse(Resources.PORT), Resources.USERNAME,
+                                                    Resources.PASSWORD, new KeyboardInteractiveAuthenticationMethod(Resources.USERNAME));
+
+            try
+            {
+                connectionInfo.Timeout = TimeSpan.FromMilliseconds(-2);
+            }
+            catch (ArgumentOutOfRangeException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue milliseconds, inclusive.", ex);
+
+                Assert.AreEqual("Timeout", ex.ParamName);
+            }
+
+            connectionInfo.Timeout = TimeSpan.FromMilliseconds(-1);
+            Assert.AreEqual(connectionInfo.Timeout, TimeSpan.FromMilliseconds(-1));
+
+            connectionInfo.Timeout = TimeSpan.FromMilliseconds(int.MaxValue);
+            Assert.AreEqual(connectionInfo.Timeout, TimeSpan.FromMilliseconds(int.MaxValue));
+
+            try
+            {
+                connectionInfo.Timeout = TimeSpan.FromMilliseconds((double)int.MaxValue + 1);
+            }
+            catch (ArgumentOutOfRangeException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue milliseconds, inclusive.", ex);
+
+                Assert.AreEqual("Timeout", ex.ParamName);
+            }
+        }
+
+        [TestMethod]
+        [TestCategory("ConnectionInfo")]
+        public void Test_ConnectionInfo_ChannelCloseTimeout_Valid()
+        {
+            var connectionInfo = new ConnectionInfo(Resources.HOST, int.Parse(Resources.PORT), Resources.USERNAME, ProxyTypes.None,
+                                                    Resources.HOST, int.Parse(Resources.PORT), Resources.USERNAME,
+                                                    Resources.PASSWORD, new KeyboardInteractiveAuthenticationMethod(Resources.USERNAME));
+
+            try
+            {
+                connectionInfo.ChannelCloseTimeout = TimeSpan.FromMilliseconds(-2);
+            }
+            catch (ArgumentOutOfRangeException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue milliseconds, inclusive.", ex);
+
+                Assert.AreEqual("ChannelCloseTimeout", ex.ParamName);
+            }
+
+            connectionInfo.ChannelCloseTimeout = TimeSpan.FromMilliseconds(-1);
+            Assert.AreEqual(connectionInfo.ChannelCloseTimeout, TimeSpan.FromMilliseconds(-1));
+
+            connectionInfo.ChannelCloseTimeout = TimeSpan.FromMilliseconds(int.MaxValue);
+            Assert.AreEqual(connectionInfo.ChannelCloseTimeout, TimeSpan.FromMilliseconds(int.MaxValue));
+
+            try
+            {
+                connectionInfo.ChannelCloseTimeout = TimeSpan.FromMilliseconds((double)int.MaxValue + 1);
+            }
+            catch (ArgumentOutOfRangeException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue milliseconds, inclusive.", ex);
+
+                Assert.AreEqual("ChannelCloseTimeout", ex.ParamName);
+            }
+        }
+
         [TestMethod]
         [TestCategory("ConnectionInfo")]
         public void ConstructorShouldThrowArgumentExceptionhenUsernameIsNull()

+ 6 - 4
test/Renci.SshNet.Tests/Classes/NetConfClientTest.cs

@@ -85,8 +85,9 @@ namespace Renci.SshNet.Tests.Classes
             catch (ArgumentOutOfRangeException ex)
             {
                 Assert.IsNull(ex.InnerException);
-                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue, inclusive.", ex);
-                Assert.AreEqual("value", ex.ParamName);
+                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue milliseconds, inclusive.", ex);
+
+                Assert.AreEqual("OperationTimeout", ex.ParamName);
             }
         }
 
@@ -104,8 +105,9 @@ namespace Renci.SshNet.Tests.Classes
             catch (ArgumentOutOfRangeException ex)
             {
                 Assert.IsNull(ex.InnerException);
-                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue, inclusive.", ex);
-                Assert.AreEqual("value", ex.ParamName);
+                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue milliseconds, inclusive.", ex);
+
+                Assert.AreEqual("OperationTimeout", ex.ParamName);
             }
         }
     }

+ 6 - 4
test/Renci.SshNet.Tests/Classes/SftpClientTest.cs

@@ -88,8 +88,9 @@ namespace Renci.SshNet.Tests.Classes
             catch (ArgumentOutOfRangeException ex)
             {
                 Assert.IsNull(ex.InnerException);
-                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue, inclusive.", ex);
-                Assert.AreEqual("value", ex.ParamName);
+                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue milliseconds, inclusive.", ex);
+
+                Assert.AreEqual("OperationTimeout", ex.ParamName);
             }
         }
 
@@ -107,8 +108,9 @@ namespace Renci.SshNet.Tests.Classes
             catch (ArgumentOutOfRangeException ex)
             {
                 Assert.IsNull(ex.InnerException);
-                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue, inclusive.", ex);
-                Assert.AreEqual("value", ex.ParamName);
+                ArgumentExceptionAssert.MessageEquals("The timeout must represent a value between -1 and Int32.MaxValue milliseconds, inclusive.", ex);
+
+                Assert.AreEqual("OperationTimeout", ex.ParamName);
             }
         }