Browse Source

Merge partial classes and complete IForwardedPort interface (#1223)

* Merge partial classes.
Complete IForwardedPort interface.
Expoe Closing event publicly on ForwardedPort.
Disable CA1030 (use events where appropriate).

* Merge partial classes.
Complete IForwardedPort interface.
Expoe Closing event publicly on ForwardedPort.
Disable CA1030 (use events where appropriate).

* Use local function and remove suppression.

* Compile regular expressions.

* Fix build, pfff.

---------

Co-authored-by: Wojciech Nagórski <wojtpl2@gmail.com>
Gert Driesen 2 years ago
parent
commit
42ff9206d2

+ 4 - 0
src/Renci.SshNet/.editorconfig

@@ -21,6 +21,10 @@ MA0053.public_class_should_be_sealed = false
 
 #### .NET Compiler Platform analysers rules ####
 
+# CA1030: Use events where appropriate
+# https://learn.microsoft.com/en-us/dotnet/fundamentals/code-analysis/quality-rules/ca1030
+dotnet_diagnostic.CA10310.severity = none
+
 # CA1031: Do not catch general exception types
 # https://learn.microsoft.com/en-us/dotnet/fundamentals/code-analysis/quality-rules/ca1031
 dotnet_diagnostic.CA1031.severity = none

+ 24 - 19
src/Renci.SshNet/ForwardedPort.cs

@@ -1,4 +1,5 @@
 using System;
+
 using Renci.SshNet.Common;
 
 namespace Renci.SshNet
@@ -16,28 +17,19 @@ namespace Renci.SshNet
         /// </value>
         internal ISession Session { get; set; }
 
-        /// <summary>
-        /// The <see cref="Closing"/> event occurs as the forwarded port is being stopped.
-        /// </summary>
-        internal event EventHandler Closing;
-
-        /// <summary>
-        /// The <see cref="IForwardedPort.Closing"/> event occurs as the forwarded port is being stopped.
-        /// </summary>
-        event EventHandler IForwardedPort.Closing
-        {
-            add { Closing += value; }
-            remove { Closing -= value; }
-        }
-
         /// <summary>
         /// Gets a value indicating whether port forwarding is started.
         /// </summary>
         /// <value>
-        /// <c>true</c> if port forwarding is started; otherwise, <c>false</c>.
+        /// <see langword="true"/> if port forwarding is started; otherwise, <see langword="false"/>.
         /// </value>
         public abstract bool IsStarted { get; }
 
+        /// <summary>
+        /// The <see cref="Closing"/> event occurs as the forwarded port is being stopped.
+        /// </summary>
+        public event EventHandler Closing;
+
         /// <summary>
         /// Occurs when an exception is thrown.
         /// </summary>
@@ -51,6 +43,8 @@ namespace Renci.SshNet
         /// <summary>
         /// Starts port forwarding.
         /// </summary>
+        /// <exception cref="InvalidOperationException">The current <see cref="ForwardedPort"/> is already started -or- is not linked to a SSH session.</exception>
+        /// <exception cref="SshConnectionException">The client is not connected.</exception>
         public virtual void Start()
         {
             CheckDisposed();
@@ -77,7 +71,9 @@ namespace Renci.SshNet
         /// <summary>
         /// Stops port forwarding.
         /// </summary>
+#pragma warning disable CA1716 // Identifiers should not match keywords
         public virtual void Stop()
+#pragma warning restore CA1716 // Identifiers should not match keywords
         {
             if (IsStarted)
             {
@@ -85,6 +81,15 @@ namespace Renci.SshNet
             }
         }
 
+        /// <summary>
+        /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
+        /// </summary>
+        public void Dispose()
+        {
+            Dispose(disposing: true);
+            GC.SuppressFinalize(this);
+        }
+
         /// <summary>
         /// Starts port forwarding.
         /// </summary>
@@ -100,22 +105,22 @@ namespace Renci.SshNet
             RaiseClosing();
 
             var session = Session;
-            if (session != null)
+            if (session is not null)
             {
                 session.ErrorOccured -= Session_ErrorOccured;
             }
         }
 
         /// <summary>
-        /// Releases unmanaged and - optionally - managed resources
+        /// Releases unmanaged and - optionally - managed resources.
         /// </summary>
-        /// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
+        /// <param name="disposing"><see langowrd="true"/> to release both managed and unmanaged resources; <see langowrd="false"/> to release only unmanaged resources.</param>
         protected virtual void Dispose(bool disposing)
         {
             if (disposing)
             {
                 var session = Session;
-                if (session != null)
+                if (session is not null)
                 {
                     StopPort(session.ConnectionInfo.Timeout);
                     Session = null;

+ 0 - 603
src/Renci.SshNet/ForwardedPortDynamic.NET.cs

@@ -1,603 +0,0 @@
-using System;
-using System.Linq;
-using System.Net;
-using System.Net.Sockets;
-using System.Text;
-using System.Threading;
-
-using Renci.SshNet.Abstractions;
-using Renci.SshNet.Channels;
-using Renci.SshNet.Common;
-
-namespace Renci.SshNet
-{
-    public partial class ForwardedPortDynamic
-    {
-        private Socket _listener;
-        private CountdownEvent _pendingChannelCountdown;
-
-        partial void InternalStart()
-        {
-            InitializePendingChannelCountdown();
-
-            var ip = IPAddress.Any;
-            if (!string.IsNullOrEmpty(BoundHost))
-            {
-                ip = DnsAbstraction.GetHostAddresses(BoundHost)[0];
-            }
-
-            var ep = new IPEndPoint(ip, (int) BoundPort);
-
-            _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp) {NoDelay = true};
-            _listener.Bind(ep);
-            _listener.Listen(5);
-
-            Session.ErrorOccured += Session_ErrorOccured;
-            Session.Disconnected += Session_Disconnected;
-
-            // consider port started when we're listening for inbound connections
-            _status = ForwardedPortStatus.Started;
-
-            StartAccept(e: null);
-        }
-
-        private void StartAccept(SocketAsyncEventArgs e)
-        {
-            if (e is null)
-            {
-                e = new SocketAsyncEventArgs();
-                e.Completed += AcceptCompleted;
-            }
-            else
-            {
-                // clear the socket as we're reusing the context object
-                e.AcceptSocket = null;
-            }
-
-            // only accept new connections while we are started
-            if (IsStarted)
-            {
-                try
-                {
-                    if (!_listener.AcceptAsync(e))
-                    {
-                        AcceptCompleted(sender: null, e);
-                    }
-                }
-                catch (ObjectDisposedException)
-                {
-                    if (_status == ForwardedPortStatus.Stopping || _status == ForwardedPortStatus.Stopped)
-                    {
-                        // ignore ObjectDisposedException while stopping or stopped
-                        return;
-                    }
-
-                    throw;
-                }
-            }
-        }
-
-        private void AcceptCompleted(object sender, SocketAsyncEventArgs e)
-        {
-            if (e.SocketError is SocketError.OperationAborted or SocketError.NotSocket)
-            {
-                // server was stopped
-                return;
-            }
-
-            // capture client socket
-            var clientSocket = e.AcceptSocket;
-
-            if (e.SocketError != SocketError.Success)
-            {
-                // accept new connection
-                StartAccept(e);
-
-                // dispose broken client socket
-                CloseClientSocket(clientSocket);
-                return;
-            }
-
-            // accept new connection
-            StartAccept(e);
-
-            // process connection
-            ProcessAccept(clientSocket);
-        }
-
-        private void ProcessAccept(Socket clientSocket)
-        {
-            // close the client socket if we're no longer accepting new connections
-            if (!IsStarted)
-            {
-                CloseClientSocket(clientSocket);
-                return;
-            }
-
-            // capture the countdown event that we're adding a count to, as we need to make sure that we'll be signaling
-            // that same instance; the instance field for the countdown event is re-initialized when the port is restarted
-            // and at that time there may still be pending requests
-            var pendingChannelCountdown = _pendingChannelCountdown;
-
-            pendingChannelCountdown.AddCount();
-
-            try
-            {
-                using (var channel = Session.CreateChannelDirectTcpip())
-                {
-                    channel.Exception += Channel_Exception;
-
-                    if (!HandleSocks(channel, clientSocket, Session.ConnectionInfo.Timeout))
-                    {
-                        CloseClientSocket(clientSocket);
-                        return;
-                    }
-
-                    // start receiving from client socket (and sending to server)
-                    channel.Bind();
-                }
-            }
-            catch (Exception exp)
-            {
-                RaiseExceptionEvent(exp);
-                CloseClientSocket(clientSocket);
-            }
-            finally
-            {
-                // take into account that CountdownEvent has since been disposed; when stopping the port we
-                // wait for a given time for the channels to close, but once that timeout period has elapsed
-                // the CountdownEvent will be disposed
-                try
-                {
-                    _ = pendingChannelCountdown.Signal();
-                }
-                catch (ObjectDisposedException)
-                {
-                }
-            }
-        }
-
-        /// <summary>
-        /// Initializes the <see cref="CountdownEvent"/>.
-        /// </summary>
-        /// <remarks>
-        /// <para>
-        /// When the port is started for the first time, a <see cref="CountdownEvent"/> is created with an initial count
-        /// of <c>1</c>.
-        /// </para>
-        /// <para>
-        /// On subsequent (re)starts, we'll dispose the current <see cref="CountdownEvent"/> and create a new one with
-        /// initial count of <c>1</c>.
-        /// </para>
-        /// </remarks>
-        private void InitializePendingChannelCountdown()
-        {
-            var original = Interlocked.Exchange(ref _pendingChannelCountdown, new CountdownEvent(1));
-            original?.Dispose();
-        }
-
-        private bool HandleSocks(IChannelDirectTcpip channel, Socket clientSocket, TimeSpan timeout)
-        {
-
-#pragma warning disable IDE0039 // Use lambda instead of local function to reduce allocations
-            // Create eventhandler which is to be invoked to interrupt a blocking receive
-            // when we're closing the forwarded port.
-            EventHandler closeClientSocket = (_, args) => CloseClientSocket(clientSocket);
-#pragma warning restore IDE0039 // Use lambda instead of local function to reduce allocations
-
-            Closing += closeClientSocket;
-
-            try
-            {
-                var version = SocketAbstraction.ReadByte(clientSocket, timeout);
-                switch (version)
-                {
-                    case -1:
-                        // SOCKS client closed connection
-                        return false;
-                    case 4:
-                        return HandleSocks4(clientSocket, channel, timeout);
-                    case 5:
-                        return HandleSocks5(clientSocket, channel, timeout);
-                    default:
-                        throw new NotSupportedException(string.Format("SOCKS version {0} is not supported.", version));
-                }
-            }
-            catch (SocketException ex)
-            {
-                // ignore exception thrown by interrupting the blocking receive as part of closing
-                // the forwarded port
-#if NETFRAMEWORK
-                if (ex.SocketErrorCode != SocketError.Interrupted)
-                {
-                    RaiseExceptionEvent(ex);
-                }
-#else
-                // Since .NET 5 the exception has been changed. 
-                // more info https://github.com/dotnet/runtime/issues/41585
-                if (ex.SocketErrorCode != SocketError.ConnectionAborted)
-                {
-                    RaiseExceptionEvent(ex);
-                }
-#endif
-                return false;
-            }
-            finally
-            {
-                // interrupt of blocking receive is now handled by channel (SOCKS4 and SOCKS5)
-                // or no longer necessary
-                Closing -= closeClientSocket;
-            }
-
-        }
-
-        private static void CloseClientSocket(Socket clientSocket)
-        {
-            if (clientSocket.Connected)
-            {
-                try
-                {
-                    clientSocket.Shutdown(SocketShutdown.Send);
-                }
-                catch (Exception)
-                {
-                    // ignore exception when client socket was already closed
-                }
-            }
-
-            clientSocket.Dispose();
-        }
-
-        /// <summary>
-        /// Interrupts the listener, and unsubscribes from <see cref="Session"/> events.
-        /// </summary>
-        partial void StopListener()
-        {
-            // close listener socket
-            _listener?.Dispose();
-
-            // unsubscribe from session events
-            var session = Session;
-            if (session != null)
-            {
-                session.ErrorOccured -= Session_ErrorOccured;
-                session.Disconnected -= Session_Disconnected;
-            }
-        }
-
-        /// <summary>
-        /// Waits for pending channels to close.
-        /// </summary>
-        /// <param name="timeout">The maximum time to wait for the pending channels to close.</param>
-        partial void InternalStop(TimeSpan timeout)
-        {
-            _ = _pendingChannelCountdown.Signal();
-
-            if (!_pendingChannelCountdown.Wait(timeout))
-            {
-                // TODO: log as warning
-                DiagnosticAbstraction.Log("Timeout waiting for pending channels in dynamic forwarded port to close.");
-            }
-
-        }
-
-        partial void InternalDispose(bool disposing)
-        {
-            if (disposing)
-            {
-                var listener = _listener;
-                if (listener != null)
-                {
-                    _listener = null;
-                    listener.Dispose();
-                }
-
-                var pendingRequestsCountdown = _pendingChannelCountdown;
-                if (pendingRequestsCountdown != null)
-                {
-                    _pendingChannelCountdown = null;
-                    pendingRequestsCountdown.Dispose();
-                }
-            }
-        }
-
-        private void Session_Disconnected(object sender, EventArgs e)
-        {
-            var session = Session;
-            if (session != null)
-            {
-                StopPort(session.ConnectionInfo.Timeout);
-            }
-        }
-
-        private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
-        {
-            var session = Session;
-            if (session != null)
-            {
-                StopPort(session.ConnectionInfo.Timeout);
-            }
-        }
-
-        private void Channel_Exception(object sender, ExceptionEventArgs e)
-        {
-            RaiseExceptionEvent(e.Exception);
-        }
-
-        private bool HandleSocks4(Socket socket, IChannelDirectTcpip channel, TimeSpan timeout)
-        {
-            var commandCode = SocketAbstraction.ReadByte(socket, timeout);
-            if (commandCode == -1)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            //  TODO:   See what need to be done depends on the code
-
-            var portBuffer = new byte[2];
-            if (SocketAbstraction.Read(socket, portBuffer, 0, portBuffer.Length, timeout) == 0)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            var port = Pack.BigEndianToUInt16(portBuffer);
-
-            var ipBuffer = new byte[4];
-            if (SocketAbstraction.Read(socket, ipBuffer, 0, ipBuffer.Length, timeout) == 0)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            var ipAddress = new IPAddress(ipBuffer);
-
-            var username = ReadString(socket, timeout);
-            if (username is null)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            var host = ipAddress.ToString();
-
-            RaiseRequestReceived(host, port);
-
-            channel.Open(host, port, this, socket);
-
-            SocketAbstraction.SendByte(socket, 0x00);
-
-            if (channel.IsOpen)
-            {
-                SocketAbstraction.SendByte(socket, 0x5a);
-                SocketAbstraction.Send(socket, portBuffer, 0, portBuffer.Length);
-                SocketAbstraction.Send(socket, ipBuffer, 0, ipBuffer.Length);
-                return true;
-            }
-
-            // signal that request was rejected or failed
-            SocketAbstraction.SendByte(socket, 0x5b);
-            return false;
-        }
-
-        private bool HandleSocks5(Socket socket, IChannelDirectTcpip channel, TimeSpan timeout)
-        {
-            var authenticationMethodsCount = SocketAbstraction.ReadByte(socket, timeout);
-            if (authenticationMethodsCount == -1)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            var authenticationMethods = new byte[authenticationMethodsCount];
-            if (SocketAbstraction.Read(socket, authenticationMethods, 0, authenticationMethods.Length, timeout) == 0)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            if (authenticationMethods.Min() == 0)
-            {
-                // no user authentication is one of the authentication methods supported
-                // by the SOCKS client
-                SocketAbstraction.Send(socket, new byte[] { 0x05, 0x00 }, 0, 2);
-            }
-            else
-            {
-                // the SOCKS client requires authentication, which we currently do not support
-                SocketAbstraction.Send(socket, new byte[] { 0x05, 0xFF }, 0, 2);
-
-                // we continue business as usual but expect the client to close the connection
-                // so one of the subsequent reads should return -1 signaling that the client
-                // has effectively closed the connection
-            }
-
-            var version = SocketAbstraction.ReadByte(socket, timeout);
-            if (version == -1)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            if (version != 5)
-            {
-                throw new ProxyException("SOCKS5: Version 5 is expected.");
-            }
-
-            var commandCode = SocketAbstraction.ReadByte(socket, timeout);
-            if (commandCode == -1)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            var reserved = SocketAbstraction.ReadByte(socket, timeout);
-            if (reserved == -1)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            if (reserved != 0)
-            {
-                throw new ProxyException("SOCKS5: 0 is expected for reserved byte.");
-            }
-
-            var addressType = SocketAbstraction.ReadByte(socket, timeout);
-            if (addressType == -1)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            var host = GetSocks5Host(addressType, socket, timeout);
-            if (host is null)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            var portBuffer = new byte[2];
-            if (SocketAbstraction.Read(socket, portBuffer, 0, portBuffer.Length, timeout) == 0)
-            {
-                // SOCKS client closed connection
-                return false;
-            }
-
-            var port = Pack.BigEndianToUInt16(portBuffer);
-
-            RaiseRequestReceived(host, port);
-
-            channel.Open(host, port, this, socket);
-
-            var socksReply = CreateSocks5Reply(channel.IsOpen);
-
-            SocketAbstraction.Send(socket, socksReply, 0, socksReply.Length);
-
-            return true;
-        }
-
-        private static string GetSocks5Host(int addressType, Socket socket, TimeSpan timeout)
-        {
-            switch (addressType)
-            {
-                case 0x01: // IPv4
-                    {
-                        var addressBuffer = new byte[4];
-                        if (SocketAbstraction.Read(socket, addressBuffer, 0, 4, timeout) == 0)
-                        {
-                            // SOCKS client closed connection
-                            return null;
-                        }
-
-                        var ipv4 = new IPAddress(addressBuffer);
-                        return ipv4.ToString();
-                    }
-                case 0x03: // Domain name
-                    {
-                        var length = SocketAbstraction.ReadByte(socket, timeout);
-                        if (length == -1)
-                        {
-                            // SOCKS client closed connection
-                            return null;
-                        }
-                        var addressBuffer = new byte[length];
-                        if (SocketAbstraction.Read(socket, addressBuffer, 0, addressBuffer.Length, timeout) == 0)
-                        {
-                            // SOCKS client closed connection
-                            return null;
-                        }
-
-                        var hostName = SshData.Ascii.GetString(addressBuffer, 0, addressBuffer.Length);
-                        return hostName;
-                    }
-                case 0x04: // IPv6
-                    {
-                        var addressBuffer = new byte[16];
-                        if (SocketAbstraction.Read(socket, addressBuffer, 0, 16, timeout) == 0)
-                        {
-                            // SOCKS client closed connection
-                            return null;
-                        }
-
-                        var ipv6 = new IPAddress(addressBuffer);
-                        return ipv6.ToString();
-                    }
-                default:
-                    throw new ProxyException(string.Format("SOCKS5: Address type '{0}' is not supported.", addressType));
-            }
-        }
-
-        private static byte[] CreateSocks5Reply(bool channelOpen)
-        {
-            var socksReply = new byte
-                [
-                    // SOCKS version
-                    1 +
-                    // Reply field
-                    1 +
-                    // Reserved; fixed: 0x00
-                    1 +
-                    // Address type; fixed: 0x01
-                    1 +
-                    // IPv4 server bound address; fixed: {0x00, 0x00, 0x00, 0x00}
-                    4 +
-                    // server bound port; fixed: {0x00, 0x00}
-                    2
-                ];
-
-            socksReply[0] = 0x05;
-
-            if (channelOpen)
-            {
-                socksReply[1] = 0x00; // succeeded
-            }
-            else
-            {
-                socksReply[1] = 0x01; // general SOCKS server failure
-            }
-
-            // reserved
-            socksReply[2] = 0x00;
-
-            // IPv4 address type
-            socksReply[3] = 0x01;
-
-            return socksReply;
-        }
-
-        /// <summary>
-        /// Reads a null terminated string from a socket.
-        /// </summary>
-        /// <param name="socket">The <see cref="Socket"/> to read from.</param>
-        /// <param name="timeout">The timeout to apply to individual reads.</param>
-        /// <returns>
-        /// The <see cref="string"/> read, or <c>null</c> when the socket was closed.
-        /// </returns>
-        private static string ReadString(Socket socket, TimeSpan timeout)
-        {
-            var text = new StringBuilder();
-            var buffer = new byte[1];
-            while (true)
-            {
-                if (SocketAbstraction.Read(socket, buffer, 0, 1, timeout) == 0)
-                {
-                    // SOCKS client closed connection
-                    return null;
-                }
-
-                var byteRead = buffer[0];
-                if (byteRead == 0)
-                {
-                    // end of the string
-                    break;
-                }
-
-                _ = text.Append((char) byteRead);
-            }
-
-            return text.ToString();
-        }
-    }
-}

+ 604 - 23
src/Renci.SshNet/ForwardedPortDynamic.cs

@@ -1,4 +1,14 @@
 using System;
+using System.Globalization;
+using System.Linq;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+using System.Threading;
+
+using Renci.SshNet.Abstractions;
+using Renci.SshNet.Channels;
+using Renci.SshNet.Common;
 
 namespace Renci.SshNet
 {
@@ -6,7 +16,7 @@ namespace Renci.SshNet
     /// Provides functionality for forwarding connections from the client to destination servers via the SSH server,
     /// also known as dynamic port forwarding.
     /// </summary>
-    public partial class ForwardedPortDynamic : ForwardedPort
+    public class ForwardedPortDynamic : ForwardedPort
     {
         private ForwardedPortStatus _status;
 
@@ -14,7 +24,7 @@ namespace Renci.SshNet
         /// Holds a value indicating whether the current instance is disposed.
         /// </summary>
         /// <value>
-        /// <c>true</c> if the current instance is disposed; otherwise, <c>false</c>.
+        /// <see langword="true"/> if the current instance is disposed; otherwise, <see langword="false"/>.
         /// </value>
         private bool _isDisposed;
 
@@ -28,11 +38,14 @@ namespace Renci.SshNet
         /// </summary>
         public uint BoundPort { get; }
 
+        private Socket _listener;
+        private CountdownEvent _pendingChannelCountdown;
+
         /// <summary>
         /// Gets a value indicating whether port forwarding is started.
         /// </summary>
         /// <value>
-        /// <c>true</c> if port forwarding is started; otherwise, <c>false</c>.
+        /// <see langword="true"/> if port forwarding is started; otherwise, <see langword="false"/>.
         /// </value>
         public override bool IsStarted
         {
@@ -112,51 +125,619 @@ namespace Renci.SshNet
         /// <exception cref="ObjectDisposedException">The current instance is disposed.</exception>
         protected override void CheckDisposed()
         {
+#if NET7_0_OR_GREATER
+            ObjectDisposedException.ThrowIf(_isDisposed, this);
+#else
             if (_isDisposed)
             {
                 throw new ObjectDisposedException(GetType().FullName);
             }
+#endif // NET7_0_OR_GREATER
         }
 
-        partial void InternalStart();
+        /// <summary>
+        /// Releases unmanaged and - optionally - managed resources.
+        /// </summary>
+        /// <param name="disposing"><see langword="true"/> to release both managed and unmanaged resources; <see langword="false"/> to release only unmanaged resources.</param>
+        protected override void Dispose(bool disposing)
+        {
+            if (_isDisposed)
+            {
+                return;
+            }
+
+            base.Dispose(disposing);
+
+            InternalDispose(disposing);
+            _isDisposed = true;
+        }
+
+        private void InternalStart()
+        {
+            InitializePendingChannelCountdown();
+
+            var ip = IPAddress.Any;
+            if (!string.IsNullOrEmpty(BoundHost))
+            {
+                ip = DnsAbstraction.GetHostAddresses(BoundHost)[0];
+            }
+
+            var ep = new IPEndPoint(ip, (int) BoundPort);
+
+            _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
+            _listener.Bind(ep);
+            _listener.Listen(5);
+
+            Session.ErrorOccured += Session_ErrorOccured;
+            Session.Disconnected += Session_Disconnected;
+
+            // consider port started when we're listening for inbound connections
+            _status = ForwardedPortStatus.Started;
+
+            StartAccept(e: null);
+        }
+
+        private void StartAccept(SocketAsyncEventArgs e)
+        {
+            if (e is null)
+            {
+#pragma warning disable CA2000 // Dispose objects before losing scope
+                e = new SocketAsyncEventArgs();
+#pragma warning restore CA2000 // Dispose objects before losing scope
+                e.Completed += AcceptCompleted;
+            }
+            else
+            {
+                // clear the socket as we're reusing the context object
+                e.AcceptSocket = null;
+            }
+
+            // only accept new connections while we are started
+            if (IsStarted)
+            {
+                try
+                {
+                    if (!_listener.AcceptAsync(e))
+                    {
+                        AcceptCompleted(sender: null, e);
+                    }
+                }
+                catch (ObjectDisposedException)
+                {
+                    if (_status == ForwardedPortStatus.Stopping || _status == ForwardedPortStatus.Stopped)
+                    {
+                        // ignore ObjectDisposedException while stopping or stopped
+                        return;
+                    }
+
+                    throw;
+                }
+            }
+        }
+
+        private void AcceptCompleted(object sender, SocketAsyncEventArgs e)
+        {
+            if (e.SocketError is SocketError.OperationAborted or SocketError.NotSocket)
+            {
+                // server was stopped
+                return;
+            }
+
+            // capture client socket
+            var clientSocket = e.AcceptSocket;
+
+            if (e.SocketError != SocketError.Success)
+            {
+                // accept new connection
+                StartAccept(e);
+
+                // dispose broken client socket
+                CloseClientSocket(clientSocket);
+                return;
+            }
+
+            // accept new connection
+            StartAccept(e);
+
+            // process connection
+            ProcessAccept(clientSocket);
+        }
+
+        private void ProcessAccept(Socket clientSocket)
+        {
+            // close the client socket if we're no longer accepting new connections
+            if (!IsStarted)
+            {
+                CloseClientSocket(clientSocket);
+                return;
+            }
+
+            // capture the countdown event that we're adding a count to, as we need to make sure that we'll be signaling
+            // that same instance; the instance field for the countdown event is re-initialized when the port is restarted
+            // and at that time there may still be pending requests
+            var pendingChannelCountdown = _pendingChannelCountdown;
+
+            pendingChannelCountdown.AddCount();
+
+            try
+            {
+                using (var channel = Session.CreateChannelDirectTcpip())
+                {
+                    channel.Exception += Channel_Exception;
+
+                    if (!HandleSocks(channel, clientSocket, Session.ConnectionInfo.Timeout))
+                    {
+                        CloseClientSocket(clientSocket);
+                        return;
+                    }
+
+                    // start receiving from client socket (and sending to server)
+                    channel.Bind();
+                }
+            }
+            catch (Exception exp)
+            {
+                RaiseExceptionEvent(exp);
+                CloseClientSocket(clientSocket);
+            }
+            finally
+            {
+                // take into account that CountdownEvent has since been disposed; when stopping the port we
+                // wait for a given time for the channels to close, but once that timeout period has elapsed
+                // the CountdownEvent will be disposed
+                try
+                {
+                    _ = pendingChannelCountdown.Signal();
+                }
+                catch (ObjectDisposedException)
+                {
+                    // Ignore any ObjectDisposedException
+                }
+            }
+        }
 
         /// <summary>
-        /// Stops the listener.
+        /// Initializes the <see cref="CountdownEvent"/>.
         /// </summary>
-        partial void StopListener();
+        /// <remarks>
+        /// <para>
+        /// When the port is started for the first time, a <see cref="CountdownEvent"/> is created with an initial count
+        /// of <c>1</c>.
+        /// </para>
+        /// <para>
+        /// On subsequent (re)starts, we'll dispose the current <see cref="CountdownEvent"/> and create a new one with
+        /// initial count of <c>1</c>.
+        /// </para>
+        /// </remarks>
+        private void InitializePendingChannelCountdown()
+        {
+            var original = Interlocked.Exchange(ref _pendingChannelCountdown, new CountdownEvent(1));
+            original?.Dispose();
+        }
+
+        private bool HandleSocks(IChannelDirectTcpip channel, Socket clientSocket, TimeSpan timeout)
+        {
+            Closing += closeClientSocket;
+
+            try
+            {
+                var version = SocketAbstraction.ReadByte(clientSocket, timeout);
+                switch (version)
+                {
+                    case -1:
+                        // SOCKS client closed connection
+                        return false;
+                    case 4:
+                        return HandleSocks4(clientSocket, channel, timeout);
+                    case 5:
+                        return HandleSocks5(clientSocket, channel, timeout);
+                    default:
+                        throw new NotSupportedException(string.Format(CultureInfo.InvariantCulture, "SOCKS version {0} is not supported.", version));
+                }
+            }
+            catch (SocketException ex)
+            {
+                // ignore exception thrown by interrupting the blocking receive as part of closing
+                // the forwarded port
+#if NETFRAMEWORK
+                if (ex.SocketErrorCode != SocketError.Interrupted)
+                {
+                    RaiseExceptionEvent(ex);
+                }
+#else
+                // Since .NET 5 the exception has been changed.
+                // more info https://github.com/dotnet/runtime/issues/41585
+                if (ex.SocketErrorCode != SocketError.ConnectionAborted)
+                {
+                    RaiseExceptionEvent(ex);
+                }
+#endif
+                return false;
+            }
+            finally
+            {
+                // interrupt of blocking receive is now handled by channel (SOCKS4 and SOCKS5)
+                // or no longer necessary
+                Closing -= closeClientSocket;
+            }
+
+            void closeClientSocket(object _, EventArgs args)
+            {
+                CloseClientSocket(clientSocket);
+            };
+        }
+
+        private static void CloseClientSocket(Socket clientSocket)
+        {
+            if (clientSocket.Connected)
+            {
+                try
+                {
+                    clientSocket.Shutdown(SocketShutdown.Send);
+                }
+                catch (Exception)
+                {
+                    // ignore exception when client socket was already closed
+                }
+            }
+
+            clientSocket.Dispose();
+        }
 
         /// <summary>
-        /// Waits for pending requests to finish, and channels to close.
+        /// Interrupts the listener, and unsubscribes from <see cref="Session"/> events.
         /// </summary>
-        /// <param name="timeout">The maximum time to wait for the forwarded port to stop.</param>
-        partial void InternalStop(TimeSpan timeout);
+        private void StopListener()
+        {
+            // close listener socket
+            _listener?.Dispose();
+
+            // unsubscribe from session events
+            var session = Session;
+            if (session is not null)
+            {
+                session.ErrorOccured -= Session_ErrorOccured;
+                session.Disconnected -= Session_Disconnected;
+            }
+        }
 
         /// <summary>
-        /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
+        /// Waits for pending channels to close.
         /// </summary>
-        public void Dispose()
+        /// <param name="timeout">The maximum time to wait for the pending channels to close.</param>
+        private void InternalStop(TimeSpan timeout)
+        {
+            _ = _pendingChannelCountdown.Signal();
+
+            if (!_pendingChannelCountdown.Wait(timeout))
+            {
+                // TODO: log as warning
+                DiagnosticAbstraction.Log("Timeout waiting for pending channels in dynamic forwarded port to close.");
+            }
+        }
+
+        private void InternalDispose(bool disposing)
+        {
+            if (disposing)
+            {
+                var listener = _listener;
+                if (listener is not null)
+                {
+                    _listener = null;
+                    listener.Dispose();
+                }
+
+                var pendingRequestsCountdown = _pendingChannelCountdown;
+                if (pendingRequestsCountdown is not null)
+                {
+                    _pendingChannelCountdown = null;
+                    pendingRequestsCountdown.Dispose();
+                }
+            }
+        }
+
+        private void Session_Disconnected(object sender, EventArgs e)
+        {
+            var session = Session;
+            if (session is not null)
+            {
+                StopPort(session.ConnectionInfo.Timeout);
+            }
+        }
+
+        private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
+        {
+            var session = Session;
+            if (session is not null)
+            {
+                StopPort(session.ConnectionInfo.Timeout);
+            }
+        }
+
+        private void Channel_Exception(object sender, ExceptionEventArgs e)
         {
-            Dispose(disposing: true);
-            GC.SuppressFinalize(this);
+            RaiseExceptionEvent(e.Exception);
         }
 
-        partial void InternalDispose(bool disposing);
+        private bool HandleSocks4(Socket socket, IChannelDirectTcpip channel, TimeSpan timeout)
+        {
+            var commandCode = SocketAbstraction.ReadByte(socket, timeout);
+            if (commandCode == -1)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            var portBuffer = new byte[2];
+            if (SocketAbstraction.Read(socket, portBuffer, 0, portBuffer.Length, timeout) == 0)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            var port = Pack.BigEndianToUInt16(portBuffer);
+
+            var ipBuffer = new byte[4];
+            if (SocketAbstraction.Read(socket, ipBuffer, 0, ipBuffer.Length, timeout) == 0)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            var ipAddress = new IPAddress(ipBuffer);
+
+            var username = ReadString(socket, timeout);
+            if (username is null)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            var host = ipAddress.ToString();
+
+            RaiseRequestReceived(host, port);
+
+            channel.Open(host, port, this, socket);
+
+            SocketAbstraction.SendByte(socket, 0x00);
+
+            if (channel.IsOpen)
+            {
+                SocketAbstraction.SendByte(socket, 0x5a);
+                SocketAbstraction.Send(socket, portBuffer, 0, portBuffer.Length);
+                SocketAbstraction.Send(socket, ipBuffer, 0, ipBuffer.Length);
+                return true;
+            }
+
+            // signal that request was rejected or failed
+            SocketAbstraction.SendByte(socket, 0x5b);
+            return false;
+        }
+
+        private bool HandleSocks5(Socket socket, IChannelDirectTcpip channel, TimeSpan timeout)
+        {
+            var authenticationMethodsCount = SocketAbstraction.ReadByte(socket, timeout);
+            if (authenticationMethodsCount == -1)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            var authenticationMethods = new byte[authenticationMethodsCount];
+            if (SocketAbstraction.Read(socket, authenticationMethods, 0, authenticationMethods.Length, timeout) == 0)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            if (authenticationMethods.Min() == 0)
+            {
+                // no user authentication is one of the authentication methods supported
+                // by the SOCKS client
+                SocketAbstraction.Send(socket, new byte[] { 0x05, 0x00 }, 0, 2);
+            }
+            else
+            {
+                // the SOCKS client requires authentication, which we currently do not support
+                SocketAbstraction.Send(socket, new byte[] { 0x05, 0xFF }, 0, 2);
+
+                // we continue business as usual but expect the client to close the connection
+                // so one of the subsequent reads should return -1 signaling that the client
+                // has effectively closed the connection
+            }
+
+            var version = SocketAbstraction.ReadByte(socket, timeout);
+            if (version == -1)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            if (version != 5)
+            {
+                throw new ProxyException("SOCKS5: Version 5 is expected.");
+            }
+
+            var commandCode = SocketAbstraction.ReadByte(socket, timeout);
+            if (commandCode == -1)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            var reserved = SocketAbstraction.ReadByte(socket, timeout);
+            if (reserved == -1)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            if (reserved != 0)
+            {
+                throw new ProxyException("SOCKS5: 0 is expected for reserved byte.");
+            }
+
+            var addressType = SocketAbstraction.ReadByte(socket, timeout);
+            if (addressType == -1)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            var host = GetSocks5Host(addressType, socket, timeout);
+            if (host is null)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            var portBuffer = new byte[2];
+            if (SocketAbstraction.Read(socket, portBuffer, 0, portBuffer.Length, timeout) == 0)
+            {
+                // SOCKS client closed connection
+                return false;
+            }
+
+            var port = Pack.BigEndianToUInt16(portBuffer);
+
+            RaiseRequestReceived(host, port);
+
+            channel.Open(host, port, this, socket);
+
+            var socksReply = CreateSocks5Reply(channel.IsOpen);
+
+            SocketAbstraction.Send(socket, socksReply, 0, socksReply.Length);
+
+            return true;
+        }
+
+        private static string GetSocks5Host(int addressType, Socket socket, TimeSpan timeout)
+        {
+            switch (addressType)
+            {
+                case 0x01: // IPv4
+                    {
+                        var addressBuffer = new byte[4];
+                        if (SocketAbstraction.Read(socket, addressBuffer, 0, 4, timeout) == 0)
+                        {
+                            // SOCKS client closed connection
+                            return null;
+                        }
+
+                        var ipv4 = new IPAddress(addressBuffer);
+                        return ipv4.ToString();
+                    }
+
+                case 0x03: // Domain name
+                    {
+                        var length = SocketAbstraction.ReadByte(socket, timeout);
+                        if (length == -1)
+                        {
+                            // SOCKS client closed connection
+                            return null;
+                        }
+
+                        var addressBuffer = new byte[length];
+                        if (SocketAbstraction.Read(socket, addressBuffer, 0, addressBuffer.Length, timeout) == 0)
+                        {
+                            // SOCKS client closed connection
+                            return null;
+                        }
+
+                        var hostName = SshData.Ascii.GetString(addressBuffer, 0, addressBuffer.Length);
+                        return hostName;
+                    }
+
+                case 0x04: // IPv6
+                    {
+                        var addressBuffer = new byte[16];
+                        if (SocketAbstraction.Read(socket, addressBuffer, 0, 16, timeout) == 0)
+                        {
+                            // SOCKS client closed connection
+                            return null;
+                        }
+
+                        var ipv6 = new IPAddress(addressBuffer);
+                        return ipv6.ToString();
+                    }
+
+                default:
+                    throw new ProxyException(string.Format(CultureInfo.InvariantCulture, "SOCKS5: Address type '{0}' is not supported.", addressType));
+            }
+        }
+
+        private static byte[] CreateSocks5Reply(bool channelOpen)
+        {
+            var socksReply = new byte[// SOCKS version
+                                      1 +
+
+                                      // Reply field
+                                      1 +
+
+                                      // Reserved; fixed: 0x00
+                                      1 +
+
+                                      // Address type; fixed: 0x01
+                                      1 +
+
+                                      // IPv4 server bound address; fixed: {0x00, 0x00, 0x00, 0x00}
+                                      4 +
+
+                                      // server bound port; fixed: {0x00, 0x00}
+                                      2];
+
+            socksReply[0] = 0x05;
+
+            if (channelOpen)
+            {
+                socksReply[1] = 0x00; // succeeded
+            }
+            else
+            {
+                socksReply[1] = 0x01; // general SOCKS server failure
+            }
+
+            // reserved
+            socksReply[2] = 0x00;
+
+            // IPv4 address type
+            socksReply[3] = 0x01;
+
+            return socksReply;
+        }
 
         /// <summary>
-        /// Releases unmanaged and - optionally - managed resources.
+        /// Reads a null terminated string from a socket.
         /// </summary>
-        /// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
-        protected override void Dispose(bool disposing)
+        /// <param name="socket">The <see cref="Socket"/> to read from.</param>
+        /// <param name="timeout">The timeout to apply to individual reads.</param>
+        /// <returns>
+        /// The <see cref="string"/> read, or <see langword="null"/> when the socket was closed.
+        /// </returns>
+        private static string ReadString(Socket socket, TimeSpan timeout)
         {
-            if (_isDisposed)
+            var text = new StringBuilder();
+            var buffer = new byte[1];
+
+            while (true)
             {
-                return;
-            }
+                if (SocketAbstraction.Read(socket, buffer, 0, 1, timeout) == 0)
+                {
+                    // SOCKS client closed connection
+                    return null;
+                }
 
-            base.Dispose(disposing);
+                var byteRead = buffer[0];
+                if (byteRead == 0)
+                {
+                    // end of the string
+                    break;
+                }
 
-            InternalDispose(disposing);
-            _isDisposed = true;
+                _ = text.Append((char) byteRead);
+            }
+
+            return text.ToString();
         }
 
         /// <summary>

+ 0 - 264
src/Renci.SshNet/ForwardedPortLocal.NET.cs

@@ -1,264 +0,0 @@
-using System;
-using System.Net;
-using System.Net.Sockets;
-using System.Threading;
-
-using Renci.SshNet.Abstractions;
-using Renci.SshNet.Common;
-
-namespace Renci.SshNet
-{
-    public partial class ForwardedPortLocal
-    {
-        private Socket _listener;
-        private CountdownEvent _pendingChannelCountdown;
-
-        partial void InternalStart()
-        {
-            var addr = DnsAbstraction.GetHostAddresses(BoundHost)[0];
-            var ep = new IPEndPoint(addr, (int) BoundPort);
-
-            _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp) {NoDelay = true};
-            _listener.Bind(ep);
-            _listener.Listen(5);
-
-            // update bound port (in case original was passed as zero)
-            BoundPort = (uint)((IPEndPoint)_listener.LocalEndPoint).Port;
-
-            Session.ErrorOccured += Session_ErrorOccured;
-            Session.Disconnected += Session_Disconnected;
-
-            InitializePendingChannelCountdown();
-
-            // consider port started when we're listening for inbound connections
-            _status = ForwardedPortStatus.Started;
-
-            StartAccept(e: null);
-        }
-
-        private void StartAccept(SocketAsyncEventArgs e)
-        {
-            if (e is null)
-            {
-                e = new SocketAsyncEventArgs();
-                e.Completed += AcceptCompleted;
-            }
-            else
-            {
-                // clear the socket as we're reusing the context object
-                e.AcceptSocket = null;
-            }
-
-            // only accept new connections while we are started
-            if (IsStarted)
-            {
-                try
-                {
-                    if (!_listener.AcceptAsync(e))
-                    {
-                        AcceptCompleted(sender: null, e);
-                    }
-                }
-                catch (ObjectDisposedException)
-                {
-                    if (_status == ForwardedPortStatus.Stopping || _status == ForwardedPortStatus.Stopped)
-                    {
-                        // ignore ObjectDisposedException while stopping or stopped
-                        return;
-                    }
-
-                    throw;
-                }
-            }
-        }
-
-        private void AcceptCompleted(object sender, SocketAsyncEventArgs e)
-        {
-            if (e.SocketError is SocketError.OperationAborted or SocketError.NotSocket)
-            {
-                // server was stopped
-                return;
-            }
-
-            // capture client socket
-            var clientSocket = e.AcceptSocket;
-
-            if (e.SocketError != SocketError.Success)
-            {
-                // accept new connection
-                StartAccept(e);
-
-                // dispose broken client socket
-                CloseClientSocket(clientSocket);
-                return;
-            }
-
-            // accept new connection
-            StartAccept(e);
-
-            // process connection
-            ProcessAccept(clientSocket);
-        }
-
-        private void ProcessAccept(Socket clientSocket)
-        {
-            // close the client socket if we're no longer accepting new connections
-            if (!IsStarted)
-            {
-                CloseClientSocket(clientSocket);
-                return;
-            }
-
-            // capture the countdown event that we're adding a count to, as we need to make sure that we'll be signaling
-            // that same instance; the instance field for the countdown event is re-initialized when the port is restarted
-            // and at that time there may still be pending requests
-            var pendingChannelCountdown = _pendingChannelCountdown;
-
-            pendingChannelCountdown.AddCount();
-
-            try
-            {
-                var originatorEndPoint = (IPEndPoint) clientSocket.RemoteEndPoint;
-
-                RaiseRequestReceived(originatorEndPoint.Address.ToString(),
-                    (uint)originatorEndPoint.Port);
-
-                using (var channel = Session.CreateChannelDirectTcpip())
-                {
-                    channel.Exception += Channel_Exception;
-                    channel.Open(Host, Port, this, clientSocket);
-                    channel.Bind();
-                }
-            }
-            catch (Exception exp)
-            {
-                RaiseExceptionEvent(exp);
-                CloseClientSocket(clientSocket);
-            }
-            finally
-            {
-                // take into account that CountdownEvent has since been disposed; when stopping the port we
-                // wait for a given time for the channels to close, but once that timeout period has elapsed
-                // the CountdownEvent will be disposed
-                try
-                {
-                    _ = pendingChannelCountdown.Signal();
-                }
-                catch (ObjectDisposedException)
-                {
-                }
-            }
-        }
-
-        /// <summary>
-        /// Initializes the <see cref="CountdownEvent"/>.
-        /// </summary>
-        /// <remarks>
-        /// <para>
-        /// When the port is started for the first time, a <see cref="CountdownEvent"/> is created with an initial count
-        /// of <c>1</c>.
-        /// </para>
-        /// <para>
-        /// On subsequent (re)starts, we'll dispose the current <see cref="CountdownEvent"/> and create a new one with
-        /// initial count of <c>1</c>.
-        /// </para>
-        /// </remarks>
-        private void InitializePendingChannelCountdown()
-        {
-            var original = Interlocked.Exchange(ref _pendingChannelCountdown, new CountdownEvent(1));
-            original?.Dispose();
-        }
-
-        private static void CloseClientSocket(Socket clientSocket)
-        {
-            if (clientSocket.Connected)
-            {
-                try
-                {
-                    clientSocket.Shutdown(SocketShutdown.Send);
-                }
-                catch (Exception)
-                {
-                    // ignore exception when client socket was already closed
-                }
-            }
-
-            clientSocket.Dispose();
-        }
-
-        /// <summary>
-        /// Interrupts the listener, and unsubscribes from <see cref="Session"/> events.
-        /// </summary>
-        partial void StopListener()
-        {
-            // close listener socket
-            _listener?.Dispose();
-
-            // unsubscribe from session events
-            var session = Session;
-            if (session != null)
-            {
-                session.ErrorOccured -= Session_ErrorOccured;
-                session.Disconnected -= Session_Disconnected;
-            }
-        }
-
-        /// <summary>
-        /// Waits for pending channels to close.
-        /// </summary>
-        /// <param name="timeout">The maximum time to wait for the pending channels to close.</param>
-        partial void InternalStop(TimeSpan timeout)
-        {
-            _ = _pendingChannelCountdown.Signal();
-
-            if (!_pendingChannelCountdown.Wait(timeout))
-            {
-                // TODO: log as warning
-                DiagnosticAbstraction.Log("Timeout waiting for pending channels in local forwarded port to close.");
-            }
-        }
-
-        partial void InternalDispose(bool disposing)
-        {
-            if (disposing)
-            {
-                var listener = _listener;
-                if (listener != null)
-                {
-                    _listener = null;
-                    listener.Dispose();
-                }
-
-                var pendingRequestsCountdown = _pendingChannelCountdown;
-                if (pendingRequestsCountdown != null)
-                {
-                    _pendingChannelCountdown = null;
-                    pendingRequestsCountdown.Dispose();
-                }
-            }
-        }
-
-        private void Session_Disconnected(object sender, EventArgs e)
-        {
-            var session = Session;
-            if (session != null)
-            {
-                StopPort(session.ConnectionInfo.Timeout);
-            }
-        }
-
-        private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
-        {
-            var session = Session;
-            if (session != null)
-            {
-                StopPort(session.ConnectionInfo.Timeout);
-            }
-        }
-
-        private void Channel_Exception(object sender, ExceptionEventArgs e)
-        {
-            RaiseExceptionEvent(e.Exception);
-        }
-    }
-}

+ 267 - 31
src/Renci.SshNet/ForwardedPortLocal.cs

@@ -1,6 +1,9 @@
 using System;
 using System.Net;
+using System.Net.Sockets;
+using System.Threading;
 
+using Renci.SshNet.Abstractions;
 using Renci.SshNet.Common;
 
 namespace Renci.SshNet
@@ -8,10 +11,12 @@ namespace Renci.SshNet
     /// <summary>
     /// Provides functionality for local port forwarding.
     /// </summary>
-    public partial class ForwardedPortLocal : ForwardedPort, IDisposable
+    public partial class ForwardedPortLocal : ForwardedPort
     {
         private ForwardedPortStatus _status;
         private bool _isDisposed;
+        private Socket _listener;
+        private CountdownEvent _pendingChannelCountdown;
 
         /// <summary>
         /// Gets the bound host.
@@ -37,7 +42,7 @@ namespace Renci.SshNet
         /// Gets a value indicating whether port forwarding is started.
         /// </summary>
         /// <value>
-        /// <c>true</c> if port forwarding is started; otherwise, <c>false</c>.
+        /// <see langword="true"/> if port forwarding is started; otherwise, <see langword="false"/>.
         /// </value>
         public override bool IsStarted
         {
@@ -51,7 +56,7 @@ namespace Renci.SshNet
         /// <param name="host">The host.</param>
         /// <param name="port">The port.</param>
         /// <exception cref="ArgumentOutOfRangeException"><paramref name="boundPort" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="host"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="host"/> is <see langword="null"/>.</exception>
         /// <exception cref="ArgumentOutOfRangeException"><paramref name="port" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
         /// <example>
         ///     <code source="..\..\src\Renci.SshNet.Tests\Classes\ForwardedPortLocalTest.cs" region="Example SshClient AddForwardedPort Start Stop ForwardedPortLocal" language="C#" title="Local port forwarding" />
@@ -67,8 +72,8 @@ namespace Renci.SshNet
         /// <param name="boundHost">The bound host.</param>
         /// <param name="host">The host.</param>
         /// <param name="port">The port.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="boundHost"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="host"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="boundHost"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="host"/> is <see langword="null"/>.</exception>
         /// <exception cref="ArgumentOutOfRangeException"><paramref name="port" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
         public ForwardedPortLocal(string boundHost, string host, uint port)
             : this(boundHost, 0, host, port)
@@ -82,8 +87,8 @@ namespace Renci.SshNet
         /// <param name="boundPort">The bound port.</param>
         /// <param name="host">The host.</param>
         /// <param name="port">The port.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="boundHost"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="host"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="boundHost"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="host"/> is <see langword="null"/>.</exception>
         /// <exception cref="ArgumentOutOfRangeException"><paramref name="boundPort" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
         /// <exception cref="ArgumentOutOfRangeException"><paramref name="port" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
         public ForwardedPortLocal(string boundHost, uint boundPort, string host, uint port)
@@ -160,58 +165,289 @@ namespace Renci.SshNet
         /// <exception cref="ObjectDisposedException">The current instance is disposed.</exception>
         protected override void CheckDisposed()
         {
+#if NET7_0_OR_GREATER
+            ObjectDisposedException.ThrowIf(_isDisposed, this);
+#else
             if (_isDisposed)
             {
                 throw new ObjectDisposedException(GetType().FullName);
             }
+#endif // NET7_0_OR_GREATER
         }
 
-        partial void InternalStart();
-
         /// <summary>
-        /// Interrupts the listener, and waits for the listener loop to finish.
+        /// Releases unmanaged and - optionally - managed resources.
         /// </summary>
-        /// <remarks>
-        /// When the forwarded port is stopped, then any further action is skipped.
-        /// </remarks>
-        partial void StopListener();
+        /// <param name="disposing"><see langword="true"/> to release both managed and unmanaged resources; <see langword="false"/> to release only unmanaged resources.</param>
+        protected override void Dispose(bool disposing)
+        {
+            if (_isDisposed)
+            {
+                return;
+            }
 
-        partial void InternalStop(TimeSpan timeout);
+            base.Dispose(disposing);
+            InternalDispose(disposing);
+            _isDisposed = true;
+        }
 
         /// <summary>
-        /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
+        /// Finalizes an instance of the <see cref="ForwardedPortLocal"/> class.
         /// </summary>
-        public void Dispose()
+        ~ForwardedPortLocal()
+        {
+            Dispose(disposing: false);
+        }
+
+        private void InternalStart()
+        {
+            var addr = DnsAbstraction.GetHostAddresses(BoundHost)[0];
+            var ep = new IPEndPoint(addr, (int) BoundPort);
+
+            _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
+            _listener.Bind(ep);
+            _listener.Listen(5);
+
+            // update bound port (in case original was passed as zero)
+            BoundPort = (uint) ((IPEndPoint) _listener.LocalEndPoint).Port;
+
+            Session.ErrorOccured += Session_ErrorOccured;
+            Session.Disconnected += Session_Disconnected;
+
+            InitializePendingChannelCountdown();
+
+            // consider port started when we're listening for inbound connections
+            _status = ForwardedPortStatus.Started;
+
+            StartAccept(e: null);
+        }
+
+        private void StartAccept(SocketAsyncEventArgs e)
+        {
+            if (e is null)
+            {
+#pragma warning disable CA2000 // Dispose objects before losing scope
+                e = new SocketAsyncEventArgs();
+#pragma warning restore CA2000 // Dispose objects before losing scope
+                e.Completed += AcceptCompleted;
+            }
+            else
+            {
+                // clear the socket as we're reusing the context object
+                e.AcceptSocket = null;
+            }
+
+            // only accept new connections while we are started
+            if (IsStarted)
+            {
+                try
+                {
+                    if (!_listener.AcceptAsync(e))
+                    {
+                        AcceptCompleted(sender: null, e);
+                    }
+                }
+                catch (ObjectDisposedException)
+                {
+                    if (_status == ForwardedPortStatus.Stopping || _status == ForwardedPortStatus.Stopped)
+                    {
+                        // ignore ObjectDisposedException while stopping or stopped
+                        return;
+                    }
+
+                    throw;
+                }
+            }
+        }
+
+        private void AcceptCompleted(object sender, SocketAsyncEventArgs e)
         {
-            Dispose(disposing: true);
-            GC.SuppressFinalize(this);
+            if (e.SocketError is SocketError.OperationAborted or SocketError.NotSocket)
+            {
+                // server was stopped
+                return;
+            }
+
+            // capture client socket
+            var clientSocket = e.AcceptSocket;
+
+            if (e.SocketError != SocketError.Success)
+            {
+                // accept new connection
+                StartAccept(e);
+
+                // dispose broken client socket
+                CloseClientSocket(clientSocket);
+                return;
+            }
+
+            // accept new connection
+            StartAccept(e);
+
+            // process connection
+            ProcessAccept(clientSocket);
         }
 
-        partial void InternalDispose(bool disposing);
+        private void ProcessAccept(Socket clientSocket)
+        {
+            // close the client socket if we're no longer accepting new connections
+            if (!IsStarted)
+            {
+                CloseClientSocket(clientSocket);
+                return;
+            }
+
+            // capture the countdown event that we're adding a count to, as we need to make sure that we'll be signaling
+            // that same instance; the instance field for the countdown event is re-initialized when the port is restarted
+            // and at that time there may still be pending requests
+            var pendingChannelCountdown = _pendingChannelCountdown;
+
+            pendingChannelCountdown.AddCount();
+
+            try
+            {
+                var originatorEndPoint = (IPEndPoint) clientSocket.RemoteEndPoint;
+
+                RaiseRequestReceived(originatorEndPoint.Address.ToString(),
+                    (uint) originatorEndPoint.Port);
+
+                using (var channel = Session.CreateChannelDirectTcpip())
+                {
+                    channel.Exception += Channel_Exception;
+                    channel.Open(Host, Port, this, clientSocket);
+                    channel.Bind();
+                }
+            }
+            catch (Exception exp)
+            {
+                RaiseExceptionEvent(exp);
+                CloseClientSocket(clientSocket);
+            }
+            finally
+            {
+                // take into account that CountdownEvent has since been disposed; when stopping the port we
+                // wait for a given time for the channels to close, but once that timeout period has elapsed
+                // the CountdownEvent will be disposed
+                try
+                {
+                    _ = pendingChannelCountdown.Signal();
+                }
+                catch (ObjectDisposedException)
+                {
+                    // Ignore any ObjectDisposedException
+                }
+            }
+        }
 
         /// <summary>
-        /// Releases unmanaged and - optionally - managed resources.
+        /// Initializes the <see cref="CountdownEvent"/>.
         /// </summary>
-        /// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
-        protected override void Dispose(bool disposing)
+        /// <remarks>
+        /// <para>
+        /// When the port is started for the first time, a <see cref="CountdownEvent"/> is created with an initial count
+        /// of <c>1</c>.
+        /// </para>
+        /// <para>
+        /// On subsequent (re)starts, we'll dispose the current <see cref="CountdownEvent"/> and create a new one with
+        /// initial count of <c>1</c>.
+        /// </para>
+        /// </remarks>
+        private void InitializePendingChannelCountdown()
         {
-            if (_isDisposed)
+            var original = Interlocked.Exchange(ref _pendingChannelCountdown, new CountdownEvent(1));
+            original?.Dispose();
+        }
+
+        private static void CloseClientSocket(Socket clientSocket)
+        {
+            if (clientSocket.Connected)
             {
-                return;
+                try
+                {
+                    clientSocket.Shutdown(SocketShutdown.Send);
+                }
+                catch (Exception)
+                {
+                    // ignore exception when client socket was already closed
+                }
             }
 
-            base.Dispose(disposing);
+            clientSocket.Dispose();
+        }
 
-            InternalDispose(disposing);
-            _isDisposed = true;
+        /// <summary>
+        /// Interrupts the listener, and unsubscribes from <see cref="Session"/> events.
+        /// </summary>
+        private void StopListener()
+        {
+            // close listener socket
+            _listener?.Dispose();
+
+            // unsubscribe from session events
+            var session = Session;
+            if (session != null)
+            {
+                session.ErrorOccured -= Session_ErrorOccured;
+                session.Disconnected -= Session_Disconnected;
+            }
         }
 
         /// <summary>
-        /// Finalizes an instance of the <see cref="ForwardedPortLocal"/> class.
+        /// Waits for pending channels to close.
         /// </summary>
-        ~ForwardedPortLocal()
+        /// <param name="timeout">The maximum time to wait for the pending channels to close.</param>
+        private void InternalStop(TimeSpan timeout)
         {
-            Dispose(disposing: false);
+            _ = _pendingChannelCountdown.Signal();
+
+            if (!_pendingChannelCountdown.Wait(timeout))
+            {
+                // TODO: log as warning
+                DiagnosticAbstraction.Log("Timeout waiting for pending channels in local forwarded port to close.");
+            }
+        }
+
+        private void InternalDispose(bool disposing)
+        {
+            if (disposing)
+            {
+                var listener = _listener;
+                if (listener is not null)
+                {
+                    _listener = null;
+                    listener.Dispose();
+                }
+
+                var pendingRequestsCountdown = _pendingChannelCountdown;
+                if (pendingRequestsCountdown is not null)
+                {
+                    _pendingChannelCountdown = null;
+                    pendingRequestsCountdown.Dispose();
+                }
+            }
+        }
+
+        private void Session_Disconnected(object sender, EventArgs e)
+        {
+            var session = Session;
+            if (session is not null)
+            {
+                StopPort(session.ConnectionInfo.Timeout);
+            }
+        }
+
+        private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
+        {
+            var session = Session;
+            if (session is not null)
+            {
+                StopPort(session.ConnectionInfo.Timeout);
+            }
+        }
+
+        private void Channel_Exception(object sender, ExceptionEventArgs e)
+        {
+            RaiseExceptionEvent(e.Exception);
         }
     }
 }

+ 11 - 15
src/Renci.SshNet/ForwardedPortRemote.cs

@@ -12,7 +12,7 @@ namespace Renci.SshNet
     /// <summary>
     /// Provides functionality for remote port forwarding.
     /// </summary>
-    public class ForwardedPortRemote : ForwardedPort, IDisposable
+    public class ForwardedPortRemote : ForwardedPort
     {
         private ForwardedPortStatus _status;
         private bool _requestStatus;
@@ -24,7 +24,7 @@ namespace Renci.SshNet
         /// Gets a value indicating whether port forwarding is started.
         /// </summary>
         /// <value>
-        /// <c>true</c> if port forwarding is started; otherwise, <c>false</c>.
+        /// <see langword="true"/> if port forwarding is started; otherwise, <see langword="false"/>.
         /// </value>
         public override bool IsStarted
         {
@@ -80,8 +80,8 @@ namespace Renci.SshNet
         /// <param name="boundPort">The bound port.</param>
         /// <param name="hostAddress">The host address.</param>
         /// <param name="port">The port.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="boundHostAddress"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="hostAddress"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="boundHostAddress"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="hostAddress"/> is <see langword="null"/>.</exception>
         /// <exception cref="ArgumentOutOfRangeException"><paramref name="boundPort" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
         /// <exception cref="ArgumentOutOfRangeException"><paramref name="port" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
         public ForwardedPortRemote(IPAddress boundHostAddress, uint boundPort, IPAddress hostAddress, uint port)
@@ -229,10 +229,14 @@ namespace Renci.SshNet
         /// <exception cref="ObjectDisposedException">The current instance is disposed.</exception>
         protected override void CheckDisposed()
         {
+#if NET7_0_OR_GREATER
+            ObjectDisposedException.ThrowIf(_isDisposed, this);
+#else
             if (_isDisposed)
             {
                 throw new ObjectDisposedException(GetType().FullName);
             }
+#endif // NET7_0_OR_GREATER
         }
 
         private void Session_ChannelOpening(object sender, MessageEventArgs<ChannelOpenMessage> e)
@@ -285,6 +289,7 @@ namespace Renci.SshNet
                                 }
                                 catch (ObjectDisposedException)
                                 {
+                                    // Ignore any ObjectDisposedException
                                 }
                             }
                         });
@@ -335,18 +340,9 @@ namespace Renci.SshNet
         }
 
         /// <summary>
-        /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
+        /// Releases unmanaged and - optionally - managed resources.
         /// </summary>
-        public void Dispose()
-        {
-            Dispose(disposing: true);
-            GC.SuppressFinalize(this);
-        }
-
-        /// <summary>
-        /// Releases unmanaged and - optionally - managed resources
-        /// </summary>
-        /// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
+        /// <param name="disposing"><see langword="true"/> to release both managed and unmanaged resources; <see langword="false"/> to release only unmanaged resources.</param>
         protected override void Dispose(bool disposing)
         {
             if (_isDisposed)

+ 31 - 1
src/Renci.SshNet/IForwardedPort.cs

@@ -1,15 +1,45 @@
 using System;
 
+using Renci.SshNet.Common;
+
 namespace Renci.SshNet
 {
     /// <summary>
     /// Supports port forwarding functionality.
     /// </summary>
-    public interface IForwardedPort
+    public interface IForwardedPort : IDisposable
     {
         /// <summary>
         /// The <see cref="Closing"/> event occurs as the forwarded port is being stopped.
         /// </summary>
         event EventHandler Closing;
+
+        /// <summary>
+        /// Occurs when an exception is thrown.
+        /// </summary>
+        event EventHandler<ExceptionEventArgs> Exception;
+
+        /// <summary>
+        /// Occurs when a port forwarding request is received.
+        /// </summary>
+        event EventHandler<PortForwardEventArgs> RequestReceived;
+
+        /// <summary>
+        /// Gets a value indicating whether port forwarding is started.
+        /// </summary>
+        /// <value>
+        /// <see langword="true"/> if port forwarding is started; otherwise, <see langword="false"/>.
+        /// </value>
+        bool IsStarted { get; }
+
+        /// <summary>
+        /// Starts port forwarding.
+        /// </summary>
+        void Start();
+
+        /// <summary>
+        /// Stops port forwarding.
+        /// </summary>
+        void Stop();
     }
 }

+ 0 - 18
src/Renci.SshNet/IServiceFactory.NET.cs

@@ -1,18 +0,0 @@
-using Renci.SshNet.NetConf;
-
-namespace Renci.SshNet
-{
-    internal partial interface IServiceFactory
-    {
-        /// <summary>
-        /// Creates a new <see cref="INetConfSession"/> in a given <see cref="ISession"/>
-        /// and with the specified operation timeout.
-        /// </summary>
-        /// <param name="session">The <see cref="ISession"/> to create the <see cref="INetConfSession"/> in.</param>
-        /// <param name="operationTimeout">The number of milliseconds to wait for an operation to complete, or -1 to wait indefinitely.</param>
-        /// <returns>
-        /// An <see cref="INetConfSession"/>.
-        /// </returns>
-        INetConfSession CreateNetConfSession(ISession session, int operationTimeout);
-    }
-}

+ 40 - 5
src/Renci.SshNet/IServiceFactory.cs

@@ -2,8 +2,10 @@
 using System.Collections.Generic;
 using System.Net.Sockets;
 using System.Text;
+
 using Renci.SshNet.Common;
 using Renci.SshNet.Connection;
+using Renci.SshNet.NetConf;
 using Renci.SshNet.Security;
 using Renci.SshNet.Sftp;
 
@@ -14,8 +16,25 @@ namespace Renci.SshNet
     /// </summary>
     internal partial interface IServiceFactory
     {
+        /// <summary>
+        /// Creates an <see cref="IClientAuthentication"/>.
+        /// </summary>
+        /// <returns>
+        /// An <see cref="IClientAuthentication"/>.
+        /// </returns>
         IClientAuthentication CreateClientAuthentication();
 
+        /// <summary>
+        /// Creates a new <see cref="INetConfSession"/> in a given <see cref="ISession"/>
+        /// and with the specified operation timeout.
+        /// </summary>
+        /// <param name="session">The <see cref="ISession"/> to create the <see cref="INetConfSession"/> in.</param>
+        /// <param name="operationTimeout">The number of milliseconds to wait for an operation to complete, or <c>-1</c> to wait indefinitely.</param>
+        /// <returns>
+        /// An <see cref="INetConfSession"/>.
+        /// </returns>
+        INetConfSession CreateNetConfSession(ISession session, int operationTimeout);
+
         /// <summary>
         /// Creates a new <see cref="ISession"/> with the specified <see cref="ConnectionInfo"/> and
         /// <see cref="ISocketFactory"/>.
@@ -25,8 +44,8 @@ namespace Renci.SshNet
         /// <returns>
         /// An <see cref="ISession"/> for the specified <see cref="ConnectionInfo"/>.
         /// </returns>
-        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="socketFactory"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="socketFactory"/> is <see langword="null"/>.</exception>
         ISession CreateSession(ConnectionInfo connectionInfo, ISocketFactory socketFactory);
 
         /// <summary>
@@ -34,7 +53,7 @@ namespace Renci.SshNet
         /// the specified operation timeout and encoding.
         /// </summary>
         /// <param name="session">The <see cref="ISession"/> to create the <see cref="ISftpSession"/> in.</param>
-        /// <param name="operationTimeout">The number of milliseconds to wait for an operation to complete, or -1 to wait indefinitely.</param>
+        /// <param name="operationTimeout">The number of milliseconds to wait for an operation to complete, or <c>-1</c> to wait indefinitely.</param>
         /// <param name="encoding">The encoding.</param>
         /// <param name="sftpMessageFactory">The factory to use for creating SFTP messages.</param>
         /// <returns>
@@ -59,13 +78,29 @@ namespace Renci.SshNet
         /// <returns>
         /// A <see cref="IKeyExchange"/> that was negotiated between client and server.
         /// </returns>
-        /// <exception cref="ArgumentNullException"><paramref name="clientAlgorithms"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="serverAlgorithms"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="clientAlgorithms"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serverAlgorithms"/> is <see langword="null"/>.</exception>
         /// <exception cref="SshConnectionException">No key exchange algorithm is supported by both client and server.</exception>
         IKeyExchange CreateKeyExchange(IDictionary<string, Type> clientAlgorithms, string[] serverAlgorithms);
 
+        /// <summary>
+        /// Creates an <see cref="ISftpFileReader"/> for the specified file and with the specified
+        /// buffer size.
+        /// </summary>
+        /// <param name="fileName">The file to read.</param>
+        /// <param name="sftpSession">The SFTP session to use.</param>
+        /// <param name="bufferSize">The size of buffer.</param>
+        /// <returns>
+        /// An <see cref="ISftpFileReader"/>.
+        /// </returns>
         ISftpFileReader CreateSftpFileReader(string fileName, ISftpSession sftpSession, uint bufferSize);
 
+        /// <summary>
+        /// Creates a new <see cref="ISftpResponseFactory"/> instance.
+        /// </summary>
+        /// <returns>
+        /// An <see cref="ISftpResponseFactory"/>.
+        /// </returns>
         ISftpResponseFactory CreateSftpResponseFactory();
 
         /// <summary>

+ 0 - 354
src/Renci.SshNet/ScpClient.NET.cs

@@ -1,354 +0,0 @@
-using System;
-using System.IO;
-using System.Text.RegularExpressions;
-
-using Renci.SshNet.Channels;
-using Renci.SshNet.Common;
-
-namespace Renci.SshNet
-{
-    /// <summary>
-    /// Provides SCP client functionality.
-    /// </summary>
-    public partial class ScpClient
-    {
-        private static readonly Regex DirectoryInfoRe = new Regex(@"D(?<mode>\d{4}) (?<length>\d+) (?<filename>.+)");
-        private static readonly Regex TimestampRe = new Regex(@"T(?<mtime>\d+) 0 (?<atime>\d+) 0");
-
-        /// <summary>
-        /// Uploads the specified file to the remote host.
-        /// </summary>
-        /// <param name="fileInfo">The file system info.</param>
-        /// <param name="path">A relative or absolute path for the remote file.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="fileInfo" /> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="path" /> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentException"><paramref name="path"/> is a zero-length <see cref="string"/>.</exception>
-        /// <exception cref="ScpException">A directory with the specified path exists on the remote host.</exception>
-        /// <exception cref="SshException">The secure copy execution request was rejected by the server.</exception>
-        public void Upload(FileInfo fileInfo, string path)
-        {
-            if (fileInfo is null)
-            {
-                throw new ArgumentNullException(nameof(fileInfo));
-            }
-
-            var posixPath = PosixPath.CreateAbsoluteOrRelativeFilePath(path);
-
-            using (var input = ServiceFactory.CreatePipeStream())
-            using (var channel = Session.CreateChannelSession())
-            {
-                channel.DataReceived += (sender, e) => input.Write(e.Data, 0, e.Data.Length);
-                channel.Open();
-
-                // Pass only the directory part of the path to the server, and use the (hidden) -d option to signal
-                // that we expect the target to be a directory.
-                if (!channel.SendExecRequest(string.Format("scp -t -d {0}", _remotePathTransformation.Transform(posixPath.Directory))))
-                {
-                    throw SecureExecutionRequestRejectedException();
-                }
-
-                CheckReturnCode(input);
-
-                using (var source = fileInfo.OpenRead())
-                {
-                    UploadTimes(channel, input, fileInfo);
-                    UploadFileModeAndName(channel, input, source.Length, posixPath.File);
-                    UploadFileContent(channel, input, source, fileInfo.Name);
-                }
-            }
-        }
-
-        /// <summary>
-        /// Uploads the specified directory to the remote host.
-        /// </summary>
-        /// <param name="directoryInfo">The directory info.</param>
-        /// <param name="path">A relative or absolute path for the remote directory.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="directoryInfo"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="path"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentException"><paramref name="path"/> is a zero-length string.</exception>
-        /// <exception cref="ScpException"><paramref name="path"/> does not exist on the remote host, is not a directory or the user does not have the required permission.</exception>
-        /// <exception cref="SshException">The secure copy execution request was rejected by the server.</exception>
-        public void Upload(DirectoryInfo directoryInfo, string path)
-        {
-            if (directoryInfo is null)
-            {
-                throw new ArgumentNullException(nameof(directoryInfo));
-            }
-
-            if (path is null)
-            {
-                throw new ArgumentNullException(nameof(path));
-            }
-
-            if (path.Length == 0)
-            {
-                throw new ArgumentException("The path cannot be a zero-length string.", nameof(path));
-            }
-
-            using (var input = ServiceFactory.CreatePipeStream())
-            using (var channel = Session.CreateChannelSession())
-            {
-                channel.DataReceived += (sender, e) => input.Write(e.Data, 0, e.Data.Length);
-                channel.Open();
-
-                // start copy with the following options:
-                // -p preserve modification and access times
-                // -r copy directories recursively
-                // -d expect path to be a directory
-                // -t copy to remote
-                if (!channel.SendExecRequest(string.Format("scp -r -p -d -t {0}", _remotePathTransformation.Transform(path))))
-                {
-                    throw SecureExecutionRequestRejectedException();
-                }
-
-                CheckReturnCode(input);
-
-                UploadDirectoryContent(channel, input, directoryInfo);
-            }
-        }
-
-        /// <summary>
-        /// Downloads the specified file from the remote host to local file.
-        /// </summary>
-        /// <param name="filename">Remote host file name.</param>
-        /// <param name="fileInfo">Local file information.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="fileInfo"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentException"><paramref name="filename"/> is <c>null</c> or empty.</exception>
-        /// <exception cref="ScpException"><paramref name="filename"/> exists on the remote host, and is not a regular file.</exception>
-        /// <exception cref="SshException">The secure copy execution request was rejected by the server.</exception>
-        public void Download(string filename, FileInfo fileInfo)
-        {
-            if (string.IsNullOrEmpty(filename))
-            {
-                throw new ArgumentException("filename");
-            }
-
-            if (fileInfo is null)
-            {
-                throw new ArgumentNullException(nameof(fileInfo));
-            }
-
-            using (var input = ServiceFactory.CreatePipeStream())
-            using (var channel = Session.CreateChannelSession())
-            {
-                channel.DataReceived += (sender, e) => input.Write(e.Data, 0, e.Data.Length);
-                channel.Open();
-
-                // Send channel command request
-                if (!channel.SendExecRequest(string.Format("scp -pf {0}", _remotePathTransformation.Transform(filename))))
-                {
-                    throw SecureExecutionRequestRejectedException();
-                }
-
-                // Send reply
-                SendSuccessConfirmation(channel);
-
-                InternalDownload(channel, input, fileInfo);
-            }
-        }
-
-        /// <summary>
-        /// Downloads the specified directory from the remote host to local directory.
-        /// </summary>
-        /// <param name="directoryName">Remote host directory name.</param>
-        /// <param name="directoryInfo">Local directory information.</param>
-        /// <exception cref="ArgumentException"><paramref name="directoryName"/> is <c>null</c> or empty.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="directoryInfo"/> is <c>null</c>.</exception>
-        /// <exception cref="ScpException">File or directory with the specified path does not exist on the remote host.</exception>
-        /// <exception cref="SshException">The secure copy execution request was rejected by the server.</exception>
-        public void Download(string directoryName, DirectoryInfo directoryInfo)
-        {
-            if (string.IsNullOrEmpty(directoryName))
-            {
-                throw new ArgumentException("directoryName");
-            }
-
-            if (directoryInfo is null)
-            {
-                throw new ArgumentNullException(nameof(directoryInfo));
-            }
-
-            using (var input = ServiceFactory.CreatePipeStream())
-            using (var channel = Session.CreateChannelSession())
-            {
-                channel.DataReceived += (sender, e) => input.Write(e.Data, 0, e.Data.Length);
-                channel.Open();
-
-                // Send channel command request
-                if (!channel.SendExecRequest(string.Format("scp -prf {0}", _remotePathTransformation.Transform(directoryName))))
-                {
-                    throw SecureExecutionRequestRejectedException();
-                }
-
-                // Send reply
-                SendSuccessConfirmation(channel);
-
-                InternalDownload(channel, input, directoryInfo);
-            }
-        }
-
-        /// <summary>
-        /// Uploads the <see cref="FileSystemInfo.LastWriteTimeUtc"/> and <see cref="FileSystemInfo.LastAccessTimeUtc"/>
-        /// of the next file or directory to upload.
-        /// </summary>
-        /// <param name="channel">The channel to perform the upload in.</param>
-        /// <param name="input">A <see cref="Stream"/> from which any feedback from the server can be read.</param>
-        /// <param name="fileOrDirectory">The file or directory to upload.</param>
-        private void UploadTimes(IChannelSession channel, Stream input, FileSystemInfo fileOrDirectory)
-        {
-            var zeroTime = new DateTime(1970, 1, 1, 0, 0, 0, 0, DateTimeKind.Utc);
-            var modificationSeconds = (long) (fileOrDirectory.LastWriteTimeUtc - zeroTime).TotalSeconds;
-            var accessSeconds = (long) (fileOrDirectory.LastAccessTimeUtc - zeroTime).TotalSeconds;
-            SendData(channel, string.Format("T{0} 0 {1} 0\n", modificationSeconds, accessSeconds));
-            CheckReturnCode(input);
-        }
-
-        /// <summary>
-        /// Upload the files and subdirectories in the specified directory.
-        /// </summary>
-        /// <param name="channel">The channel to perform the upload in.</param>
-        /// <param name="input">A <see cref="Stream"/> from which any feedback from the server can be read.</param>
-        /// <param name="directoryInfo">The directory to upload.</param>
-        private void UploadDirectoryContent(IChannelSession channel, Stream input, DirectoryInfo directoryInfo)
-        {
-            // Upload files
-            var files = directoryInfo.GetFiles();
-            foreach (var file in files)
-            {
-                using (var source = file.OpenRead())
-                {
-                    UploadTimes(channel, input, file);
-                    UploadFileModeAndName(channel, input, source.Length, file.Name);
-                    UploadFileContent(channel, input, source, file.Name);
-                }
-            }
-
-            // Upload directories
-            var directories = directoryInfo.GetDirectories();
-            foreach (var directory in directories)
-            {
-                UploadTimes(channel, input, directory);
-                UploadDirectoryModeAndName(channel, input, directory.Name);
-                UploadDirectoryContent(channel, input, directory);
-            }
-
-            // Mark upload of current directory complete
-            SendData(channel, "E\n");
-            CheckReturnCode(input);
-        }
-
-        /// <summary>
-        /// Sets mode and name of the directory being upload.
-        /// </summary>
-        private void UploadDirectoryModeAndName(IChannelSession channel, Stream input, string directoryName)
-        {
-            SendData(channel, string.Format("D0755 0 {0}\n", directoryName));
-            CheckReturnCode(input);
-        }
-
-        private void InternalDownload(IChannelSession channel, Stream input, FileSystemInfo fileSystemInfo)
-        {
-            var modifiedTime = DateTime.Now;
-            var accessedTime = DateTime.Now;
-
-            var startDirectoryFullName = fileSystemInfo.FullName;
-            var currentDirectoryFullName = startDirectoryFullName;
-            var directoryCounter = 0;
-
-            while (true)
-            {
-                var message = ReadString(input);
-
-                if (message == "E")
-                {
-                    SendSuccessConfirmation(channel); // Send reply
-
-                    directoryCounter--;
-
-                    currentDirectoryFullName = new DirectoryInfo(currentDirectoryFullName).Parent.FullName;
-
-                    if (directoryCounter == 0)
-                    {
-                        break;
-                    }
-
-                    continue;
-                }
-
-                var match = DirectoryInfoRe.Match(message);
-                if (match.Success)
-                {
-                    SendSuccessConfirmation(channel); // Send reply
-
-                    // Read directory
-                    var filename = match.Result("${filename}");
-
-                    DirectoryInfo newDirectoryInfo;
-                    if (directoryCounter > 0)
-                    {
-                        newDirectoryInfo = Directory.CreateDirectory(Path.Combine(currentDirectoryFullName, filename));
-                        newDirectoryInfo.LastAccessTime = accessedTime;
-                        newDirectoryInfo.LastWriteTime = modifiedTime;
-                    }
-                    else
-                    {
-                        // Don't create directory for first level
-                        newDirectoryInfo = fileSystemInfo as DirectoryInfo;
-                    }
-
-                    directoryCounter++;
-
-                    currentDirectoryFullName = newDirectoryInfo.FullName;
-                    continue;
-                }
-
-                match = FileInfoRe.Match(message);
-                if (match.Success)
-                {
-                    // Read file
-                    SendSuccessConfirmation(channel); //  Send reply
-
-                    var length = long.Parse(match.Result("${length}"));
-                    var fileName = match.Result("${filename}");
-
-                    if (fileSystemInfo is not FileInfo fileInfo)
-                    {
-                        fileInfo = new FileInfo(Path.Combine(currentDirectoryFullName, fileName));
-                    }
-
-                    using (var output = fileInfo.OpenWrite())
-                    {
-                        InternalDownload(channel, input, output, fileName, length);
-                    }
-
-                    fileInfo.LastAccessTime = accessedTime;
-                    fileInfo.LastWriteTime = modifiedTime;
-
-                    if (directoryCounter == 0)
-                    {
-                        break;
-                    }
-
-                    continue;
-                }
-
-                match = TimestampRe.Match(message);
-                if (match.Success)
-                {
-                    // Read timestamp
-                    SendSuccessConfirmation(channel); //  Send reply
-
-                    var mtime = long.Parse(match.Result("${mtime}"));
-                    var atime = long.Parse(match.Result("${atime}"));
-
-                    var zeroTime = new DateTime(1970, 1, 1, 0, 0, 0, 0, DateTimeKind.Utc);
-                    modifiedTime = zeroTime.AddSeconds(mtime);
-                    accessedTime = zeroTime.AddSeconds(atime);
-                    continue;
-                }
-
-                SendErrorConfirmation(channel, string.Format("\"{0}\" is not valid protocol message.", message));
-            }
-        }
-    }
-}

+ 414 - 74
src/Renci.SshNet/ScpClient.cs

@@ -16,8 +16,7 @@ namespace Renci.SshNet
     /// </summary>
     /// <remarks>
     /// <para>
-    /// More information on the SCP protocol is available here:
-    /// https://github.com/net-ssh/net-scp/blob/master/lib/net/scp.rb
+    /// More information on the SCP protocol is available here: https://github.com/net-ssh/net-scp/blob/master/lib/net/scp.rb.
     /// </para>
     /// <para>
     /// Known issues in OpenSSH:
@@ -32,9 +31,11 @@ namespace Renci.SshNet
     public partial class ScpClient : BaseClient
     {
         private const string Message = "filename";
-        private static readonly Regex FileInfoRe = new Regex(@"C(?<mode>\d{4}) (?<length>\d+) (?<filename>.+)");
+        private static readonly Regex FileInfoRe = new Regex(@"C(?<mode>\d{4}) (?<length>\d+) (?<filename>.+)", RegexOptions.Compiled);
         private static readonly byte[] SuccessConfirmationCode = { 0 };
         private static readonly byte[] ErrorConfirmationCode = { 1 };
+        private static readonly Regex DirectoryInfoRe = new Regex(@"D(?<mode>\d{4}) (?<length>\d+) (?<filename>.+)", RegexOptions.Compiled);
+        private static readonly Regex TimestampRe = new Regex(@"T(?<mtime>\d+) 0 (?<atime>\d+) 0", RegexOptions.Compiled);
 
         private IRemotePathTransformation _remotePathTransformation;
 
@@ -61,7 +62,7 @@ namespace Renci.SshNet
         /// <value>
         /// The transformation to apply to remote paths. The default is <see cref="RemotePathTransformation.DoubleQuote"/>.
         /// </value>
-        /// <exception cref="ArgumentNullException"><paramref name="value"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
         /// <remarks>
         /// <para>
         /// This transformation is applied to the remote file or directory path that is passed to the
@@ -74,7 +75,10 @@ namespace Renci.SshNet
         /// </remarks>
         public IRemotePathTransformation RemotePathTransformation
         {
-            get { return _remotePathTransformation; }
+            get
+            {
+                return _remotePathTransformation;
+            }
             set
             {
                 if (value is null)
@@ -100,7 +104,7 @@ namespace Renci.SshNet
         /// Initializes a new instance of the <see cref="ScpClient"/> class.
         /// </summary>
         /// <param name="connectionInfo">The connection info.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <see langword="null"/>.</exception>
         public ScpClient(ConnectionInfo connectionInfo)
             : this(connectionInfo, ownsConnectionInfo: false)
         {
@@ -113,8 +117,8 @@ namespace Renci.SshNet
         /// <param name="port">Connection port.</param>
         /// <param name="username">Authentication username.</param>
         /// <param name="password">Authentication password.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="password"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentException"><paramref name="host"/> is invalid, or <paramref name="username"/> is <c>null</c> or contains only whitespace characters.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="password"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentException"><paramref name="host"/> is invalid, or <paramref name="username"/> is <see langword="null"/> or contains only whitespace characters.</exception>
         /// <exception cref="ArgumentOutOfRangeException"><paramref name="port"/> is not within <see cref="IPEndPoint.MinPort"/> and <see cref="IPEndPoint.MaxPort"/>.</exception>
         [SuppressMessage("Microsoft.Reliability", "CA2000:DisposeObjectsBeforeLosingScope", Justification = "Disposed in Dispose(bool) method.")]
         public ScpClient(string host, int port, string username, string password)
@@ -128,8 +132,8 @@ namespace Renci.SshNet
         /// <param name="host">Connection host.</param>
         /// <param name="username">Authentication username.</param>
         /// <param name="password">Authentication password.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="password"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentException"><paramref name="host"/> is invalid, or <paramref name="username"/> is <c>null</c> or contains only whitespace characters.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="password"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentException"><paramref name="host"/> is invalid, or <paramref name="username"/> is <see langword="null"/> or contains only whitespace characters.</exception>
         public ScpClient(string host, string username, string password)
             : this(host, ConnectionInfo.DefaultPort, username, password)
         {
@@ -142,8 +146,8 @@ namespace Renci.SshNet
         /// <param name="port">Connection port.</param>
         /// <param name="username">Authentication username.</param>
         /// <param name="keyFiles">Authentication private key file(s) .</param>
-        /// <exception cref="ArgumentNullException"><paramref name="keyFiles"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentException"><paramref name="host"/> is invalid, -or- <paramref name="username"/> is <c>null</c> or contains only whitespace characters.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="keyFiles"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentException"><paramref name="host"/> is invalid, -or- <paramref name="username"/> is <see langword="null"/> or contains only whitespace characters.</exception>
         /// <exception cref="ArgumentOutOfRangeException"><paramref name="port"/> is not within <see cref="IPEndPoint.MinPort"/> and <see cref="IPEndPoint.MaxPort"/>.</exception>
         [SuppressMessage("Microsoft.Reliability", "CA2000:DisposeObjectsBeforeLosingScope", Justification = "Disposed in Dispose(bool) method.")]
         public ScpClient(string host, int port, string username, params IPrivateKeySource[] keyFiles)
@@ -157,8 +161,8 @@ namespace Renci.SshNet
         /// <param name="host">Connection host.</param>
         /// <param name="username">Authentication username.</param>
         /// <param name="keyFiles">Authentication private key file(s) .</param>
-        /// <exception cref="ArgumentNullException"><paramref name="keyFiles"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentException"><paramref name="host"/> is invalid, -or- <paramref name="username"/> is <c>null</c> or contains only whitespace characters.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="keyFiles"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentException"><paramref name="host"/> is invalid, -or- <paramref name="username"/> is <see langword="null"/> or contains only whitespace characters.</exception>
         public ScpClient(string host, string username, params IPrivateKeySource[] keyFiles)
             : this(host, ConnectionInfo.DefaultPort, username, keyFiles)
         {
@@ -169,9 +173,9 @@ namespace Renci.SshNet
         /// </summary>
         /// <param name="connectionInfo">The connection info.</param>
         /// <param name="ownsConnectionInfo">Specified whether this instance owns the connection info.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <see langword="null"/>.</exception>
         /// <remarks>
-        /// If <paramref name="ownsConnectionInfo"/> is <c>true</c>, then the
+        /// If <paramref name="ownsConnectionInfo"/> is <see langword="true"/>, then the
         /// connection info will be disposed when this instance is disposed.
         /// </remarks>
         private ScpClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo)
@@ -185,10 +189,10 @@ namespace Renci.SshNet
         /// <param name="connectionInfo">The connection info.</param>
         /// <param name="ownsConnectionInfo">Specified whether this instance owns the connection info.</param>
         /// <param name="serviceFactory">The factory to use for creating new services.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="serviceFactory"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serviceFactory"/> is <see langword="null"/>.</exception>
         /// <remarks>
-        /// If <paramref name="ownsConnectionInfo"/> is <c>true</c>, then the
+        /// If <paramref name="ownsConnectionInfo"/> is <see langword="true"/>, then the
         /// connection info will be disposed when this instance is disposed.
         /// </remarks>
         internal ScpClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory)
@@ -204,7 +208,7 @@ namespace Renci.SshNet
         /// </summary>
         /// <param name="source">The <see cref="Stream"/> to upload.</param>
         /// <param name="path">A relative or absolute path for the remote file.</param>
-        /// <exception cref="ArgumentNullException"><paramref name="path" /> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="path" /> is <see langword="null"/>.</exception>
         /// <exception cref="ArgumentException"><paramref name="path"/> is a zero-length <see cref="string"/>.</exception>
         /// <exception cref="ScpException">A directory with the specified path exists on the remote host.</exception>
         /// <exception cref="SshException">The secure copy execution request was rejected by the server.</exception>
@@ -232,13 +236,185 @@ namespace Renci.SshNet
             }
         }
 
+        /// <summary>
+        /// Uploads the specified file to the remote host.
+        /// </summary>
+        /// <param name="fileInfo">The file system info.</param>
+        /// <param name="path">A relative or absolute path for the remote file.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="fileInfo" /> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="path" /> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentException"><paramref name="path"/> is a zero-length <see cref="string"/>.</exception>
+        /// <exception cref="ScpException">A directory with the specified path exists on the remote host.</exception>
+        /// <exception cref="SshException">The secure copy execution request was rejected by the server.</exception>
+        public void Upload(FileInfo fileInfo, string path)
+        {
+            if (fileInfo is null)
+            {
+                throw new ArgumentNullException(nameof(fileInfo));
+            }
+
+            var posixPath = PosixPath.CreateAbsoluteOrRelativeFilePath(path);
+
+            using (var input = ServiceFactory.CreatePipeStream())
+            using (var channel = Session.CreateChannelSession())
+            {
+                channel.DataReceived += (sender, e) => input.Write(e.Data, 0, e.Data.Length);
+                channel.Open();
+
+                // Pass only the directory part of the path to the server, and use the (hidden) -d option to signal
+                // that we expect the target to be a directory.
+                if (!channel.SendExecRequest($"scp -t -d {_remotePathTransformation.Transform(posixPath.Directory)}"))
+                {
+                    throw SecureExecutionRequestRejectedException();
+                }
+
+                CheckReturnCode(input);
+
+                using (var source = fileInfo.OpenRead())
+                {
+                    UploadTimes(channel, input, fileInfo);
+                    UploadFileModeAndName(channel, input, source.Length, posixPath.File);
+                    UploadFileContent(channel, input, source, fileInfo.Name);
+                }
+            }
+        }
+
+        /// <summary>
+        /// Uploads the specified directory to the remote host.
+        /// </summary>
+        /// <param name="directoryInfo">The directory info.</param>
+        /// <param name="path">A relative or absolute path for the remote directory.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="directoryInfo"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="path"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentException"><paramref name="path"/> is a zero-length string.</exception>
+        /// <exception cref="ScpException"><paramref name="path"/> does not exist on the remote host, is not a directory or the user does not have the required permission.</exception>
+        /// <exception cref="SshException">The secure copy execution request was rejected by the server.</exception>
+        public void Upload(DirectoryInfo directoryInfo, string path)
+        {
+            if (directoryInfo is null)
+            {
+                throw new ArgumentNullException(nameof(directoryInfo));
+            }
+
+            if (path is null)
+            {
+                throw new ArgumentNullException(nameof(path));
+            }
+
+            if (path.Length == 0)
+            {
+                throw new ArgumentException("The path cannot be a zero-length string.", nameof(path));
+            }
+
+            using (var input = ServiceFactory.CreatePipeStream())
+            using (var channel = Session.CreateChannelSession())
+            {
+                channel.DataReceived += (sender, e) => input.Write(e.Data, 0, e.Data.Length);
+                channel.Open();
+
+                // start copy with the following options:
+                // -p preserve modification and access times
+                // -r copy directories recursively
+                // -d expect path to be a directory
+                // -t copy to remote
+                if (!channel.SendExecRequest($"scp -r -p -d -t {_remotePathTransformation.Transform(path)}"))
+                {
+                    throw SecureExecutionRequestRejectedException();
+                }
+
+                CheckReturnCode(input);
+
+                UploadDirectoryContent(channel, input, directoryInfo);
+            }
+        }
+
+        /// <summary>
+        /// Downloads the specified file from the remote host to local file.
+        /// </summary>
+        /// <param name="filename">Remote host file name.</param>
+        /// <param name="fileInfo">Local file information.</param>
+        /// <exception cref="ArgumentNullException"><paramref name="fileInfo"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentException"><paramref name="filename"/> is <see langword="null"/> or empty.</exception>
+        /// <exception cref="ScpException"><paramref name="filename"/> exists on the remote host, and is not a regular file.</exception>
+        /// <exception cref="SshException">The secure copy execution request was rejected by the server.</exception>
+        public void Download(string filename, FileInfo fileInfo)
+        {
+            if (string.IsNullOrEmpty(filename))
+            {
+                throw new ArgumentException("filename");
+            }
+
+            if (fileInfo is null)
+            {
+                throw new ArgumentNullException(nameof(fileInfo));
+            }
+
+            using (var input = ServiceFactory.CreatePipeStream())
+            using (var channel = Session.CreateChannelSession())
+            {
+                channel.DataReceived += (sender, e) => input.Write(e.Data, 0, e.Data.Length);
+                channel.Open();
+
+                // Send channel command request
+                if (!channel.SendExecRequest($"scp -pf {_remotePathTransformation.Transform(filename)}"))
+                {
+                    throw SecureExecutionRequestRejectedException();
+                }
+
+                // Send reply
+                SendSuccessConfirmation(channel);
+
+                InternalDownload(channel, input, fileInfo);
+            }
+        }
+
+        /// <summary>
+        /// Downloads the specified directory from the remote host to local directory.
+        /// </summary>
+        /// <param name="directoryName">Remote host directory name.</param>
+        /// <param name="directoryInfo">Local directory information.</param>
+        /// <exception cref="ArgumentException"><paramref name="directoryName"/> is <see langword="null"/> or empty.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="directoryInfo"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ScpException">File or directory with the specified path does not exist on the remote host.</exception>
+        /// <exception cref="SshException">The secure copy execution request was rejected by the server.</exception>
+        public void Download(string directoryName, DirectoryInfo directoryInfo)
+        {
+            if (string.IsNullOrEmpty(directoryName))
+            {
+                throw new ArgumentException("directoryName");
+            }
+
+            if (directoryInfo is null)
+            {
+                throw new ArgumentNullException(nameof(directoryInfo));
+            }
+
+            using (var input = ServiceFactory.CreatePipeStream())
+            using (var channel = Session.CreateChannelSession())
+            {
+                channel.DataReceived += (sender, e) => input.Write(e.Data, 0, e.Data.Length);
+                channel.Open();
+
+                // Send channel command request
+                if (!channel.SendExecRequest($"scp -prf {_remotePathTransformation.Transform(directoryName)}"))
+                {
+                    throw SecureExecutionRequestRejectedException();
+                }
+
+                // Send reply
+                SendSuccessConfirmation(channel);
+
+                InternalDownload(channel, input, directoryInfo);
+            }
+        }
+
         /// <summary>
         /// Downloads the specified file from the remote host to the stream.
         /// </summary>
         /// <param name="filename">A relative or absolute path for the remote file.</param>
         /// <param name="destination">The <see cref="Stream"/> to download the remote file to.</param>
-        /// <exception cref="ArgumentException"><paramref name="filename"/> is <c>null</c> or contains only whitespace characters.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="destination"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentException"><paramref name="filename"/> is <see langword="null"/> or contains only whitespace characters.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="destination"/> is <see langword="null"/>.</exception>
         /// <exception cref="ScpException"><paramref name="filename"/> exists on the remote host, and is not a regular file.</exception>
         /// <exception cref="SshException">The secure copy execution request was rejected by the server.</exception>
         public void Download(string filename, Stream destination)
@@ -287,6 +463,33 @@ namespace Renci.SshNet
             }
         }
 
+        private static void SendData(IChannel channel, byte[] buffer, int length)
+        {
+            channel.SendData(buffer, 0, length);
+        }
+
+        private static void SendData(IChannel channel, byte[] buffer)
+        {
+            channel.SendData(buffer);
+        }
+
+        private static int ReadByte(Stream stream)
+        {
+            var b = stream.ReadByte();
+
+            if (b == -1)
+            {
+                throw new SshException("Stream has been closed.");
+            }
+
+            return b;
+        }
+
+        private static SshException SecureExecutionRequestRejectedException()
+        {
+            throw new SshException("Secure copy execution request was rejected by the server. Please consult the server logs.");
+        }
+
         /// <summary>
         /// Sets mode, size and name of file being upload.
         /// </summary>
@@ -342,34 +545,6 @@ namespace Renci.SshNet
             CheckReturnCode(input);
         }
 
-        private void InternalDownload(IChannel channel, Stream input, Stream output, string filename, long length)
-        {
-            var buffer = new byte[Math.Min(length, BufferSize)];
-            var needToRead = length;
-
-            do
-            {
-                var read = input.Read(buffer, 0, (int) Math.Min(needToRead, BufferSize));
-
-                output.Write(buffer, 0, read);
-
-                RaiseDownloadingEvent(filename, length, length - needToRead);
-
-                needToRead -= read;
-            }
-            while (needToRead > 0);
-
-            output.Flush();
-
-            //  Raise one more time when file downloaded
-            RaiseDownloadingEvent(filename, length, length - needToRead);
-
-            //  Send confirmation byte after last data byte was read
-            SendSuccessConfirmation(channel);
-
-            CheckReturnCode(input);
-        }
-
         private void RaiseDownloadingEvent(string filename, long size, long downloaded)
         {
             Downloading?.Invoke(this, new ScpDownloadEventArgs(filename, size, downloaded));
@@ -412,28 +587,6 @@ namespace Renci.SshNet
             channel.SendData(ConnectionInfo.Encoding.GetBytes(command));
         }
 
-        private static void SendData(IChannel channel, byte[] buffer, int length)
-        {
-            channel.SendData(buffer, 0, length);
-        }
-
-        private static void SendData(IChannel channel, byte[] buffer)
-        {
-            channel.SendData(buffer);
-        }
-
-        private static int ReadByte(Stream stream)
-        {
-            var b = stream.ReadByte();
-
-            if (b == -1)
-            {
-                throw new SshException("Stream has been closed.");
-            }
-
-            return b;
-        }
-
         /// <summary>
         /// Read a LF-terminated string from the <see cref="Stream"/>.
         /// </summary>
@@ -470,9 +623,196 @@ namespace Renci.SshNet
             return ConnectionInfo.Encoding.GetString(readBytes, 0, readBytes.Length);
         }
 
-        private static SshException SecureExecutionRequestRejectedException()
+        /// <summary>
+        /// Uploads the <see cref="FileSystemInfo.LastWriteTimeUtc"/> and <see cref="FileSystemInfo.LastAccessTimeUtc"/>
+        /// of the next file or directory to upload.
+        /// </summary>
+        /// <param name="channel">The channel to perform the upload in.</param>
+        /// <param name="input">A <see cref="Stream"/> from which any feedback from the server can be read.</param>
+        /// <param name="fileOrDirectory">The file or directory to upload.</param>
+        private void UploadTimes(IChannelSession channel, Stream input, FileSystemInfo fileOrDirectory)
         {
-            throw new SshException("Secure copy execution request was rejected by the server. Please consult the server logs.");
+            var zeroTime = new DateTime(1970, 1, 1, 0, 0, 0, 0, DateTimeKind.Utc);
+            var modificationSeconds = (long) (fileOrDirectory.LastWriteTimeUtc - zeroTime).TotalSeconds;
+            var accessSeconds = (long) (fileOrDirectory.LastAccessTimeUtc - zeroTime).TotalSeconds;
+            SendData(channel, string.Format(CultureInfo.InvariantCulture, "T{0} 0 {1} 0\n", modificationSeconds, accessSeconds));
+            CheckReturnCode(input);
+        }
+
+        /// <summary>
+        /// Upload the files and subdirectories in the specified directory.
+        /// </summary>
+        /// <param name="channel">The channel to perform the upload in.</param>
+        /// <param name="input">A <see cref="Stream"/> from which any feedback from the server can be read.</param>
+        /// <param name="directoryInfo">The directory to upload.</param>
+        private void UploadDirectoryContent(IChannelSession channel, Stream input, DirectoryInfo directoryInfo)
+        {
+            // Upload files
+            var files = directoryInfo.GetFiles();
+            foreach (var file in files)
+            {
+                using (var source = file.OpenRead())
+                {
+                    UploadTimes(channel, input, file);
+                    UploadFileModeAndName(channel, input, source.Length, file.Name);
+                    UploadFileContent(channel, input, source, file.Name);
+                }
+            }
+
+            // Upload directories
+            var directories = directoryInfo.GetDirectories();
+            foreach (var directory in directories)
+            {
+                UploadTimes(channel, input, directory);
+                UploadDirectoryModeAndName(channel, input, directory.Name);
+                UploadDirectoryContent(channel, input, directory);
+            }
+
+            // Mark upload of current directory complete
+            SendData(channel, "E\n");
+            CheckReturnCode(input);
+        }
+
+        /// <summary>
+        /// Sets mode and name of the directory being upload.
+        /// </summary>
+        private void UploadDirectoryModeAndName(IChannelSession channel, Stream input, string directoryName)
+        {
+            SendData(channel, string.Format("D0755 0 {0}\n", directoryName));
+            CheckReturnCode(input);
+        }
+
+        private void InternalDownload(IChannel channel, Stream input, Stream output, string filename, long length)
+        {
+            var buffer = new byte[Math.Min(length, BufferSize)];
+            var needToRead = length;
+
+            do
+            {
+                var read = input.Read(buffer, 0, (int) Math.Min(needToRead, BufferSize));
+
+                output.Write(buffer, 0, read);
+
+                RaiseDownloadingEvent(filename, length, length - needToRead);
+
+                needToRead -= read;
+            }
+            while (needToRead > 0);
+
+            output.Flush();
+
+            // Raise one more time when file downloaded
+            RaiseDownloadingEvent(filename, length, length - needToRead);
+
+            // Send confirmation byte after last data byte was read
+            SendSuccessConfirmation(channel);
+
+            CheckReturnCode(input);
+        }
+
+        private void InternalDownload(IChannelSession channel, Stream input, FileSystemInfo fileSystemInfo)
+        {
+            var modifiedTime = DateTime.Now;
+            var accessedTime = DateTime.Now;
+
+            var startDirectoryFullName = fileSystemInfo.FullName;
+            var currentDirectoryFullName = startDirectoryFullName;
+            var directoryCounter = 0;
+
+            while (true)
+            {
+                var message = ReadString(input);
+
+                if (message == "E")
+                {
+                    SendSuccessConfirmation(channel); // Send reply
+
+                    directoryCounter--;
+
+                    currentDirectoryFullName = new DirectoryInfo(currentDirectoryFullName).Parent.FullName;
+
+                    if (directoryCounter == 0)
+                    {
+                        break;
+                    }
+
+                    continue;
+                }
+
+                var match = DirectoryInfoRe.Match(message);
+                if (match.Success)
+                {
+                    SendSuccessConfirmation(channel); // Send reply
+
+                    // Read directory
+                    var filename = match.Result("${filename}");
+
+                    DirectoryInfo newDirectoryInfo;
+                    if (directoryCounter > 0)
+                    {
+                        newDirectoryInfo = Directory.CreateDirectory(Path.Combine(currentDirectoryFullName, filename));
+                        newDirectoryInfo.LastAccessTime = accessedTime;
+                        newDirectoryInfo.LastWriteTime = modifiedTime;
+                    }
+                    else
+                    {
+                        // Don't create directory for first level
+                        newDirectoryInfo = fileSystemInfo as DirectoryInfo;
+                    }
+
+                    directoryCounter++;
+
+                    currentDirectoryFullName = newDirectoryInfo.FullName;
+                    continue;
+                }
+
+                match = FileInfoRe.Match(message);
+                if (match.Success)
+                {
+                    // Read file
+                    SendSuccessConfirmation(channel); //  Send reply
+
+                    var length = long.Parse(match.Result("${length}"), CultureInfo.InvariantCulture);
+                    var fileName = match.Result("${filename}");
+
+                    if (fileSystemInfo is not FileInfo fileInfo)
+                    {
+                        fileInfo = new FileInfo(Path.Combine(currentDirectoryFullName, fileName));
+                    }
+
+                    using (var output = fileInfo.OpenWrite())
+                    {
+                        InternalDownload(channel, input, output, fileName, length);
+                    }
+
+                    fileInfo.LastAccessTime = accessedTime;
+                    fileInfo.LastWriteTime = modifiedTime;
+
+                    if (directoryCounter == 0)
+                    {
+                        break;
+                    }
+
+                    continue;
+                }
+
+                match = TimestampRe.Match(message);
+                if (match.Success)
+                {
+                    // Read timestamp
+                    SendSuccessConfirmation(channel); //  Send reply
+
+                    var mtime = long.Parse(match.Result("${mtime}"), CultureInfo.InvariantCulture);
+                    var atime = long.Parse(match.Result("${atime}"), CultureInfo.InvariantCulture);
+
+                    var zeroTime = new DateTime(1970, 1, 1, 0, 0, 0, 0, DateTimeKind.Utc);
+                    modifiedTime = zeroTime.AddSeconds(mtime);
+                    accessedTime = zeroTime.AddSeconds(atime);
+                    continue;
+                }
+
+                SendErrorConfirmation(channel, string.Format("\"{0}\" is not valid protocol message.", message));
+            }
         }
     }
 }

