Przeglądaj źródła

Read the underlying buffer in SshDataStream (#1638)

SshDataStream is a MemoryStream, so we can access the buffer directly.
Also simplify some usage in PrivateKeyFile.
Rob Hague 5 miesięcy temu
rodzic
commit
7c07b10e4a

+ 24 - 0
src/Renci.SshNet/Common/Extensions.cs

@@ -19,7 +19,9 @@ namespace Renci.SshNet.Common
     /// </summary>
     internal static class Extensions
     {
+#pragma warning disable S4136 // Method overloads should be grouped together
         internal static byte[] ToArray(this ServiceName serviceName)
+#pragma warning restore S4136 // Method overloads should be grouped together
         {
             switch (serviceName)
             {
@@ -382,6 +384,28 @@ namespace Renci.SshNet.Common
             value = default;
             return false;
         }
+
+        internal static ArraySegment<T> Slice<T>(this ArraySegment<T> arraySegment, int index)
+        {
+            return new ArraySegment<T>(arraySegment.Array, arraySegment.Offset + index, arraySegment.Count - index);
+        }
+
+        internal static ArraySegment<T> Slice<T>(this ArraySegment<T> arraySegment, int index, int count)
+        {
+            return new ArraySegment<T>(arraySegment.Array, arraySegment.Offset + index, count);
+        }
+
+        internal static T[] ToArray<T>(this ArraySegment<T> arraySegment)
+        {
+            if (arraySegment.Count == 0)
+            {
+                return Array.Empty<T>();
+            }
+
+            var array = new T[arraySegment.Count];
+            Array.Copy(arraySegment.Array, arraySegment.Offset, array, 0, arraySegment.Count);
+            return array;
+        }
 #endif
     }
 }

+ 56 - 55
src/Renci.SshNet/Common/SshDataStream.cs

@@ -1,4 +1,6 @@
 using System;
+using System.Buffers.Binary;
+using System.Diagnostics;
 using System.Globalization;
 using System.IO;
 using System.Numerics;
@@ -27,7 +29,7 @@ namespace Renci.SshNet.Common
         /// <param name="buffer">The array of unsigned bytes from which to create the current stream.</param>
         /// <exception cref="ArgumentNullException"><paramref name="buffer"/> is <see langword="null"/>.</exception>
         public SshDataStream(byte[] buffer)
-            : base(buffer)
+            : base(buffer ?? throw new ArgumentNullException(nameof(buffer)), 0, buffer.Length, writable: true, publiclyVisible: true)
         {
         }
 
@@ -39,7 +41,7 @@ namespace Renci.SshNet.Common
         /// <param name="count">The number of bytes to load.</param>
         /// <exception cref="ArgumentNullException"><paramref name="buffer"/> is <see langword="null"/>.</exception>
         public SshDataStream(byte[] buffer, int offset, int count)
-            : base(buffer, offset, count)
+            : base(buffer, offset, count, writable: true, publiclyVisible: true)
         {
         }
 
@@ -58,19 +60,6 @@ namespace Renci.SshNet.Common
         }
 
 #if NETFRAMEWORK || NETSTANDARD2_0
