浏览代码

Ignore data of SSH_MSG_IGNORE when its specified length is greater than the actual available bytes.
Fixes issue #41.

drieseng 9 年之前
父节点
当前提交
c141c7d4a7

+ 208 - 2
src/Renci.SshNet.Tests/Classes/Common/SshDataTest.cs

@@ -1,4 +1,5 @@
-using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
 using Renci.SshNet.Common;
 
 namespace Renci.SshNet.Tests.Classes.Common
@@ -10,7 +11,7 @@ namespace Renci.SshNet.Tests.Classes.Common
         public void Write_Boolean_False()
         {
             var sshData = new BoolSshData(false);
-            
+
             var bytes = sshData.GetBytes();
 
             Assert.AreEqual((byte) 0, bytes[0]);
@@ -26,6 +27,125 @@ namespace Renci.SshNet.Tests.Classes.Common
             Assert.AreEqual((byte) 1, bytes[0]);
         }
 
+        [TestMethod]
+        public void Load_Data()
+        {
+            const uint one = 123456u;
+            const uint two = 456789u;
+
+            var sshDataStream = new SshDataStream(8);
+            sshDataStream.Write(one);
+            sshDataStream.Write(two);
+
+            var sshData = sshDataStream.ToArray();
+
+            var request = new RequestSshData();
+            request.Load(sshData);
+
+            Assert.AreEqual(one, request.ValueOne);
+            Assert.AreEqual(two, request.ValueTwo);
+        }
+
+        [TestMethod]
+        public void Load_Data_ShouldThrowArgumentNullExceptionWhenDataIsNull()
+        {
+            const byte[] sshData = null;
+            var request = new RequestSshData();
+
+            try
+            {
+                request.Load(sshData);
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("data", ex.ParamName);
+            }
+        }
+
+        [TestMethod]
+        public void Load_DataAndOffsetAndCount()
+        {
+            const uint one = 123456u;
+            const uint two = 456789u;
+
+            var sshDataStream = new SshDataStream(11);
+            sshDataStream.WriteByte(0x05);
+            sshDataStream.WriteByte(0x07);
+            sshDataStream.WriteByte(0x0f);
+            sshDataStream.Write(one);
+            sshDataStream.Write(two);
+
+            var sshData = sshDataStream.ToArray();
+
+            var request = new RequestSshData();
+            request.Load(sshData, 3, sshData.Length - 3);
+
+            Assert.AreEqual(one, request.ValueOne);
+            Assert.AreEqual(two, request.ValueTwo);
+        }
+
+
+        [TestMethod]
+        public void OfType()
+        {
+            const uint one = 123456u;
+            const uint two = 456789u;
+
+            var sshDataStream = new SshDataStream(8);
+            sshDataStream.Write(one);
+            sshDataStream.Write(two);
+
+            var sshData = sshDataStream.ToArray();
+
+            var request = new RequestSshData();
+            request.Load(sshData);
+
+            var reply = request.OfType<ReplySshData>();
+            Assert.IsNotNull(reply);
+            Assert.AreEqual(one, reply.ValueOne);
+        }
+
+        [TestMethod]
+        public void OfType_LoadWithOffset()
+        {
+            const uint one = 123456u;
+            const uint two = 456789u;
+
+            var sshDataStream = new SshDataStream(11);
+            sshDataStream.WriteByte(0x05);
+            sshDataStream.WriteByte(0x07);
+            sshDataStream.WriteByte(0x0f);
+            sshDataStream.Write(one);
+            sshDataStream.Write(two);
+
+            var sshData = sshDataStream.ToArray();
+
+            var request = new RequestSshData();
+            request.Load(sshData, 3, sshData.Length - 3);
+            var reply = request.OfType<ReplySshData>();
+            Assert.IsNotNull(reply);
+            Assert.AreEqual(one, reply.ValueOne);
+        }
+
+        [TestMethod]
+        public void OfType_ShouldThrowArgumentNullExceptionWhenNoDataIsLoaded()
+        {
+            var request = new RequestSshData();
+
+            try
+            {
+                request.OfType<ReplySshData>();
+                Assert.Fail();
+            }
+            catch (ArgumentNullException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual("data", ex.ParamName);
+            }
+        }
+
         private class BoolSshData : SshData
         {
             private readonly bool _value;
@@ -54,5 +174,91 @@ namespace Renci.SshNet.Tests.Classes.Common
                 Write(_value);
             }
         }
