Session.NET.cs 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. using System.Globalization;
  2. using System.Linq;
  3. using System;
  4. using System.Net.Sockets;
  5. using System.Net;
  6. using Renci.SshNet.Common;
  7. using Renci.SshNet.Messages.Transport;
  8. using System.Diagnostics;
  9. using System.Collections.Generic;
  10. using Renci.SshNet.Abstractions;
  11. namespace Renci.SshNet
  12. {
  13. public partial class Session
  14. {
  15. private const byte Null = 0x00;
  16. private const byte CarriageReturn = 0x0d;
  17. private const byte LineFeed = 0x0a;
  18. #if FEATURE_DIAGNOSTICS_TRACESOURCE
  19. private readonly TraceSource _log =
  20. #if DEBUG
  21. new TraceSource("SshNet.Logging", SourceLevels.All);
  22. #else
  23. new TraceSource("SshNet.Logging");
  24. #endif // DEBUG
  25. #endif // FEATURE_DIAGNOSTICS_TRACESOURCE
  26. /// <summary>
  27. /// Holds the lock object to ensure read access to the socket is synchronized.
  28. /// </summary>
  29. private readonly object _socketReadLock = new object();
  30. /// <summary>
  31. /// Gets a value indicating whether the socket is connected.
  32. /// </summary>
  33. /// <param name="isConnected"><c>true</c> if the socket is connected; otherwise, <c>false</c></param>
  34. /// <remarks>
  35. /// <para>
  36. /// As a first check we verify whether <see cref="Socket.Connected"/> is
  37. /// <c>true</c>. However, this only returns the state of the socket as of
  38. /// the last I/O operation. Therefore we use the combination of Socket.Poll
  39. /// with mode SelectRead and Socket.Available to verify if the socket is
  40. /// still connected.
  41. /// </para>
  42. /// <para>
  43. /// The MSDN doc mention the following on the return value of <see cref="Socket.Poll(int, SelectMode)"/>
  44. /// with mode <see cref="SelectMode.SelectRead"/>:
  45. /// <list type="bullet">
  46. /// <item>
  47. /// <description><c>true</c> if data is available for reading;</description>
  48. /// </item>
  49. /// <item>
  50. /// <description><c>true</c> if the connection has been closed, reset, or terminated; otherwise, returns <c>false</c>.</description>
  51. /// </item>
  52. /// </list>
  53. /// </para>
  54. /// <para>
  55. /// <c>Conclusion:</c> when the return value is <c>true</c> - but no data is available for reading - then
  56. /// the socket is no longer connected.
  57. /// </para>
  58. /// <para>
  59. /// When a <see cref="Socket"/> is used from multiple threads, there's a race condition
  60. /// between the invocation of <see cref="Socket.Poll(int, SelectMode)"/> and the moment
  61. /// when the value of <see cref="Socket.Available"/> is obtained. As a workaround, we signal
  62. /// when bytes are read from the <see cref="Socket"/>.
  63. /// </para>
  64. /// </remarks>
  65. partial void IsSocketConnected(ref bool isConnected)
  66. {
  67. isConnected = (_socket != null && _socket.Connected);
  68. if (isConnected)
  69. {
  70. // synchronize this to ensure thread B does not reset the wait handle before
  71. // thread A was able to check whether "bytes read from socket" signal was
  72. // actually received
  73. lock (_socketReadLock)
  74. {
  75. _bytesReadFromSocket.Reset();
  76. var connectionClosedOrDataAvailable = _socket.Poll(1000, SelectMode.SelectRead);
  77. isConnected = !(connectionClosedOrDataAvailable && _socket.Available == 0);
  78. if (!isConnected)
  79. {
  80. // the race condition is between the Socket.Poll call and
  81. // Socket.Available, but the event handler - where we signal that
  82. // bytes have been received from the socket - is sometimes invoked
  83. // shortly after
  84. isConnected = _bytesReadFromSocket.WaitOne(500);
  85. }
  86. }
  87. }
  88. }
  89. /// <summary>
  90. /// Establishes a socket connection to the specified host and port.
  91. /// </summary>
  92. /// <param name="host">The host name of the server to connect to.</param>
  93. /// <param name="port">The port to connect to.</param>
  94. /// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="Renci.SshNet.ConnectionInfo.Timeout"/>.</exception>
  95. /// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
  96. partial void SocketConnect(string host, int port)
  97. {
  98. const int socketBufferSize = 2 * MaximumSshPacketSize;
  99. var ipAddress = host.GetIPAddress();
  100. var timeout = ConnectionInfo.Timeout;
  101. var ep = new IPEndPoint(ipAddress, port);
  102. _socket = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
  103. _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true);
  104. _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.SendBuffer, socketBufferSize);
  105. _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReceiveBuffer, socketBufferSize);
  106. Log(string.Format("Initiating connect to '{0}:{1}'.", ConnectionInfo.Host, ConnectionInfo.Port));
  107. #if FEATURE_SOCKET_EAP
  108. if (!_socket.ConnectAsync(ep).Wait(timeout))
  109. throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
  110. "Connection failed to establish within {0:F0} milliseconds.", timeout.TotalMilliseconds));
  111. #else
  112. var connectResult = _socket.BeginConnect(ep, null, null);
  113. if (!connectResult.AsyncWaitHandle.WaitOne(timeout, false))
  114. throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
  115. "Connection failed to establish within {0:F0} milliseconds.", timeout.TotalMilliseconds));
  116. _socket.EndConnect(connectResult);
  117. #endif // FEATURE_SOCKET_ASYNC_TPL
  118. }
  119. /// <summary>
  120. /// Closes the socket and allows the socket to be reused after the current connection is closed.
  121. /// </summary>
  122. /// <exception cref="SocketException">An error occurred when trying to access the socket.</exception>
  123. partial void SocketDisconnect()
  124. {
  125. _socket.Dispose();
  126. }
  127. /// <summary>
  128. /// Performs a blocking read on the socket until a line is read.
  129. /// </summary>
  130. /// <param name="response">The line read from the socket, or <c>null</c> when the remote server has shutdown and all data has been received.</param>
  131. /// <param name="timeout">A <see cref="TimeSpan"/> that represents the time to wait until a line is read.</param>
  132. /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
  133. /// <exception cref="SocketException">An error occurred when trying to access the socket.</exception>
  134. partial void SocketReadLine(ref string response, TimeSpan timeout)
  135. {
  136. var buffer = new List<byte>();
  137. var data = new byte[1];
  138. // read data one byte at a time to find end of line and leave any unhandled information in the buffer
  139. // to be processed by subsequent invocations
  140. do
  141. {
  142. #if FEATURE_SOCKET_TAP
  143. var receiveTask = _socket.ReceiveAsync(new ArraySegment<byte>(data, 0, data.Length), SocketFlags.None);
  144. if (!receiveTask.Wait(timeout))
  145. throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
  146. "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds));
  147. var received = receiveTask.Result;
  148. #else
  149. var asyncResult = _socket.BeginReceive(data, 0, data.Length, SocketFlags.None, null, null);
  150. if (!asyncResult.AsyncWaitHandle.WaitOne(timeout))
  151. throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
  152. "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds));
  153. var received = _socket.EndReceive(asyncResult);
  154. #endif // FEATURE_SOCKET_TAP
  155. if (received == 0)
  156. // the remote server shut down the socket
  157. break;
  158. buffer.Add(data[0]);
  159. }
  160. while (!(buffer.Count > 0 && (buffer[buffer.Count - 1] == LineFeed || buffer[buffer.Count - 1] == Null)));
  161. if (buffer.Count == 0)
  162. response = null;
  163. else if (buffer.Count == 1 && buffer[buffer.Count - 1] == 0x00)
  164. // return an empty version string if the buffer consists of only a 0x00 character
  165. response = string.Empty;
  166. else if (buffer.Count > 1 && buffer[buffer.Count - 2] == CarriageReturn)
  167. // strip trailing CRLF
  168. response = SshData.Ascii.GetString(buffer.Take(buffer.Count - 2).ToArray());
  169. else if (buffer.Count > 1 && buffer[buffer.Count - 1] == LineFeed)
  170. // strip trailing LF
  171. response = SshData.Ascii.GetString(buffer.Take(buffer.Count - 1).ToArray());
  172. else
  173. response = SshData.Ascii.GetString(buffer.ToArray());
  174. }
  175. /// <summary>
  176. /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
  177. /// </summary>
  178. /// <param name="length">The number of bytes to read.</param>
  179. /// <param name="buffer">The buffer to read to.</param>
  180. /// <exception cref="SshConnectionException">The socket is closed.</exception>
  181. /// <exception cref="SocketException">The read failed.</exception>
  182. partial void SocketRead(int length, ref byte[] buffer)
  183. {
  184. var receivedTotal = 0; // how many bytes is already received
  185. do
  186. {
  187. try
  188. {
  189. var receivedBytes = _socket.Receive(buffer, receivedTotal, length - receivedTotal, SocketFlags.None);
  190. if (receivedBytes > 0)
  191. {
  192. // signal that bytes have been read from the socket
  193. // this is used to improve accuracy of Session.IsSocketConnected
  194. _bytesReadFromSocket.Set();
  195. receivedTotal += receivedBytes;
  196. continue;
  197. }
  198. // 2012-09-11: Kenneth_aa
  199. // When Disconnect or Dispose is called, this throws SshConnectionException(), which...
  200. // 1 - goes up to ReceiveMessage()
  201. // 2 - up again to MessageListener()
  202. // which is where there is a catch-all exception block so it can notify event listeners.
  203. // 3 - MessageListener then again calls RaiseError().
  204. // There the exception is checked for the exception thrown here (ConnectionLost), and if it matches it will not call Session.SendDisconnect().
  205. //
  206. // Adding a check for _isDisconnecting causes ReceiveMessage() to throw SshConnectionException: "Bad packet length {0}".
  207. //
  208. if (_isDisconnecting)
  209. throw new SshConnectionException("An established connection was aborted by the software in your host machine.", DisconnectReason.ConnectionLost);
  210. throw new SshConnectionException("An established connection was aborted by the server.", DisconnectReason.ConnectionLost);
  211. }
  212. catch (SocketException exp)
  213. {
  214. if (exp.SocketErrorCode == SocketError.WouldBlock ||
  215. exp.SocketErrorCode == SocketError.IOPending ||
  216. exp.SocketErrorCode == SocketError.NoBufferSpaceAvailable)
  217. {
  218. // socket buffer is probably empty, wait and try again
  219. ThreadAbstraction.Sleep(30);
  220. }
  221. else
  222. {
  223. throw new SshConnectionException(exp.Message, DisconnectReason.ConnectionLost, exp);
  224. }
  225. }
  226. } while (receivedTotal < length);
  227. }
  228. /// <summary>
  229. /// Writes the specified data to the server.
  230. /// </summary>
  231. /// <param name="data">The data to write to the server.</param>
  232. /// <param name="offset">The zero-based offset in <paramref name="data"/> at which to begin taking data from.</param>
  233. /// <param name="length">The number of bytes of <paramref name="data"/> to write.</param>
  234. /// <exception cref="SshOperationTimeoutException">The write has timed-out.</exception>
  235. /// <exception cref="SocketException">The write failed.</exception>
  236. private void SocketWrite(byte[] data, int offset, int length)
  237. {
  238. var totalBytesSent = 0; // how many bytes are already sent
  239. var totalBytesToSend = length;
  240. do
  241. {
  242. try
  243. {
  244. totalBytesSent += _socket.Send(data, offset + totalBytesSent, totalBytesToSend - totalBytesSent,
  245. SocketFlags.None);
  246. }
  247. catch (SocketException ex)
  248. {
  249. if (ex.SocketErrorCode == SocketError.WouldBlock ||
  250. ex.SocketErrorCode == SocketError.IOPending ||
  251. ex.SocketErrorCode == SocketError.NoBufferSpaceAvailable)
  252. {
  253. // socket buffer is probably full, wait and try again
  254. ThreadAbstraction.Sleep(30);
  255. }
  256. else
  257. throw; // any serious error occurr
  258. }
  259. } while (totalBytesSent < totalBytesToSend);
  260. }
  261. [Conditional("DEBUG")]
  262. partial void Log(string text)
  263. {
  264. #if FEATURE_DIAGNOSTICS_TRACESOURCE
  265. _log.TraceEvent(TraceEventType.Verbose, 1, text);
  266. #endif // FEATURE_DIAGNOSTICS_TRACESOURCE
  267. }
  268. #if ASYNC_SOCKET_READ
  269. private void SocketRead(int length, ref byte[] buffer)
  270. {
  271. var state = new SocketReadState(_socket, length, ref buffer);
  272. _socket.BeginReceive(buffer, 0, length, SocketFlags.None, SocketReceiveCallback, state);
  273. var readResult = state.Wait();
  274. switch (readResult)
  275. {
  276. case SocketReadResult.Complete:
  277. break;
  278. case SocketReadResult.ConnectionLost:
  279. if (_isDisconnecting)
  280. throw new SshConnectionException(
  281. "An established connection was aborted by the software in your host machine.",
  282. DisconnectReason.ConnectionLost);
  283. throw new SshConnectionException("An established connection was aborted by the server.",
  284. DisconnectReason.ConnectionLost);
  285. case SocketReadResult.Failed:
  286. var socketException = state.Exception as SocketException;
  287. if (socketException != null)
  288. {
  289. if (socketException.SocketErrorCode == SocketError.ConnectionAborted)
  290. {
  291. buffer = new byte[length];
  292. Disconnect();
  293. return;
  294. }
  295. }
  296. throw state.Exception;
  297. }
  298. }
  299. private void SocketReceiveCallback(IAsyncResult ar)
  300. {
  301. var state = ar.AsyncState as SocketReadState;
  302. var socket = state.Socket;
  303. try
  304. {
  305. var bytesReceived = socket.EndReceive(ar);
  306. if (bytesReceived > 0)
  307. {
  308. _bytesReadFromSocket.Set();
  309. state.BytesRead += bytesReceived;
  310. if (state.BytesRead < state.TotalBytesToRead)
  311. {
  312. socket.BeginReceive(state.Buffer, state.BytesRead, state.TotalBytesToRead - state.BytesRead,
  313. SocketFlags.None, SocketReceiveCallback, state);
  314. }
  315. else
  316. {
  317. // we received all bytes that we wanted, so lets mark the read
  318. // complete
  319. state.Complete();
  320. }
  321. }
  322. else
  323. {
  324. // the remote host shut down the connection; this could also have been
  325. // triggered by a SSH_MSG_DISCONNECT sent by the client
  326. state.ConnectionLost();
  327. }
  328. }
  329. catch (SocketException ex)
  330. {
  331. if (ex.SocketErrorCode != SocketError.ConnectionAborted)
  332. {
  333. if (ex.SocketErrorCode == SocketError.WouldBlock ||
  334. ex.SocketErrorCode == SocketError.IOPending ||
  335. ex.SocketErrorCode == SocketError.NoBufferSpaceAvailable)
  336. {
  337. // socket buffer is probably empty, wait and try again
  338. Thread.Sleep(30);
  339. socket.BeginReceive(state.Buffer, state.BytesRead, state.TotalBytesToRead - state.BytesRead,
  340. SocketFlags.None, SocketReceiveCallback, state);
  341. return;
  342. }
  343. }
  344. state.Fail(ex);
  345. }
  346. catch (Exception ex)
  347. {
  348. state.Fail(ex);
  349. }
  350. }
  351. private class SocketReadState
  352. {
  353. private SocketReadResult _result;
  354. /// <summary>
  355. /// WaitHandle to signal that read from socket has completed (either successfully
  356. /// or with failure)
  357. /// </summary>
  358. private EventWaitHandle _socketReadComplete;
  359. public SocketReadState(Socket socket, int totalBytesToRead, ref byte[] buffer)
  360. {
  361. Socket = socket;
  362. TotalBytesToRead = totalBytesToRead;
  363. Buffer = buffer;
  364. _socketReadComplete = new ManualResetEvent(false);
  365. }
  366. /// <summary>
  367. /// Gets the <see cref="Socket"/> to read from.
  368. /// </summary>
  369. /// <value>
  370. /// The <see cref="Socket"/> to read from.
  371. /// </value>
  372. public Socket Socket { get; private set; }
  373. /// <summary>
  374. /// Gets or sets the number of bytes that have been read from the <see cref="Socket"/>.
  375. /// </summary>
  376. /// <value>
  377. /// The number of bytes that have been read from the <see cref="Socket"/>.
  378. /// </value>
  379. public int BytesRead { get; set; }
  380. /// <summary>
  381. /// Gets the total number of bytes to read from the <see cref="Socket"/>.
  382. /// </summary>
  383. /// <value>
  384. /// The total number of bytes to read from the <see cref="Socket"/>.
  385. /// </value>
  386. public int TotalBytesToRead { get; private set; }
  387. /// <summary>
  388. /// Gets or sets the buffer to hold the bytes that have been read.
  389. /// </summary>
  390. /// <value>
  391. /// The buffer to hold the bytes that have been read.
  392. /// </value>
  393. public byte[] Buffer { get; private set; }
  394. /// <summary>
  395. /// Gets or sets the exception that was thrown while reading from the
  396. /// <see cref="Socket"/>.
  397. /// </summary>
  398. /// <value>
  399. /// The exception that was thrown while reading from the <see cref="Socket"/>,
  400. /// or <c>null</c> if no exception was thrown.
  401. /// </value>
  402. public Exception Exception { get; private set; }
  403. /// <summary>
  404. /// Signals that the total number of bytes has been read successfully.
  405. /// </summary>
  406. public void Complete()
  407. {
  408. _result = SocketReadResult.Complete;
  409. _socketReadComplete.Set();
  410. }
  411. /// <summary>
  412. /// Signals that the socket read failed.
  413. /// </summary>
  414. /// <param name="cause">The <see cref="Exception"/> that caused the read to fail.</param>
  415. public void Fail(Exception cause)
  416. {
  417. Exception = cause;
  418. _result = SocketReadResult.Failed;
  419. _socketReadComplete.Set();
  420. }
  421. /// <summary>
  422. /// Signals that the connection to the server was lost.
  423. /// </summary>
  424. public void ConnectionLost()
  425. {
  426. _result = SocketReadResult.ConnectionLost;
  427. _socketReadComplete.Set();
  428. }
  429. public SocketReadResult Wait()
  430. {
  431. _socketReadComplete.WaitOne();
  432. _socketReadComplete.Dispose();
  433. _socketReadComplete = null;
  434. return _result;
  435. }
  436. }
  437. private enum SocketReadResult
  438. {
  439. Complete,
  440. ConnectionLost,
  441. Failed
  442. }
  443. #endif
  444. }
  445. }