-        private int Read(Span<byte> buffer)
-        {
-            var sharedBuffer = System.Buffers.ArrayPool<byte>.Shared.Rent(buffer.Length);
-
-            var numRead = Read(sharedBuffer, 0, buffer.Length);
-
-            sharedBuffer.AsSpan(0, numRead).CopyTo(buffer);
-
-            System.Buffers.ArrayPool<byte>.Shared.Return(sharedBuffer);
-
-            return numRead;
-        }
-
         private void Write(ReadOnlySpan<byte> buffer)
         {
             var sharedBuffer = System.Buffers.ArrayPool<byte>.Shared.Rent(buffer.Length);
@@ -90,7 +79,7 @@ namespace Renci.SshNet.Common
         public void Write(uint value)
         {
             Span<byte> bytes = stackalloc byte[4];
-            System.Buffers.Binary.BinaryPrimitives.WriteUInt32BigEndian(bytes, value);
+            BinaryPrimitives.WriteUInt32BigEndian(bytes, value);
             Write(bytes);
         }
 
@@ -101,7 +90,7 @@ namespace Renci.SshNet.Common
         public void Write(ulong value)
         {
             Span<byte> bytes = stackalloc byte[8];
-            System.Buffers.Binary.BinaryPrimitives.WriteUInt64BigEndian(bytes, value);
+            BinaryPrimitives.WriteUInt64BigEndian(bytes, value);
             Write(bytes);
         }
 
@@ -137,6 +126,7 @@ namespace Renci.SshNet.Common
         /// <exception cref="ArgumentNullException"><paramref name="encoding"/> is <see langword="null"/>.</exception>
         public void Write(string s, Encoding encoding)
         {
+            ThrowHelper.ThrowIfNull(s);
             ThrowHelper.ThrowIfNull(encoding);
 
 #if NETSTANDARD2_1 || NET
@@ -153,12 +143,21 @@ namespace Renci.SshNet.Common
         }
 
         /// <summary>
-        /// Reads a byte array from the SSH data stream.
+        /// Reads a length-prefixed byte array from the SSH data stream.
         /// </summary>
         /// <returns>
         /// The byte array read from the SSH data stream.
         /// </returns>
         public byte[] ReadBinary()
+        {
+            return ReadBinarySegment().ToArray();
+        }
+
+        /// <summary>
+        /// Reads a length-prefixed byte array from the SSH data stream,
+        /// returned as a view over the underlying buffer.
+        /// </summary>
+        internal ArraySegment<byte> ReadBinarySegment()
         {
             var length = ReadUInt32();
 
@@ -167,7 +166,23 @@ namespace Renci.SshNet.Common
                 throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Data longer than {0} is not supported.", int.MaxValue));
             }
 
-            return ReadBytes((int)length);
+            var buffer = GetRemainingBuffer().Slice(0, (int)length);
+
+            Position += length;
+
+            return buffer;
+        }
+
+        /// <summary>
+        /// Gets a view over the remaining data in the underlying buffer.
+        /// </summary>
+        private ArraySegment<byte> GetRemainingBuffer()
+        {
+            var success = TryGetBuffer(out var buffer);
+
+            Debug.Assert(success, "Expected buffer to be publicly visible");
+
+            return buffer.Slice((int)Position);
         }
 
         /// <summary>
@@ -205,11 +220,11 @@ namespace Renci.SshNet.Common
         /// </returns>
         public BigInteger ReadBigInt()
         {
-            var data = ReadBinary();
-
 #if NETSTANDARD2_1 || NET
+            var data = ReadBinarySegment();
             return new BigInteger(data, isBigEndian: true);
 #else
+            var data = ReadBinary();
             Array.Reverse(data);
             return new BigInteger(data);
 #endif
@@ -223,9 +238,9 @@ namespace Renci.SshNet.Common
         /// </returns>
         public ushort ReadUInt16()
         {
-            Span<byte> bytes = stackalloc byte[2];
-            ReadBytes(bytes);
-            return System.Buffers.Binary.BinaryPrimitives.ReadUInt16BigEndian(bytes);
+            var ret = BinaryPrimitives.ReadUInt16BigEndian(GetRemainingBuffer());
+            Position += sizeof(ushort);
+            return ret;
         }
 
         /// <summary>
@@ -236,9 +251,9 @@ namespace Renci.SshNet.Common
         /// </returns>
         public uint ReadUInt32()
         {
-            Span<byte> span = stackalloc byte[4];
-            ReadBytes(span);
-            return System.Buffers.Binary.BinaryPrimitives.ReadUInt32BigEndian(span);
+            var ret = BinaryPrimitives.ReadUInt32BigEndian(GetRemainingBuffer());
+            Position += sizeof(uint);
+            return ret;
         }
 
         /// <summary>
@@ -249,9 +264,9 @@ namespace Renci.SshNet.Common
         /// </returns>
         public ulong ReadUInt64()
         {
-            Span<byte> span = stackalloc byte[8];
-            ReadBytes(span);
-            return System.Buffers.Binary.BinaryPrimitives.ReadUInt64BigEndian(span);
+            var ret = BinaryPrimitives.ReadUInt64BigEndian(GetRemainingBuffer());
+            Position += sizeof(ulong);
+            return ret;
         }
 
         /// <summary>
@@ -265,19 +280,13 @@ namespace Renci.SshNet.Common
         {
             encoding ??= Encoding.UTF8;
 
-            var length = ReadUInt32();
-
-            if (length > int.MaxValue)
-            {
-                throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Strings longer than {0} is not supported.", int.MaxValue));
-            }
+            var bytes = ReadBinarySegment();
 
-            var bytes = ReadBytes((int)length);
-            return encoding.GetString(bytes, 0, bytes.Length);
+            return encoding.GetString(bytes.Array, bytes.Offset, bytes.Count);
         }
 
         /// <summary>
-        /// Writes the stream contents to a byte array, regardless of the <see cref="MemoryStream.Position"/>.
+        /// Retrieves the stream contents as a byte array, regardless of the <see cref="MemoryStream.Position"/>.
         /// </summary>
         /// <returns>
         /// This method returns the contents of the <see cref="SshDataStream"/> as a byte array.
@@ -288,9 +297,15 @@ namespace Renci.SshNet.Common
         /// </remarks>
         public override byte[] ToArray()
         {
-            if (Capacity == Length)
+            var success = TryGetBuffer(out var buffer);
+
+            Debug.Assert(success, "Expected buffer to be publicly visible");
+
+            if (buffer.Offset == 0 &&
+                buffer.Count == buffer.Array.Length &&
+                buffer.Count == Length)
             {
-                return GetBuffer();
+                return buffer.Array;
             }
 
             return base.ToArray();
@@ -315,19 +330,5 @@ namespace Renci.SshNet.Common
 
             return data;
         }
-
-        /// <summary>
-        /// Reads data into the specified <paramref name="buffer" />.
-        /// </summary>
-        /// <param name="buffer">The buffer to read into.</param>
-        /// <exception cref="ArgumentOutOfRangeException"><paramref name="buffer"/> is larger than the total of bytes available.</exception>
-        private void ReadBytes(Span<byte> buffer)
-        {
-            var bytesRead = Read(buffer);
-            if (bytesRead < buffer.Length)
-            {
-                throw new ArgumentOutOfRangeException(nameof(buffer), string.Format(CultureInfo.InvariantCulture, "The requested length ({0}) is greater than the actual number of bytes read ({1}).", buffer.Length, bytesRead));
-            }
-        }
     }
 }