+
+        private class RequestSshData : SshData
+        {
+            private uint _valueOne;
+            private uint _valueTwo;
+
+            public RequestSshData()
+            {
+            }
+
+            public RequestSshData(uint one, uint two)
+            {
+                _valueOne = one;
+                _valueTwo = two;
+            }
+
+            protected override int BufferCapacity
+            {
+                get
+                {
+                    var capacity = base.BufferCapacity;
+                    capacity += 4; // ValueOne
+                    capacity += 4; // ValueTwo
+                    return capacity;
+                }
+            }
+
+            public uint ValueOne
+            {
+                get { return _valueOne; }
+                set { _valueOne = value; }
+            }
+
+            public uint ValueTwo
+            {
+                get { return _valueTwo; }
+                set { _valueTwo = value; }
+            }
+
+            protected override void LoadData()
+            {
+                _valueOne = ReadUInt32();
+                _valueTwo = ReadUInt32();
+            }
+
+            protected override void SaveData()
+            {
+                Write(ValueOne);
+                Write(ValueTwo);
+            }
+        }
+
+        private class ReplySshData : SshData
+        {
+            private uint _valueOne;
+
+            public ReplySshData()
+            {
+            }
+
+            protected override int BufferCapacity
+            {
+                get
+                {
+                    var capacity = base.BufferCapacity;
+                    capacity += 4; // ValueOne
+                    return capacity;
+                }
+            }
+
+            public uint ValueOne
+            {
+                get { return _valueOne; }
+                set { _valueOne = value; }
+            }
+
+            protected override void LoadData()
+            {
+                _valueOne = ReadUInt32();
+            }
+
+            protected override void SaveData()
+            {
+                Write(ValueOne);
+            }
+        }
     }
 }

+ 42 - 2
src/Renci.SshNet.Tests/Classes/Messages/Transport/IgnoreMessageTest.cs

@@ -1,4 +1,5 @@
 using System;
+using System.Globalization;
 using System.Linq;
 using Renci.SshNet.Common;
 using Renci.SshNet.Messages.Transport;
@@ -38,7 +39,7 @@ namespace Renci.SshNet.Tests.Classes.Messages.Transport
         [TestMethod]
         public void Constructor_Data_ShouldThrowArgumentNullExceptionWhenDataIsNull()
         {
-            byte[] data = null;
+            const byte[] data = null;
 
             try
             {
@@ -85,11 +86,50 @@ namespace Renci.SshNet.Tests.Classes.Messages.Transport
             var bytes = ignoreMessage.GetBytes();
             var target = new IgnoreMessage();
 
-            target.Load(bytes);
+            target.Load(bytes, 1, bytes.Length - 1);
 
             Assert.IsNotNull(target.Data);
             Assert.AreEqual(_data.Length, target.Data.Length);
             Assert.IsTrue(target.Data.SequenceEqual(_data));
         }
+
+        [TestMethod]
+        public void Load_ShouldIgnoreDataWhenItsLengthIsGreatherThanItsActualBytes()
+        {
+            var ssh = new SshDataStream(1);
+            ssh.WriteByte(2); // Type
+            ssh.Write(5u); // Data length
+            ssh.Write(new byte[3]); // Data
+
+            var ignoreMessageBytes = ssh.ToArray();
+
+            var ignoreMessage = new IgnoreMessage();
+            ignoreMessage.Load(ignoreMessageBytes, 1, ignoreMessageBytes.Length - 1);
+            Assert.IsNotNull(ignoreMessage.Data);
+            Assert.AreEqual(0, ignoreMessage.Data.Length);
+        }
+
+        [TestMethod]
+        public void Load_ShouldThrowNotSupportedExceptionWhenDataLengthIsGreaterThanInt32MaxValue()
+        {
+            var ssh = new SshDataStream(1);
+            ssh.WriteByte(2); // Type
+            ssh.Write(uint.MaxValue); // Data length
+            ssh.Write(new byte[3]);
+
+            var ignoreMessageBytes = ssh.ToArray();
+            var ignoreMessage = new IgnoreMessage();
+
+            try
+            {
+                ignoreMessage.Load(ignoreMessageBytes, 1, ignoreMessageBytes.Length - 1);
+                Assert.Fail();
+            }
+            catch (NotSupportedException ex)
+            {
+                Assert.IsNull(ex.InnerException);
+                Assert.AreEqual(string.Format(CultureInfo.CurrentCulture, "Data longer than {0} is not supported.", int.MaxValue), ex.Message);
+            }
+        }
     }
 }

+ 17 - 1
src/Renci.SshNet/Messages/Transport/IgnoreMessage.cs

@@ -1,4 +1,6 @@
 using System;
+using System.Globalization;
+using Renci.SshNet.Abstractions;
 using Renci.SshNet.Common;
 
 namespace Renci.SshNet.Messages.Transport
@@ -58,7 +60,21 @@ namespace Renci.SshNet.Messages.Transport
         /// </summary>
         protected override void LoadData()
         {
-            Data = ReadBinary();
+            var dataLength = ReadUInt32();
+            if (dataLength > int.MaxValue)
+            {
+                throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Data longer than {0} is not supported.", int.MaxValue));
+            }
+
+            if (dataLength > (DataStream.Length - DataStream.Position))
+            {
+                DiagnosticAbstraction.Log("SSH_MSG_IGNORE: Length exceeds data bytes, data ignored.");
+                Data = Array<byte>.Empty;
+            }
+            else
+            {
+                Data = ReadBytes((int) dataLength);
+            }
         }
 
         /// <summary>