+ 0 - 21
src/Renci.SshNet/ServiceFactory.NET.cs

@@ -1,21 +0,0 @@
-using Renci.SshNet.NetConf;
-
-namespace Renci.SshNet
-{
-    internal partial class ServiceFactory
-    {
-        /// <summary>
-        /// Creates a new <see cref="INetConfSession"/> in a given <see cref="ISession"/>
-        /// and with the specified operation timeout.
-        /// </summary>
-        /// <param name="session">The <see cref="ISession"/> to create the <see cref="INetConfSession"/> in.</param>
-        /// <param name="operationTimeout">The number of milliseconds to wait for an operation to complete, or -1 to wait indefinitely.</param>
-        /// <returns>
-        /// An <see cref="INetConfSession"/>.
-        /// </returns>
-        public INetConfSession CreateNetConfSession(ISession session, int operationTimeout)
-        {
-            return new NetConfSession(session, operationTimeout);
-        }
-    }
-}

+ 39 - 8
src/Renci.SshNet/ServiceFactory.cs

@@ -8,6 +8,7 @@ using Renci.SshNet.Abstractions;
 using Renci.SshNet.Common;
 using Renci.SshNet.Connection;
 using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.NetConf;
 using Renci.SshNet.Security;
 using Renci.SshNet.Sftp;
 