+ 19 - 20
src/Renci.SshNet/PrivateKeyFile.OpenSSH.cs

@@ -32,7 +32,7 @@ namespace Renci.SshNet
             /// </summary>
             public Key Parse()
             {
-                var keyReader = new SshDataReader(_data);
+                var keyReader = new SshDataStream(_data);
 
                 // check magic header
                 var authMagic = "openssh-key-v1\0"u8;
@@ -171,7 +171,7 @@ namespace Renci.SshNet
                 // now parse the data we called the private key, it actually contains the public key again
                 // so we need to parse through it to get the private key bytes, plus there's some
                 // validation we need to do.
-                var privateKeyReader = new SshDataReader(privateKeyBytes);
+                var privateKeyReader = new SshDataStream(privateKeyBytes);
 
                 // check ints should match, they wouldn't match for example if the wrong passphrase was supplied
                 var checkInt1 = (int)privateKeyReader.ReadUInt32();
@@ -196,33 +196,29 @@ namespace Renci.SshNet
                         // https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent-11#section-3.2.3
 
                         // ENC(A)
-                        _ = privateKeyReader.ReadBignum2();
+                        _ = privateKeyReader.ReadBinarySegment();
 
                         // k || ENC(A)
-                        unencryptedPrivateKey = privateKeyReader.ReadBignum2();
+                        unencryptedPrivateKey = privateKeyReader.ReadBinary();
                         parsedKey = new ED25519Key(unencryptedPrivateKey);
                         break;
                     case "ecdsa-sha2-nistp256":
                     case "ecdsa-sha2-nistp384":
                     case "ecdsa-sha2-nistp521":
-                        // curve
-                        var len = (int)privateKeyReader.ReadUInt32();
-                        var curve = Encoding.ASCII.GetString(privateKeyReader.ReadBytes(len));
+                        var curve = privateKeyReader.ReadString(Encoding.ASCII);
 
-                        // public key
-                        publicKey = privateKeyReader.ReadBignum2();
+                        publicKey = privateKeyReader.ReadBinary();
 
-                        // private key
-                        unencryptedPrivateKey = privateKeyReader.ReadBignum2();
+                        unencryptedPrivateKey = privateKeyReader.ReadBinary();
                         parsedKey = new EcdsaKey(curve, publicKey, unencryptedPrivateKey.TrimLeadingZeros());
                         break;
                     case "ssh-rsa":
-                        var modulus = privateKeyReader.ReadBignum(); // n
-                        var exponent = privateKeyReader.ReadBignum(); // e
-                        var d = privateKeyReader.ReadBignum(); // d
-                        var inverseQ = privateKeyReader.ReadBignum(); // iqmp
-                        var p = privateKeyReader.ReadBignum(); // p
-                        var q = privateKeyReader.ReadBignum(); // q
+                        var modulus = privateKeyReader.ReadBigInt();
+                        var exponent = privateKeyReader.ReadBigInt();
+                        var d = privateKeyReader.ReadBigInt();
+                        var inverseQ = privateKeyReader.ReadBigInt();
+                        var p = privateKeyReader.ReadBigInt();
+                        var q = privateKeyReader.ReadBigInt();
                         parsedKey = new RsaKey(modulus, exponent, d, p, q, inverseQ);
                         break;
                     default:
@@ -233,14 +229,17 @@ namespace Renci.SshNet
 
                 // The list of privatekey/comment pairs is padded with the bytes 1, 2, 3, ...
                 // until the total length is a multiple of the cipher block size.
-                var padding = privateKeyReader.ReadBytes();
-                for (var i = 0; i < padding.Length; i++)
+                int b, i = 0;
+
+                while ((b = privateKeyReader.ReadByte()) != -1)
                 {
-                    if ((int)padding[i] != i + 1)
+                    if (b != i + 1)
                     {
                         throw new SshException("Padding of openssh key format contained wrong byte at position: " +
                                                i.ToString(CultureInfo.InvariantCulture));
                     }
+
+                    i++;
                 }
 
                 return parsedKey;

+ 11 - 11
src/Renci.SshNet/PrivateKeyFile.PuTTY.cs

@@ -163,34 +163,34 @@ namespace Renci.SshNet
                     throw new SshException("MAC verification failed for PuTTY key file");
                 }
 
