SocketAbstraction.cs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. using System;
  2. using System.Globalization;
  3. using System.Net;
  4. using System.Net.Sockets;
  5. using System.Threading;
  6. using System.Threading.Tasks;
  7. using Renci.SshNet.Common;
  8. using Renci.SshNet.Messages.Transport;
  9. namespace Renci.SshNet.Abstractions
  10. {
  11. internal static class SocketAbstraction
  12. {
  13. public static bool CanRead(Socket socket)
  14. {
  15. if (socket.Connected)
  16. {
  17. return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0;
  18. }
  19. return false;
  20. }
  21. /// <summary>
  22. /// Returns a value indicating whether the specified <see cref="Socket"/> can be used
  23. /// to send data.
  24. /// </summary>
  25. /// <param name="socket">The <see cref="Socket"/> to check.</param>
  26. /// <returns>
  27. /// <see langword="true"/> if <paramref name="socket"/> can be written to; otherwise, <see langword="false"/>.
  28. /// </returns>
  29. public static bool CanWrite(Socket socket)
  30. {
  31. if (socket != null && socket.Connected)
  32. {
  33. return socket.Poll(-1, SelectMode.SelectWrite);
  34. }
  35. return false;
  36. }
  37. public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
  38. {
  39. var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
  40. ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: true);
  41. return socket;
  42. }
  43. public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
  44. {
  45. ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: false);
  46. }
  47. public static async Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
  48. {
  49. await socket.ConnectAsync(remoteEndpoint, cancellationToken).ConfigureAwait(false);
  50. }
  51. private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
  52. {
  53. #if FEATURE_SOCKET_EAP
  54. var connectCompleted = new ManualResetEvent(initialState: false);
  55. var args = new SocketAsyncEventArgs
  56. {
  57. UserToken = connectCompleted,
  58. RemoteEndPoint = remoteEndpoint
  59. };
  60. args.Completed += ConnectCompleted;
  61. if (socket.ConnectAsync(args))
  62. {
  63. if (!connectCompleted.WaitOne(connectTimeout))
  64. {
  65. // avoid ObjectDisposedException in ConnectCompleted
  66. args.Completed -= ConnectCompleted;
  67. if (ownsSocket)
  68. {
  69. // dispose Socket
  70. socket.Dispose();
  71. }
  72. // dispose ManualResetEvent
  73. connectCompleted.Dispose();
  74. // dispose SocketAsyncEventArgs
  75. args.Dispose();
  76. throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
  77. "Connection failed to establish within {0:F0} milliseconds.",
  78. connectTimeout.TotalMilliseconds));
  79. }
  80. }
  81. // dispose ManualResetEvent
  82. connectCompleted.Dispose();
  83. if (args.SocketError != SocketError.Success)
  84. {
  85. var socketError = (int) args.SocketError;
  86. if (ownsSocket)
  87. {
  88. // dispose Socket
  89. socket.Dispose();
  90. }
  91. // dispose SocketAsyncEventArgs
  92. args.Dispose();
  93. throw new SocketException(socketError);
  94. }
  95. // dispose SocketAsyncEventArgs
  96. args.Dispose();
  97. #elif FEATURE_SOCKET_APM
  98. var connectResult = socket.BeginConnect(remoteEndpoint, null, null);
  99. if (!connectResult.AsyncWaitHandle.WaitOne(connectTimeout, false))
  100. throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
  101. "Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds));
  102. socket.EndConnect(connectResult);
  103. #elif FEATURE_SOCKET_TAP
  104. if (!socket.ConnectAsync(remoteEndpoint).Wait(connectTimeout))
  105. throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
  106. "Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds));
  107. #else
  108. #error Connecting to a remote endpoint is not implemented.
  109. #endif
  110. }
  111. public static void ClearReadBuffer(Socket socket)
  112. {
  113. var timeout = TimeSpan.FromMilliseconds(500);
  114. var buffer = new byte[256];
  115. int bytesReceived;
  116. do
  117. {
  118. bytesReceived = ReadPartial(socket, buffer, 0, buffer.Length, timeout);
  119. }
  120. while (bytesReceived > 0);
  121. }
  122. public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout)
  123. {
  124. socket.ReceiveTimeout = (int) timeout.TotalMilliseconds;
  125. try
  126. {
  127. return socket.Receive(buffer, offset, size, SocketFlags.None);
  128. }
  129. catch (SocketException ex)
  130. {
  131. if (ex.SocketErrorCode == SocketError.TimedOut)
  132. {
  133. throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
  134. "Socket read operation has timed out after {0:F0} milliseconds.",
  135. timeout.TotalMilliseconds));
  136. }
  137. throw;
  138. }
  139. }
  140. public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int size, Action<byte[], int, int> processReceivedBytesAction)
  141. {
  142. // do not time-out receive
  143. socket.ReceiveTimeout = 0;
  144. while (socket.Connected)
  145. {
  146. try
  147. {
  148. var bytesRead = socket.Receive(buffer, offset, size, SocketFlags.None);
  149. if (bytesRead == 0)
  150. {
  151. break;
  152. }
  153. processReceivedBytesAction(buffer, offset, bytesRead);
  154. }
  155. catch (SocketException ex)
  156. {
  157. if (IsErrorResumable(ex.SocketErrorCode))
  158. {
  159. continue;
  160. }
  161. #pragma warning disable IDE0010 // Add missing cases
  162. switch (ex.SocketErrorCode)
  163. {
  164. case SocketError.ConnectionAborted:
  165. case SocketError.ConnectionReset:
  166. // connection was closed
  167. return;
  168. case SocketError.Interrupted:
  169. // connection was closed because FIN/ACK was not received in time after
  170. // shutting down the (send part of the) socket
  171. return;
  172. default:
  173. throw; // throw any other error
  174. }
  175. #pragma warning restore IDE0010 // Add missing cases
  176. }
  177. }
  178. }
  179. /// <summary>
  180. /// Reads a byte from the specified <see cref="Socket"/>.
  181. /// </summary>
  182. /// <param name="socket">The <see cref="Socket"/> to read from.</param>
  183. /// <param name="timeout">Specifies the amount of time after which the call will time out.</param>
  184. /// <returns>
  185. /// The byte read, or <c>-1</c> if the socket was closed.
  186. /// </returns>
  187. /// <exception cref="SshOperationTimeoutException">The read operation timed out.</exception>
  188. /// <exception cref="SocketException">The read failed.</exception>
  189. public static int ReadByte(Socket socket, TimeSpan timeout)
  190. {
  191. var buffer = new byte[1];
  192. if (Read(socket, buffer, 0, 1, timeout) == 0)
  193. {
  194. return -1;
  195. }
  196. return buffer[0];
  197. }
  198. /// <summary>
  199. /// Sends a byte using the specified <see cref="Socket"/>.
  200. /// </summary>
  201. /// <param name="socket">The <see cref="Socket"/> to write to.</param>
  202. /// <param name="value">The value to send.</param>
  203. /// <exception cref="SocketException">The write failed.</exception>
  204. public static void SendByte(Socket socket, byte value)
  205. {
  206. var buffer = new[] { value };
  207. Send(socket, buffer, 0, 1);
  208. }
  209. /// <summary>
  210. /// Receives data from a bound <see cref="Socket"/>.
  211. /// </summary>
  212. /// <param name="socket">The <see cref="Socket"/> to read from.</param>
  213. /// <param name="size">The number of bytes to receive.</param>
  214. /// <param name="timeout">Specifies the amount of time after which the call will time out.</param>
  215. /// <returns>
  216. /// The bytes received.
  217. /// </returns>
  218. /// <remarks>
  219. /// If no data is available for reading, the <see cref="Read(Socket, int, TimeSpan)"/> method will
  220. /// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the
  221. /// <see cref="Read(Socket, int, TimeSpan)"/> call will throw a <see cref="SshOperationTimeoutException"/>.
  222. /// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the
  223. /// <see cref="Read(Socket, int, TimeSpan)"/> method will complete immediately and throw a <see cref="SocketException"/>.
  224. /// </remarks>
  225. public static byte[] Read(Socket socket, int size, TimeSpan timeout)
  226. {
  227. var buffer = new byte[size];
  228. _ = Read(socket, buffer, 0, size, timeout);
  229. return buffer;
  230. }
  231. /// <summary>
  232. /// Receives data from a bound <see cref="Socket"/> into a receive buffer.
  233. /// </summary>
  234. /// <param name="socket">The <see cref="Socket"/> to read from.</param>
  235. /// <param name="buffer">An array of type <see cref="byte"/> that is the storage location for the received data. </param>
  236. /// <param name="offset">The position in <paramref name="buffer"/> parameter to store the received data.</param>
  237. /// <param name="size">The number of bytes to receive.</param>
  238. /// <param name="readTimeout">The maximum time to wait until <paramref name="size"/> bytes have been received.</param>
  239. /// <returns>
  240. /// The number of bytes received.
  241. /// </returns>
  242. /// <remarks>
  243. /// <para>
  244. /// If no data is available for reading, the <see cref="Read(Socket, byte[], int, int, TimeSpan)"/> method will
  245. /// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the
  246. /// <see cref="Read(Socket, byte[], int, int, TimeSpan)"/> call will throw a <see cref="SshOperationTimeoutException"/>.
  247. /// </para>
  248. /// <para>
  249. /// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the
  250. /// <see cref="Read(Socket, byte[], int, int, TimeSpan)"/> method will complete immediately and throw a <see cref="SocketException"/>.
  251. /// </para>
  252. /// </remarks>
  253. public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeSpan readTimeout)
  254. {
  255. var totalBytesRead = 0;
  256. var totalBytesToRead = size;
  257. socket.ReceiveTimeout = (int) readTimeout.TotalMilliseconds;
  258. do
  259. {
  260. try
  261. {
  262. var bytesRead = socket.Receive(buffer, offset + totalBytesRead, totalBytesToRead - totalBytesRead, SocketFlags.None);
  263. if (bytesRead == 0)
  264. {
  265. return 0;
  266. }
  267. totalBytesRead += bytesRead;
  268. }
  269. catch (SocketException ex)
  270. {
  271. if (IsErrorResumable(ex.SocketErrorCode))
  272. {
  273. ThreadAbstraction.Sleep(30);
  274. continue;
  275. }
  276. if (ex.SocketErrorCode == SocketError.TimedOut)
  277. {
  278. throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
  279. "Socket read operation has timed out after {0:F0} milliseconds.",
  280. readTimeout.TotalMilliseconds));
  281. }
  282. throw;
  283. }
  284. }
  285. while (totalBytesRead < totalBytesToRead);
  286. return totalBytesRead;
  287. }
  288. public static Task<int> ReadAsync(Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
  289. {
  290. return socket.ReceiveAsync(buffer, offset, length, cancellationToken);
  291. }
  292. public static void Send(Socket socket, byte[] data)
  293. {
  294. Send(socket, data, 0, data.Length);
  295. }
  296. public static void Send(Socket socket, byte[] data, int offset, int size)
  297. {
  298. var totalBytesSent = 0; // how many bytes are already sent
  299. var totalBytesToSend = size;
  300. do
  301. {
  302. try
  303. {
  304. var bytesSent = socket.Send(data, offset + totalBytesSent, totalBytesToSend - totalBytesSent, SocketFlags.None);
  305. if (bytesSent == 0)
  306. {
  307. throw new SshConnectionException("An established connection was aborted by the server.",
  308. DisconnectReason.ConnectionLost);
  309. }
  310. totalBytesSent += bytesSent;
  311. }
  312. catch (SocketException ex)
  313. {
  314. if (IsErrorResumable(ex.SocketErrorCode))
  315. {
  316. // socket buffer is probably full, wait and try again
  317. ThreadAbstraction.Sleep(30);
  318. }
  319. else
  320. {
  321. throw; // any serious error occurr
  322. }
  323. }
  324. }
  325. while (totalBytesSent < totalBytesToSend);
  326. }
  327. public static bool IsErrorResumable(SocketError socketError)
  328. {
  329. #pragma warning disable IDE0010 // Add missing cases
  330. switch (socketError)
  331. {
  332. case SocketError.WouldBlock:
  333. case SocketError.IOPending:
  334. case SocketError.NoBufferSpaceAvailable:
  335. return true;
  336. default:
  337. return false;
  338. }
  339. #pragma warning restore IDE0010 // Add missing cases
  340. }
  341. #if FEATURE_SOCKET_EAP
  342. private static void ConnectCompleted(object sender, SocketAsyncEventArgs e)
  343. {
  344. var eventWaitHandle = (ManualResetEvent) e.UserToken;
  345. _ = eventWaitHandle?.Set();
  346. }
  347. #endif // FEATURE_SOCKET_EAP
  348. }
  349. }