@@ -22,13 +23,13 @@ namespace Renci.SshNet
         /// Defines the number of times an authentication attempt with any given <see cref="IAuthenticationMethod"/>
         /// can result in <see cref="AuthenticationResult.PartialSuccess"/> before it is disregarded.
         /// </summary>
-        private static readonly int PartialSuccessLimit = 5;
+        private const int PartialSuccessLimit = 5;
 
         /// <summary>
-        /// Creates a <see cref="IClientAuthentication"/>.
+        /// Creates an <see cref="IClientAuthentication"/>.
         /// </summary>
         /// <returns>
-        /// A <see cref="IClientAuthentication"/>.
+        /// An <see cref="IClientAuthentication"/>.
         /// </returns>
         public IClientAuthentication CreateClientAuthentication()
         {
@@ -44,8 +45,8 @@ namespace Renci.SshNet
         /// <returns>
         /// An <see cref="ISession"/> for the specified <see cref="ConnectionInfo"/>.
         /// </returns>
-        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="socketFactory"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="connectionInfo"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="socketFactory"/> is <see langword="null"/>.</exception>
         public ISession CreateSession(ConnectionInfo connectionInfo, ISocketFactory socketFactory)
         {
             return new Session(connectionInfo, this, socketFactory);
@@ -56,7 +57,7 @@ namespace Renci.SshNet
         /// the specified operation timeout and encoding.
         /// </summary>
         /// <param name="session">The <see cref="ISession"/> to create the <see cref="ISftpSession"/> in.</param>
-        /// <param name="operationTimeout">The number of milliseconds to wait for an operation to complete, or -1 to wait indefinitely.</param>
+        /// <param name="operationTimeout">The number of milliseconds to wait for an operation to complete, or <c>-1</c> to wait indefinitely.</param>
         /// <param name="encoding">The encoding.</param>
         /// <param name="sftpMessageFactory">The factory to use for creating SFTP messages.</param>
         /// <returns>
@@ -87,8 +88,8 @@ namespace Renci.SshNet
         /// <returns>
         /// A <see cref="IKeyExchange"/> that was negotiated between client and server.
         /// </returns>
-        /// <exception cref="ArgumentNullException"><paramref name="clientAlgorithms"/> is <c>null</c>.</exception>
-        /// <exception cref="ArgumentNullException"><paramref name="serverAlgorithms"/> is <c>null</c>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="clientAlgorithms"/> is <see langword="null"/>.</exception>
+        /// <exception cref="ArgumentNullException"><paramref name="serverAlgorithms"/> is <see langword="null"/>.</exception>
         /// <exception cref="SshConnectionException">No key exchange algorithms are supported by both client and server.</exception>
         public IKeyExchange CreateKeyExchange(IDictionary<string, Type> clientAlgorithms, string[] serverAlgorithms)
         {
@@ -116,6 +117,30 @@ namespace Renci.SshNet
             return keyExchangeAlgorithmType.CreateInstance<IKeyExchange>();
         }
 
+        /// <summary>
+        /// Creates a new <see cref="INetConfSession"/> in a given <see cref="ISession"/>
+        /// and with the specified operation timeout.
+        /// </summary>
+        /// <param name="session">The <see cref="ISession"/> to create the <see cref="INetConfSession"/> in.</param>
+        /// <param name="operationTimeout">The number of milliseconds to wait for an operation to complete, or <c>-1</c> to wait indefinitely.</param>
+        /// <returns>
+        /// An <see cref="INetConfSession"/>.
+        /// </returns>
+        public INetConfSession CreateNetConfSession(ISession session, int operationTimeout)
+        {
+            return new NetConfSession(session, operationTimeout);
+        }
+
+        /// <summary>
+        /// Creates an <see cref="ISftpFileReader"/> for the specified file and with the specified
+        /// buffer size.
+        /// </summary>
+        /// <param name="fileName">The file to read.</param>
+        /// <param name="sftpSession">The SFTP session to use.</param>
+        /// <param name="bufferSize">The size of buffer.</param>
+        /// <returns>
+        /// An <see cref="ISftpFileReader"/>.
+        /// </returns>
         public ISftpFileReader CreateSftpFileReader(string fileName, ISftpSession sftpSession, uint bufferSize)
         {
             const int defaultMaxPendingReads = 3;
@@ -151,6 +176,12 @@ namespace Renci.SshNet
             return sftpSession.CreateFileReader(handle, sftpSession, chunkSize, maxPendingReads, fileSize);
         }
 
+        /// <summary>
+        /// Creates a new <see cref="ISftpResponseFactory"/> instance.
+        /// </summary>
+        /// <returns>
+        /// An <see cref="ISftpResponseFactory"/>.
+        /// </returns>
         public ISftpResponseFactory CreateSftpResponseFactory()
         {
             return new SftpResponseFactory();