-                var publicKeyReader = new SshDataReader(_publicKey);
+                var publicKeyReader = new SshDataStream(_publicKey);
                 var keyType = publicKeyReader.ReadString(Encoding.UTF8);
                 Debug.Assert(keyType == _algorithmName, $"{nameof(keyType)} is not the same as {nameof(_algorithmName)}");
 
-                var privateKeyReader = new SshDataReader(privateKey);
+                var privateKeyReader = new SshDataStream(privateKey);
 
                 Key parsedKey;
 
                 switch (keyType)
                 {
                     case "ssh-ed25519":
-                        parsedKey = new ED25519Key(privateKeyReader.ReadBignum2());
+                        parsedKey = new ED25519Key(privateKeyReader.ReadBinary());
                         break;
                     case "ecdsa-sha2-nistp256":
                     case "ecdsa-sha2-nistp384":
                     case "ecdsa-sha2-nistp521":
                         var curve = publicKeyReader.ReadString(Encoding.ASCII);
-                        var pub = publicKeyReader.ReadBignum2();
-                        var prv = privateKeyReader.ReadBignum2();
+                        var pub = publicKeyReader.ReadBinary();
+                        var prv = privateKeyReader.ReadBinary();
                         parsedKey = new EcdsaKey(curve, pub, prv);
                         break;
                     case "ssh-rsa":
-                        var exponent = publicKeyReader.ReadBignum(); // e
-                        var modulus = publicKeyReader.ReadBignum(); // n
-                        var d = privateKeyReader.ReadBignum(); // d
-                        var p = privateKeyReader.ReadBignum(); // p
-                        var q = privateKeyReader.ReadBignum(); // q
-                        var inverseQ = privateKeyReader.ReadBignum(); // iqmp
+                        var exponent = publicKeyReader.ReadBigInt();
+                        var modulus = publicKeyReader.ReadBigInt();
+                        var d = privateKeyReader.ReadBigInt();
+                        var p = privateKeyReader.ReadBigInt();
+                        var q = privateKeyReader.ReadBigInt();
+                        var inverseQ = privateKeyReader.ReadBigInt();
                         parsedKey = new RsaKey(modulus, exponent, d, p, q, inverseQ);
                         break;
                     default:

+ 19 - 12
src/Renci.SshNet/PrivateKeyFile.SSHCOM.cs

@@ -1,6 +1,7 @@
 #nullable enable
 using System;
 using System.Collections.Generic;
+using System.Numerics;
 using System.Security.Cryptography;
 using System.Text;
 
@@ -27,7 +28,7 @@ namespace Renci.SshNet
 
             public Key Parse()
             {
-                var reader = new SshDataReader(_data);
+                var reader = new SshDataStream(_data);
                 var magicNumber = reader.ReadUInt32();
                 if (magicNumber != 0x3f6ff9eb)
                 {
@@ -60,11 +61,7 @@ namespace Renci.SshNet
                     throw new SshException(string.Format("Cipher method '{0}' is not supported.", ssh2CipherName));
                 }
 
-                /*
-                 * TODO: Create two specific data types to avoid using SshDataReader class.
-                 */
-
-                reader = new SshDataReader(keyData);
+                reader = new SshDataStream(keyData);
 
                 var decryptedLength = reader.ReadUInt32();
 
@@ -75,16 +72,26 @@ namespace Renci.SshNet
 
                 if (keyType.Contains("rsa"))
                 {
-                    var exponent = reader.ReadBigIntWithBits(); // e
-                    var d = reader.ReadBigIntWithBits(); // d
-                    var modulus = reader.ReadBigIntWithBits(); // n
-                    var inverseQ = reader.ReadBigIntWithBits(); // u
-                    var q = reader.ReadBigIntWithBits(); // p
-                    var p = reader.ReadBigIntWithBits(); // q
+                    var exponent = ReadBigIntWithBits(reader);
+                    var d = ReadBigIntWithBits(reader);
+                    var modulus = ReadBigIntWithBits(reader);
+                    var inverseQ = ReadBigIntWithBits(reader);
+                    var q = ReadBigIntWithBits(reader);
+                    var p = ReadBigIntWithBits(reader);
                     return new RsaKey(modulus, exponent, d, p, q, inverseQ);
                 }
 
                 throw new NotSupportedException(string.Format("Key type '{0}' is not supported.", keyType));
+
+                // Reads next mpint where length is specified in bits.
+                static BigInteger ReadBigIntWithBits(SshDataStream reader)
+                {
+                    var numBits = (int)reader.ReadUInt32();
+
+                    var numBytes = (numBits + 7) / 8;
+
+                    return reader.ReadBytes(numBytes).ToBigInteger2();
+                }
             }
 
             private static byte[] GetCipherKey(string passphrase, int length)

+ 0 - 61
src/Renci.SshNet/PrivateKeyFile.cs

@@ -6,9 +6,7 @@ using System.Diagnostics.CodeAnalysis;
 using System.Globalization;
 using System.IO;
 using System.Linq;
-using System.Numerics;
 using System.Security.Cryptography;
-using System.Text;
 using System.Text.RegularExpressions;
 
 using Renci.SshNet.Common;
@@ -470,65 +468,6 @@ namespace Renci.SshNet
             }
         }
 
