Extensions.cs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Globalization;
  4. #if !NET
  5. using System.IO;
  6. #endif
  7. using System.Net;
  8. using System.Net.Sockets;
  9. using System.Numerics;
  10. using System.Runtime.CompilerServices;
  11. using System.Threading;
  12. using Renci.SshNet.Abstractions;
  13. using Renci.SshNet.Messages;
  14. namespace Renci.SshNet.Common
  15. {
  16. /// <summary>
  17. /// Collection of different extension methods.
  18. /// </summary>
  19. internal static class Extensions
  20. {
  21. #pragma warning disable S4136 // Method overloads should be grouped together
  22. internal static byte[] ToArray(this ServiceName serviceName)
  23. #pragma warning restore S4136 // Method overloads should be grouped together
  24. {
  25. switch (serviceName)
  26. {
  27. case ServiceName.UserAuthentication:
  28. return SshData.Ascii.GetBytes("ssh-userauth");
  29. case ServiceName.Connection:
  30. return SshData.Ascii.GetBytes("ssh-connection");
  31. default:
  32. throw new NotSupportedException(string.Format("Service name '{0}' is not supported.", serviceName));
  33. }
  34. }
  35. internal static ServiceName ToServiceName(this byte[] data)
  36. {
  37. var sshServiceName = SshData.Ascii.GetString(data, 0, data.Length);
  38. switch (sshServiceName)
  39. {
  40. case "ssh-userauth":
  41. return ServiceName.UserAuthentication;
  42. case "ssh-connection":
  43. return ServiceName.Connection;
  44. default:
  45. throw new NotSupportedException(string.Format("Service name '{0}' is not supported.", sshServiceName));
  46. }
  47. }
  48. internal static BigInteger ToBigInteger(this ReadOnlySpan<byte> data)
  49. {
  50. #if NET
  51. return new BigInteger(data, isBigEndian: true);
  52. #else
  53. var reversed = data.ToArray();
  54. Array.Reverse(reversed);
  55. return new BigInteger(reversed);
  56. #endif
  57. }
  58. internal static BigInteger ToBigInteger(this byte[] data)
  59. {
  60. #if NET
  61. return new BigInteger(data, isBigEndian: true);
  62. #else
  63. var reversed = new byte[data.Length];
  64. Buffer.BlockCopy(data, 0, reversed, 0, data.Length);
  65. Array.Reverse(reversed);
  66. return new BigInteger(reversed);
  67. #endif
  68. }
  69. /// <summary>
  70. /// Initializes a new instance of the <see cref="BigInteger"/> structure using the SSH BigNum2 Format.
  71. /// </summary>
  72. public static BigInteger ToBigInteger2(this byte[] data)
  73. {
  74. #if NET
  75. return new BigInteger(data, isBigEndian: true, isUnsigned: true);
  76. #else
  77. if ((data[0] & (1 << 7)) != 0)
  78. {
  79. var buf = new byte[data.Length + 1];
  80. Buffer.BlockCopy(data, 0, buf, 1, data.Length);
  81. Array.Reverse(buf);
  82. return new BigInteger(buf);
  83. }
  84. return data.ToBigInteger();
  85. #endif
  86. }
  87. #if !NET
  88. public static byte[] ToByteArray(this BigInteger bigInt, bool isUnsigned = false, bool isBigEndian = false)
  89. {
  90. var data = bigInt.ToByteArray();
  91. if (isUnsigned && data[data.Length - 1] == 0)
  92. {
  93. data = data.Take(data.Length - 1);
  94. }
  95. if (isBigEndian)
  96. {
  97. Array.Reverse(data);
  98. }
  99. return data;
  100. }
  101. #endif
  102. #if !NET
  103. public static long GetBitLength(this BigInteger bigint)
  104. {
  105. // Taken from https://github.com/dotnet/runtime/issues/31308
  106. return (long)Math.Ceiling(BigInteger.Log(bigint.Sign < 0 ? -bigint : bigint + 1, 2));
  107. }
  108. #endif
  109. // See https://github.com/dotnet/runtime/blob/9b57a265c7efd3732b035bade005561a04767128/src/libraries/Common/src/System/Security/Cryptography/KeyBlobHelpers.cs#L51
  110. public static byte[] ExportKeyParameter(this BigInteger value, int length)
  111. {
  112. var target = value.ToByteArray(isUnsigned: true, isBigEndian: true);
  113. // The BCL crypto is expecting exactly-sized byte arrays (sized to "length").
  114. // If our byte array is smaller than required, then size it up.
  115. // Otherwise, just return as is: if it is too large, we'll let the BCL throw the error.
  116. if (target.Length < length)
  117. {
  118. var correctlySized = new byte[length];
  119. Buffer.BlockCopy(target, 0, correctlySized, length - target.Length, target.Length);
  120. return correctlySized;
  121. }
  122. return target;
  123. }
  124. /// <summary>
  125. /// Sets a wait handle, swallowing any resulting <see cref="ObjectDisposedException"/>.
  126. /// Used in cases where set and dispose may race.
  127. /// </summary>
  128. /// <param name="waitHandle">The wait handle to set.</param>
  129. public static void SetIgnoringObjectDisposed(this EventWaitHandle waitHandle)
  130. {
  131. try
  132. {
  133. _ = waitHandle.Set();
  134. }
  135. catch (ObjectDisposedException)
  136. {
  137. // ODE intentionally ignored.
  138. }
  139. }
  140. internal static void ValidatePort(this uint value, [CallerArgumentExpression(nameof(value))] string argument = null)
  141. {
  142. if (value > IPEndPoint.MaxPort)
  143. {
  144. throw new ArgumentOutOfRangeException(argument,
  145. string.Format(CultureInfo.InvariantCulture, "Specified value cannot be greater than {0}.", IPEndPoint.MaxPort));
  146. }
  147. }
  148. internal static void ValidatePort(this int value, [CallerArgumentExpression(nameof(value))] string argument = null)
  149. {
  150. if (value < IPEndPoint.MinPort)
  151. {
  152. throw new ArgumentOutOfRangeException(argument, string.Format(CultureInfo.InvariantCulture, "Specified value cannot be less than {0}.", IPEndPoint.MinPort));
  153. }
  154. if (value > IPEndPoint.MaxPort)
  155. {
  156. throw new ArgumentOutOfRangeException(argument, string.Format(CultureInfo.InvariantCulture, "Specified value cannot be greater than {0}.", IPEndPoint.MaxPort));
  157. }
  158. }
  159. /// <summary>
  160. /// Returns a specified number of contiguous bytes from a given offset.
  161. /// </summary>
  162. /// <param name="value">The array to return a number of bytes from.</param>
  163. /// <param name="offset">The zero-based offset in <paramref name="value"/> at which to begin taking bytes.</param>
  164. /// <param name="count">The number of bytes to take from <paramref name="value"/>.</param>
  165. /// <returns>
  166. /// A <see cref="byte"/> array that contains the specified number of bytes at the specified offset
  167. /// of the input array.
  168. /// </returns>
  169. /// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
  170. /// <remarks>
  171. /// When <paramref name="offset"/> is zero and <paramref name="count"/> equals the length of <paramref name="value"/>,
  172. /// then <paramref name="value"/> is returned.
  173. /// </remarks>
  174. public static byte[] Take(this byte[] value, int offset, int count)
  175. {
  176. ThrowHelper.ThrowIfNull(value);
  177. if (count == 0)
  178. {
  179. return Array.Empty<byte>();
  180. }
  181. if (offset == 0 && value.Length == count)
  182. {
  183. return value;
  184. }
  185. var taken = new byte[count];
  186. Buffer.BlockCopy(value, offset, taken, 0, count);
  187. return taken;
  188. }
  189. /// <summary>
  190. /// Returns a specified number of contiguous bytes from the start of the specified byte array.
  191. /// </summary>
  192. /// <param name="value">The array to return a number of bytes from.</param>
  193. /// <param name="count">The number of bytes to take from <paramref name="value"/>.</param>
  194. /// <returns>
  195. /// A <see cref="byte"/> array that contains the specified number of bytes at the start of the input array.
  196. /// </returns>
  197. /// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
  198. /// <remarks>
  199. /// When <paramref name="count"/> equals the length of <paramref name="value"/>, then <paramref name="value"/>
  200. /// is returned.
  201. /// </remarks>
  202. public static byte[] Take(this byte[] value, int count)
  203. {
  204. ThrowHelper.ThrowIfNull(value);
  205. if (count == 0)
  206. {
  207. return Array.Empty<byte>();
  208. }
  209. if (value.Length == count)
  210. {
  211. return value;
  212. }
  213. var taken = new byte[count];
  214. Buffer.BlockCopy(value, 0, taken, 0, count);
  215. return taken;
  216. }
  217. public static bool IsEqualTo(this byte[] left, byte[] right)
  218. {
  219. ThrowHelper.ThrowIfNull(left);
  220. ThrowHelper.ThrowIfNull(right);
  221. return left.AsSpan().SequenceEqual(right);
  222. }
  223. /// <summary>
  224. /// Trims the leading zero from a byte array.
  225. /// </summary>
  226. /// <param name="value">The value.</param>
  227. /// <returns>
  228. /// <paramref name="value"/> without leading zeros.
  229. /// </returns>
  230. public static byte[] TrimLeadingZeros(this byte[] value)
  231. {
  232. ThrowHelper.ThrowIfNull(value);
  233. for (var i = 0; i < value.Length; i++)
  234. {
  235. if (value[i] == 0)
  236. {
  237. continue;
  238. }
  239. // if the first byte is non-zero, then we return the byte array as is
  240. if (i == 0)
  241. {
  242. return value;
  243. }
  244. var remainingBytes = value.Length - i;
  245. var cleaned = new byte[remainingBytes];
  246. Buffer.BlockCopy(value, i, cleaned, 0, remainingBytes);
  247. return cleaned;
  248. }
  249. return value;
  250. }
  251. /// <summary>
  252. /// Pads with leading zeros if needed.
  253. /// </summary>
  254. /// <param name="data">The data.</param>
  255. /// <param name="length">The length to pad to.</param>
  256. public static byte[] Pad(this byte[] data, int length)
  257. {
  258. if (length <= data.Length)
  259. {
  260. return data;
  261. }
  262. var newData = new byte[length];
  263. Buffer.BlockCopy(data, 0, newData, newData.Length - data.Length, data.Length);
  264. return newData;
  265. }
  266. public static byte[] Concat(this byte[] first, byte[] second)
  267. {
  268. if (first is null || first.Length == 0)
  269. {
  270. return second;
  271. }
  272. if (second is null || second.Length == 0)
  273. {
  274. return first;
  275. }
  276. var concat = new byte[first.Length + second.Length];
  277. Buffer.BlockCopy(first, 0, concat, 0, first.Length);
  278. Buffer.BlockCopy(second, 0, concat, first.Length, second.Length);
  279. return concat;
  280. }
  281. internal static bool CanRead(this Socket socket)
  282. {
  283. return SocketAbstraction.CanRead(socket);
  284. }
  285. internal static bool CanWrite(this Socket socket)
  286. {
  287. return SocketAbstraction.CanWrite(socket);
  288. }
  289. internal static bool IsConnected(this Socket socket)
  290. {
  291. if (socket is null)
  292. {
  293. return false;
  294. }
  295. return socket.Connected;
  296. }
  297. internal static string Join(this IEnumerable<string> values, string separator)
  298. {
  299. // Used to avoid analyzers asking to "use an overload with a char parameter"
  300. // which is not available on all targets.
  301. return string.Join(separator, values);
  302. }
  303. #if !NET
  304. internal static bool TryAdd<TKey, TValue>(this Dictionary<TKey, TValue> dictionary, TKey key, TValue value)
  305. {
  306. if (!dictionary.ContainsKey(key))
  307. {
  308. dictionary.Add(key, value);
  309. return true;
  310. }
  311. return false;
  312. }
  313. internal static bool Remove<TKey, TValue>(this Dictionary<TKey, TValue> dictionary, TKey key, out TValue value)
  314. {
  315. if (dictionary.TryGetValue(key, out value))
  316. {
  317. _ = dictionary.Remove(key);
  318. return true;
  319. }
  320. value = default;
  321. return false;
  322. }
  323. internal static ArraySegment<T> Slice<T>(this ArraySegment<T> arraySegment, int index)
  324. {
  325. return new ArraySegment<T>(arraySegment.Array, arraySegment.Offset + index, arraySegment.Count - index);
  326. }
  327. internal static ArraySegment<T> Slice<T>(this ArraySegment<T> arraySegment, int index, int count)
  328. {
  329. return new ArraySegment<T>(arraySegment.Array, arraySegment.Offset + index, count);
  330. }
  331. internal static T[] ToArray<T>(this ArraySegment<T> arraySegment)
  332. {
  333. if (arraySegment.Count == 0)
  334. {
  335. return Array.Empty<T>();
  336. }
  337. var array = new T[arraySegment.Count];
  338. Array.Copy(arraySegment.Array, arraySegment.Offset, array, 0, arraySegment.Count);
  339. return array;
  340. }
  341. #pragma warning disable CA1859 // Use concrete types for improved performance
  342. internal static void ReadExactly(this Stream stream, byte[] buffer, int offset, int count)
  343. #pragma warning restore CA1859
  344. {
  345. var totalRead = 0;
  346. while (totalRead < count)
  347. {
  348. var read = stream.Read(buffer, offset + totalRead, count - totalRead);
  349. if (read == 0)
  350. {
  351. throw new EndOfStreamException();
  352. }
  353. totalRead += read;
  354. }
  355. }
  356. #endif
  357. }
  358. }