Procházet zdrojové kódy

Add support for Zlib compression (.NET 6.0 onward only) (#1326)

* Integrate `ZLibStream` from .NET 6.0+ with SSH.NET.

* OpenSSH server does not support zlib (pre-auth); OpenSSH client still supports zlib (pre-auth)

* Correct compression algorithm name; Update README.md

* Integrate `ZLibStream` from .NET 6.0+ with SSH.NET.

* OpenSSH server does not support zlib (pre-auth); OpenSSH client still supports zlib (pre-auth)

* Correct compression algorithm name; Update README.md

* Test the compression by upload/download file

* Refactor compression.

* Move delayed compression logic to base class.

* seal Zlib

* update unit test

* update unit test to see if it can trigger integration test

* Flush zlibStream

* Fix integration test

* update test

* Update ConnectionInfo.cs

Co-authored-by: Rob Hague <rob.hague00@gmail.com>

* Update README.md

---------

Co-authored-by: Rob Hague <rob.hague00@gmail.com>
Scott Xu před 1 rokem
rodič
revize
b553f81f8a

+ 8 - 1
README.md

@@ -1,4 +1,4 @@
- ![Logo](https://raw.githubusercontent.com/sshnet/SSH.NET/develop/images/logo/png/SS-NET-icon-h50.png) SSH.NET
+ ![Logo](https://raw.githubusercontent.com/sshnet/SSH.NET/develop/images/logo/png/SS-NET-icon-h50.png) SSH.NET
 =======
 SSH.NET is a Secure Shell (SSH-2) library for .NET, optimized for parallelism.
 
@@ -123,6 +123,13 @@ Private keys can be encrypted using one of the following cipher methods:
 * hmac-sha2-256-etm<span></span>@openssh.com
 * hmac-sha2-512-etm<span></span>@openssh.com
 
+
+## Compression
+
+**SSH.NET** supports the following compression algorithms:
+* none (default)
+* zlib<span></span>@openssh.com (.NET 6 and higher)
+
 ## Framework Support
 
 **SSH.NET** supports the following target frameworks:

+ 0 - 18
src/Renci.SshNet/Compression/CompressionMode.cs

@@ -1,18 +0,0 @@
-namespace Renci.SshNet.Compression
-{
-    /// <summary>
-    /// Specifies compression modes.
-    /// </summary>
-    public enum CompressionMode
-    {
-        /// <summary>
-        /// Specifies that content should be compressed.
-        /// </summary>
-        Compress = 0,
-
-        /// <summary>
-        /// Specifies that content should be decompressed.
-        /// </summary>
-        Decompress = 1,
-    }
-}

+ 53 - 51
src/Renci.SshNet/Compression/Compressor.cs

@@ -1,6 +1,6 @@
 using System;
-using System.IO;
 
+using Renci.SshNet.Messages.Authentication;
 using Renci.SshNet.Security;
 
 namespace Renci.SshNet.Compression
@@ -10,35 +10,23 @@ namespace Renci.SshNet.Compression
     /// </summary>
     public abstract class Compressor : Algorithm, IDisposable
     {
-        private readonly ZlibStream _compressor;
-        private readonly ZlibStream _decompressor;
-        private MemoryStream _compressorStream;
-        private MemoryStream _decompressorStream;
-        private bool _isDisposed;
-
-        /// <summary>
-        /// Gets or sets a value indicating whether compression is active.
-        /// </summary>
-        /// <value>
-        /// <see langword="true"/> if compression is active; otherwise, <see langword="false"/>.
-        /// </value>
-        protected bool IsActive { get; set; }
+        private readonly bool _delayedCompression;
 
-        /// <summary>
-        /// Gets the session.
-        /// </summary>
-        protected Session Session { get; private set; }
+        private bool _isActive;
+        private Session _session;
+        private bool _isDisposed;
 
         /// <summary>
         /// Initializes a new instance of the <see cref="Compressor"/> class.
         /// </summary>
-        protected Compressor()
+        /// <param name="delayedCompression">
+        /// <see langword="false"/> to start compression after receiving SSH_MSG_NEWKEYS.
+        /// <see langword="true"/> to delay compression util receiving SSH_MSG_USERAUTH_SUCCESS.
+        /// <see href="https://www.openssh.com/txt/draft-miller-secsh-compression-delayed-00.txt"/>.
+        /// </param>
+        protected Compressor(bool delayedCompression)
         {
-            _compressorStream = new MemoryStream();
-            _decompressorStream = new MemoryStream();
-
-            _compressor = new ZlibStream(_compressorStream, CompressionMode.Compress);
-            _decompressor = new ZlibStream(_decompressorStream, CompressionMode.Decompress);
+            _delayedCompression = delayedCompression;
         }
 
         /// <summary>
@@ -47,7 +35,15 @@ namespace Renci.SshNet.Compression
         /// <param name="session">The session.</param>
         public virtual void Init(Session session)
         {
-            Session = session;
+            if (_delayedCompression)
+            {
+                _session = session;
+                _session.UserAuthenticationSuccessReceived += Session_UserAuthenticationSuccessReceived;
+            }
+            else
+            {
+                _isActive = true;
+            }
         }
 
         /// <summary>
@@ -57,7 +53,7 @@ namespace Renci.SshNet.Compression
         /// <returns>
         /// The compressed data.
         /// </returns>
-        public virtual byte[] Compress(byte[] data)
+        public byte[] Compress(byte[] data)
         {
             return Compress(data, 0, data.Length);
         }
@@ -73,7 +69,7 @@ namespace Renci.SshNet.Compression
         /// </returns>
         public virtual byte[] Compress(byte[] data, int offset, int length)
         {
-            if (!IsActive)
+            if (!_isActive)
             {
                 if (offset == 0 && length == data.Length)
                 {
@@ -85,13 +81,20 @@ namespace Renci.SshNet.Compression
                 return buffer;
             }
 
-            _compressorStream.SetLength(0);
-
-            _compressor.Write(data, offset, length);
-
-            return _compressorStream.ToArray();
+            return CompressCore(data, offset, length);
         }
 
+        /// <summary>
+        /// Compresses the specified data.
+        /// </summary>
+        /// <param name="data">Data to compress.</param>
+        /// <param name="offset">The zero-based byte offset in <paramref name="data"/> at which to begin reading the data to compress. </param>
+        /// <param name="length">The number of bytes to be compressed. </param>
+        /// <returns>
+        /// The compressed data.
+        /// </returns>
+        protected abstract byte[] CompressCore(byte[] data, int offset, int length);
+
         /// <summary>
         /// Decompresses the specified data.
         /// </summary>
@@ -99,7 +102,7 @@ namespace Renci.SshNet.Compression
         /// <returns>
         /// The decompressed data.
         /// </returns>
-        public virtual byte[] Decompress(byte[] data)
+        public byte[] Decompress(byte[] data)
         {
             return Decompress(data, 0, data.Length);
         }
@@ -115,7 +118,7 @@ namespace Renci.SshNet.Compression
         /// </returns>
         public virtual byte[] Decompress(byte[] data, int offset, int length)
         {
-            if (!IsActive)
+            if (!_isActive)
             {
                 if (offset == 0 && length == data.Length)
                 {
@@ -127,11 +130,24 @@ namespace Renci.SshNet.Compression
                 return buffer;
             }
 
-            _decompressorStream.SetLength(0);
+            return DecompressCore(data, offset, length);
+        }
 
-            _decompressor.Write(data, offset, length);
+        /// <summary>
+        /// Decompresses the specified data.
+        /// </summary>
+        /// <param name="data">Compressed data.</param>
+        /// <param name="offset">The zero-based byte offset in <paramref name="data"/> at which to begin reading the data to decompress. </param>
+        /// <param name="length">The number of bytes to be read from the compressed data. </param>
+        /// <returns>
+        /// The decompressed data.
+        /// </returns>
+        protected abstract byte[] DecompressCore(byte[] data, int offset, int length);
 
-            return _decompressorStream.ToArray();
+        private void Session_UserAuthenticationSuccessReceived(object sender, MessageEventArgs<SuccessMessage> e)
+        {
+            _isActive = true;
+            _session.UserAuthenticationSuccessReceived -= Session_UserAuthenticationSuccessReceived;
         }
 
         /// <summary>
@@ -156,20 +172,6 @@ namespace Renci.SshNet.Compression
 
             if (disposing)
             {
-                var compressorStream = _compressorStream;
-                if (compressorStream != null)
-                {
-                    compressorStream.Dispose();
-                    _compressorStream = null;
-                }
-
-                var decompressorStream = _decompressorStream;
-                if (decompressorStream != null)
-                {
-                    decompressorStream.Dispose();
-                    _decompressorStream = null;
-                }
-
                 _isDisposed = true;
             }
         }

+ 78 - 7
src/Renci.SshNet/Compression/Zlib.cs

@@ -1,10 +1,35 @@
-namespace Renci.SshNet.Compression
+#if NET6_0_OR_GREATER
+using System.IO;
+using System.IO.Compression;
+
+namespace Renci.SshNet.Compression
 {
     /// <summary>
     /// Represents "zlib" compression implementation.
     /// </summary>
-    internal sealed class Zlib : Compressor
+    internal class Zlib : Compressor
     {
+        private readonly ZLibStream _compressor;
+        private readonly ZLibStream _decompressor;
+        private MemoryStream _compressorStream;
+        private MemoryStream _decompressorStream;
+        private bool _isDisposed;
+
+        public Zlib()
+            : this(delayedCompression: false)
+        {
+        }
+
+        protected Zlib(bool delayedCompression)
+            : base(delayedCompression)
+        {
+            _compressorStream = new MemoryStream();
+            _decompressorStream = new MemoryStream();
+
+            _compressor = new ZLibStream(_compressorStream, CompressionMode.Compress);
+            _decompressor = new ZLibStream(_decompressorStream, CompressionMode.Decompress);
+        }
+
         /// <summary>
         /// Gets algorithm name.
         /// </summary>
@@ -13,15 +38,61 @@
             get { return "zlib"; }
         }
 
+        protected override byte[] CompressCore(byte[] data, int offset, int length)
+        {
+            _compressorStream.SetLength(0);
+
+            _compressor.Write(data, offset, length);
+            _compressor.Flush();
+
+            return _compressorStream.ToArray();
+        }
+
+        protected override byte[] DecompressCore(byte[] data, int offset, int length)
+        {
+            _decompressorStream.Write(data, offset, length);
+            _decompressorStream.Position = 0;
+
+            using var outputStream = new MemoryStream();
+            _decompressor.CopyTo(outputStream);
+
+            _decompressorStream.SetLength(0);
+
+            return outputStream.ToArray();
+        }
+
         /// <summary>
-        /// Initializes the algorithm.
+        /// Releases unmanaged and - optionally - managed resources.
         /// </summary>
-        /// <param name="session">The session.</param>
-        public override void Init(Session session)
+        /// <param name="disposing"><see langword="true"/> to release both managed and unmanaged resources; <see langword="false"/> to release only unmanaged resources.</param>
+        protected override void Dispose(bool disposing)
         {
-            base.Init(session);
+            base.Dispose(disposing);
+
+            if (_isDisposed)
+            {
+                return;
+            }
+
+            if (disposing)
+            {
+                var compressorStream = _compressorStream;
+                if (compressorStream != null)
+                {
+                    compressorStream.Dispose();
+                    _compressorStream = null;
+                }
+
+                var decompressorStream = _decompressorStream;
+                if (decompressorStream != null)
+                {
+                    decompressorStream.Dispose();
+                    _decompressorStream = null;
+                }
 
-            IsActive = true;
+                _isDisposed = true;
+            }
         }
     }
 }
+#endif

+ 7 - 25
src/Renci.SshNet/Compression/ZlibOpenSsh.cs

@@ -1,35 +1,17 @@
-using Renci.SshNet.Messages.Authentication;
-
+#if NET6_0_OR_GREATER
 namespace Renci.SshNet.Compression
 {
-    /// <summary>
-    /// Represents "zlib@openssh.org" compression implementation.
-    /// </summary>
-    public class ZlibOpenSsh : Compressor
+    internal sealed class ZlibOpenSsh : Zlib
     {
-        /// <summary>
-        /// Gets algorithm name.
-        /// </summary>
-        public override string Name
+        public ZlibOpenSsh()
+            : base(delayedCompression: true)
         {
-            get { return "zlib@openssh.org"; }
         }
 
-        /// <summary>
-        /// Initializes the algorithm.
-        /// </summary>
-        /// <param name="session">The session.</param>
-        public override void Init(Session session)
-        {
-            base.Init(session);
-
-            session.UserAuthenticationSuccessReceived += Session_UserAuthenticationSuccessReceived;
-        }
-
-        private void Session_UserAuthenticationSuccessReceived(object sender, MessageEventArgs<SuccessMessage> e)
+        public override string Name
         {
-            IsActive = true;
-            Session.UserAuthenticationSuccessReceived -= Session_UserAuthenticationSuccessReceived;
+            get { return "zlib@openssh.com"; }
         }
     }
 }
+#endif

+ 0 - 57
src/Renci.SshNet/Compression/ZlibStream.cs

@@ -1,57 +0,0 @@
-using System.IO;
-
-#pragma warning disable S125 // Sections of code should not be commented out
-#pragma warning disable SA1005 // Single line comments should begin with single space
-
-namespace Renci.SshNet.Compression
-{
-    /// <summary>
-    /// Implements Zlib compression algorithm.
-    /// </summary>
-#pragma warning disable CA1711 // Identifiers should not have incorrect suffix
-    public class ZlibStream
-#pragma warning restore CA1711 // Identifiers should not have incorrect suffix
-    {
-        //private readonly Ionic.Zlib.ZlibStream _baseStream;
-
-        /// <summary>
-        /// Initializes a new instance of the <see cref="ZlibStream" /> class.
-        /// </summary>
-        /// <param name="stream">The stream.</param>
-        /// <param name="mode">The mode.</param>
-#pragma warning disable IDE0060 // Remove unused parameter
-        public ZlibStream(Stream stream, CompressionMode mode)
-#pragma warning restore IDE0060 // Remove unused parameter
-        {
-            //switch (mode)
-            //{
-            //    case CompressionMode.Compress:
-            //        this._baseStream = new Ionic.Zlib.ZlibStream(stream, Ionic.Zlib.CompressionMode.Compress, Ionic.Zlib.CompressionLevel.Default);
-            //        break;
-            //    case CompressionMode.Decompress:
-            //        this._baseStream = new Ionic.Zlib.ZlibStream(stream, Ionic.Zlib.CompressionMode.Decompress, Ionic.Zlib.CompressionLevel.Default);
-            //        break;
-            //    default:
-            //        break;
-            //}
-
-            //this._baseStream.FlushMode = Ionic.Zlib.FlushType.Partial;
-        }
-
-        /// <summary>
-        /// Writes the specified buffer.
-        /// </summary>
-        /// <param name="buffer">The buffer.</param>
-        /// <param name="offset">The offset.</param>
-        /// <param name="count">The count.</param>
-#pragma warning disable IDE0060 // Remove unused parameter
-        public void Write(byte[] buffer, int offset, int count)
-#pragma warning restore IDE0060 // Remove unused parameter
-        {
-            //this._baseStream.Write(buffer, offset, count);
-        }
-#pragma warning restore SA1005 // Single line comments should begin with single space
-    }
-}
-
-#pragma warning restore S125 // Sections of code should not be commented out

+ 3 - 0
src/Renci.SshNet/ConnectionInfo.cs

@@ -439,6 +439,9 @@ namespace Renci.SshNet
             CompressionAlgorithms = new Dictionary<string, Func<Compressor>>
                 {
                     { "none", null },
+#if NET6_0_OR_GREATER
+                    { "zlib@openssh.com", () => new ZlibOpenSsh() },
+#endif
                 };
 
             ChannelRequests = new Dictionary<string, RequestInfo>

+ 2 - 2
src/Renci.SshNet/Security/KeyExchange.cs

@@ -124,7 +124,7 @@ namespace Renci.SshNet.Security
             var compressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys
                                             from a in message.CompressionAlgorithmsClientToServer
                                             where a == b
-                                            select a).LastOrDefault();
+                                            select a).FirstOrDefault();
             if (string.IsNullOrEmpty(compressionAlgorithmName))
             {
                 throw new SshConnectionException("Compression algorithm not found", DisconnectReason.KeyExchangeFailed);
@@ -136,7 +136,7 @@ namespace Renci.SshNet.Security
             var decompressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys
                                               from a in message.CompressionAlgorithmsServerToClient
                                               where a == b
-                                              select a).LastOrDefault();
+                                              select a).FirstOrDefault();
             if (string.IsNullOrEmpty(decompressionAlgorithmName))
             {
                 throw new SshConnectionException("Decompression algorithm not found", DisconnectReason.KeyExchangeFailed);

+ 14 - 0
src/Renci.SshNet/Session.cs

@@ -2132,6 +2132,20 @@ namespace Renci.SshNet
                     _clientMac = null;
                 }
 
+                var serverDecompression = _serverDecompression;
+                if (serverDecompression != null)
+                {
+                    serverDecompression.Dispose();
+                    _serverDecompression = null;
+                }
+
+                var clientCompression = _clientCompression;
+                if (clientCompression != null)
+                {
+                    clientCompression.Dispose();
+                    _clientCompression = null;
+                }
+
                 var keyExchange = _keyExchange;
                 if (keyExchange != null)
                 {

+ 64 - 0
test/Renci.SshNet.IntegrationTests/CompressionTests.cs

@@ -0,0 +1,64 @@
+using Renci.SshNet.Compression;
+
+namespace Renci.SshNet.IntegrationTests
+{
+    [TestClass]
+    public class CompressionTests : IntegrationTestBase
+    {
+        private IConnectionInfoFactory _connectionInfoFactory;
+
+        [TestInitialize]
+        public void SetUp()
+        {
+            _connectionInfoFactory = new LinuxVMConnectionFactory(SshServerHostName, SshServerPort);
+        }
+
+        [TestMethod]
+        public void None()
+        {
+            DoTest(new KeyValuePair<string, Func<Compressor>>("none", null));
+        }
+
+#if NET6_0_OR_GREATER
+        [TestMethod]
+        public void ZlibOpenSsh()
+        {
+            DoTest(new KeyValuePair<string, Func<Compressor>>("zlib@openssh.com", () => new ZlibOpenSsh()));
+        }
+#endif
+
+        private void DoTest(KeyValuePair<string, Func<Compressor>> compressor)
+        {
+            using (var scpClient = new ScpClient(_connectionInfoFactory.Create()))
+            {
+                scpClient.ConnectionInfo.CompressionAlgorithms.Clear();
+                scpClient.ConnectionInfo.CompressionAlgorithms.Add(compressor);
+
+                scpClient.Connect();
+
+                Assert.AreEqual(compressor.Key, scpClient.ConnectionInfo.CurrentClientCompressionAlgorithm);
+                Assert.AreEqual(compressor.Key, scpClient.ConnectionInfo.CurrentServerCompressionAlgorithm);
+
+                var file = $"/tmp/{Guid.NewGuid()}.txt";
+
+                var sb = new StringBuilder();
+                for (var i = 0; i < 100; i++)
+                {
+                    _ = sb.Append("Repeating");
+                }
+
+                var fileContent = sb.ToString();
+
+                using var uploadStream = new MemoryStream(Encoding.UTF8.GetBytes(fileContent));
+                scpClient.Upload(uploadStream, file);
+
+                using var downloadStream = new MemoryStream();
+                scpClient.Download(file, downloadStream);
+
+                var result = Encoding.UTF8.GetString(downloadStream.ToArray());
+
+                Assert.AreEqual(fileContent, result);
+            }
+        }
+    }
+}