-        private sealed class SshDataReader : SshData
-        {
-            public SshDataReader(byte[] data)
-            {
-                Load(data);
-            }
-
-            public new uint ReadUInt32()
-            {
-                return base.ReadUInt32();
-            }
-
-            public new string ReadString(Encoding encoding)
-            {
-                return base.ReadString(encoding);
-            }
-
-            public new byte[] ReadBytes(int length)
-            {
-                return base.ReadBytes(length);
-            }
-
-            public new byte[] ReadBytes()
-            {
-                return base.ReadBytes();
-            }
-
-            /// <summary>
-            /// Reads next mpint data type from internal buffer where length specified in bits.
-            /// </summary>
-            /// <returns>mpint read.</returns>
-            public BigInteger ReadBigIntWithBits()
-            {
-                var length = (int)base.ReadUInt32();
-
-                length = (length + 7) / 8;
-
-                return base.ReadBytes(length).ToBigInteger2();
-            }
-
-            public BigInteger ReadBignum()
-            {
-                return DataStream.ReadBigInt();
-            }
-
-            public byte[] ReadBignum2()
-            {
-                return ReadBinary();
-            }
-
-            protected override void LoadData()
-            {
-            }
-
-            protected override void SaveData()
-            {
-            }
-        }
-
         /// <summary>
         /// Represents private key parser.
         /// </summary>