Session.NET.cs 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. using System.Globalization;
  2. using System.Linq;
  3. using System;
  4. using System.Net.Sockets;
  5. using System.Net;
  6. using Renci.SshNet.Common;
  7. using System.Threading;
  8. using Renci.SshNet.Messages.Transport;
  9. using System.Diagnostics;
  10. using System.Collections.Generic;
  11. namespace Renci.SshNet
  12. {
  13. public partial class Session
  14. {
  15. private const byte Null = 0x00;
  16. private const byte CarriageReturn = 0x0d;
  17. private const byte LineFeed = 0x0a;
  18. private readonly TraceSource _log =
  19. #if DEBUG
  20. new TraceSource("SshNet.Logging", SourceLevels.All);
  21. #else
  22. new TraceSource("SshNet.Logging");
  23. #endif
  24. /// <summary>
  25. /// Holds the lock object to ensure read access to the socket is synchronized.
  26. /// </summary>
  27. private readonly object _socketReadLock = new object();
  28. /// <summary>
  29. /// Gets a value indicating whether the socket is connected.
  30. /// </summary>
  31. /// <param name="isConnected"><c>true</c> if the socket is connected; otherwise, <c>false</c></param>
  32. /// <remarks>
  33. /// <para>
  34. /// As a first check we verify whether <see cref="Socket.Connected"/> is
  35. /// <c>true</c>. However, this only returns the state of the socket as of
  36. /// the last I/O operation. Therefore we use the combination of Socket.Poll
  37. /// with mode SelectRead and Socket.Available to verify if the socket is
  38. /// still connected.
  39. /// </para>
  40. /// <para>
  41. /// The MSDN doc mention the following on the return value of <see cref="Socket.Poll(int, SelectMode)"/>
  42. /// with mode <see cref="SelectMode.SelectRead"/>:
  43. /// <list type="bullet">
  44. /// <item>
  45. /// <description><c>true</c> if data is available for reading;</description>
  46. /// </item>
  47. /// <item>
  48. /// <description><c>true</c> if the connection has been closed, reset, or terminated; otherwise, returns <c>false</c>.</description>
  49. /// </item>
  50. /// </list>
  51. /// </para>
  52. /// <para>
  53. /// <c>Conclusion:</c> when the return value is <c>true</c> - but no data is available for reading - then
  54. /// the socket is no longer connected.
  55. /// </para>
  56. /// <para>
  57. /// When a <see cref="Socket"/> is used from multiple threads, there's a race condition
  58. /// between the invocation of <see cref="Socket.Poll(int, SelectMode)"/> and the moment
  59. /// when the value of <see cref="Socket.Available"/> is obtained. As a workaround, we signal
  60. /// when bytes are read from the <see cref="Socket"/>.
  61. /// </para>
  62. /// </remarks>
  63. partial void IsSocketConnected(ref bool isConnected)
  64. {
  65. isConnected = (_socket != null && _socket.Connected);
  66. if (isConnected)
  67. {
  68. // synchronize this to ensure thread B does not reset the wait handle before
  69. // thread A was able to check whether "bytes read from socket" signal was
  70. // actually received
  71. lock (_socketReadLock)
  72. {
  73. _bytesReadFromSocket.Reset();
  74. var connectionClosedOrDataAvailable = _socket.Poll(1000, SelectMode.SelectRead);
  75. isConnected = !(connectionClosedOrDataAvailable && _socket.Available == 0);
  76. if (!isConnected)
  77. {
  78. // the race condition is between the Socket.Poll call and
  79. // Socket.Available, but the event handler - where we signal that
  80. // bytes have been received from the socket - is sometimes invoked
  81. // shortly after
  82. isConnected = _bytesReadFromSocket.WaitOne(500);
  83. }
  84. }
  85. }
  86. }
  87. /// <summary>
  88. /// Establishes a socket connection to the specified host and port.
  89. /// </summary>
  90. /// <param name="host">The host name of the server to connect to.</param>
  91. /// <param name="port">The port to connect to.</param>
  92. /// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="Renci.SshNet.ConnectionInfo.Timeout"/>.</exception>
  93. /// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
  94. partial void SocketConnect(string host, int port)
  95. {
  96. const int socketBufferSize = 2 * MaximumSshPacketSize;
  97. var ipAddress = host.GetIPAddress();
  98. var timeout = ConnectionInfo.Timeout;
  99. var ep = new IPEndPoint(ipAddress, port);
  100. _socket = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
  101. _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true);
  102. _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.SendBuffer, socketBufferSize);
  103. _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReceiveBuffer, socketBufferSize);
  104. Log(string.Format("Initiating connect to '{0}:{1}'.", ConnectionInfo.Host, ConnectionInfo.Port));
  105. var connectResult = _socket.BeginConnect(ep, null, null);
  106. if (!connectResult.AsyncWaitHandle.WaitOne(timeout, false))
  107. throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
  108. "Connection failed to establish within {0:F0} milliseconds.", timeout.TotalMilliseconds));
  109. _socket.EndConnect(connectResult);
  110. }
  111. /// <summary>
  112. /// Closes the socket and allows the socket to be reused after the current connection is closed.
  113. /// </summary>
  114. /// <exception cref="SocketException">An error occurred when trying to access the socket.</exception>
  115. partial void SocketDisconnect()
  116. {
  117. _socket.Disconnect(true);
  118. }
  119. /// <summary>
  120. /// Performs a blocking read on the socket until a line is read.
  121. /// </summary>
  122. /// <param name="response">The line read from the socket, or <c>null</c> when the remote server has shutdown and all data has been received.</param>
  123. /// <param name="timeout">A <see cref="TimeSpan"/> that represents the time to wait until a line is read.</param>
  124. /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
  125. /// <exception cref="SocketException">An error occurred when trying to access the socket.</exception>
  126. partial void SocketReadLine(ref string response, TimeSpan timeout)
  127. {
  128. var encoding = new ASCIIEncoding();
  129. var buffer = new List<byte>();
  130. var data = new byte[1];
  131. // read data one byte at a time to find end of line and leave any unhandled information in the buffer
  132. // to be processed by subsequent invocations
  133. do
  134. {
  135. var asyncResult = _socket.BeginReceive(data, 0, data.Length, SocketFlags.None, null, null);
  136. if (!asyncResult.AsyncWaitHandle.WaitOne(timeout))
  137. throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
  138. "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds));
  139. var received = _socket.EndReceive(asyncResult);
  140. if (received == 0)
  141. // the remote server shut down the socket
  142. break;
  143. buffer.Add(data[0]);
  144. }
  145. while (!(buffer.Count > 0 && (buffer[buffer.Count - 1] == LineFeed || buffer[buffer.Count - 1] == Null)));
  146. if (buffer.Count == 0)
  147. response = null;
  148. else if (buffer.Count == 1 && buffer[buffer.Count - 1] == 0x00)
  149. // return an empty version string if the buffer consists of only a 0x00 character
  150. response = string.Empty;
  151. else if (buffer.Count > 1 && buffer[buffer.Count - 2] == CarriageReturn)
  152. // strip trailing CRLF
  153. response = encoding.GetString(buffer.Take(buffer.Count - 2).ToArray());
  154. else if (buffer.Count > 1 && buffer[buffer.Count - 1] == LineFeed)
  155. // strip trailing LF
  156. response = encoding.GetString(buffer.Take(buffer.Count - 1).ToArray());
  157. else
  158. response = encoding.GetString(buffer.ToArray());
  159. }
  160. /// <summary>
  161. /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
  162. /// </summary>
  163. /// <param name="length">The number of bytes to read.</param>
  164. /// <param name="buffer">The buffer to read to.</param>
  165. /// <exception cref="SshConnectionException">The socket is closed.</exception>
  166. /// <exception cref="SocketException">The read failed.</exception>
  167. partial void SocketRead(int length, ref byte[] buffer)
  168. {
  169. var receivedTotal = 0; // how many bytes is already received
  170. do
  171. {
  172. try
  173. {
  174. var receivedBytes = _socket.Receive(buffer, receivedTotal, length - receivedTotal, SocketFlags.None);
  175. if (receivedBytes > 0)
  176. {
  177. // signal that bytes have been read from the socket
  178. // this is used to improve accuracy of Session.IsSocketConnected
  179. _bytesReadFromSocket.Set();
  180. receivedTotal += receivedBytes;
  181. continue;
  182. }
  183. // 2012-09-11: Kenneth_aa
  184. // When Disconnect or Dispose is called, this throws SshConnectionException(), which...
  185. // 1 - goes up to ReceiveMessage()
  186. // 2 - up again to MessageListener()
  187. // which is where there is a catch-all exception block so it can notify event listeners.
  188. // 3 - MessageListener then again calls RaiseError().
  189. // There the exception is checked for the exception thrown here (ConnectionLost), and if it matches it will not call Session.SendDisconnect().
  190. //
  191. // Adding a check for _isDisconnecting causes ReceiveMessage() to throw SshConnectionException: "Bad packet length {0}".
  192. //
  193. if (_isDisconnecting)
  194. throw new SshConnectionException("An established connection was aborted by the software in your host machine.", DisconnectReason.ConnectionLost);
  195. throw new SshConnectionException("An established connection was aborted by the server.", DisconnectReason.ConnectionLost);
  196. }
  197. catch (SocketException exp)
  198. {
  199. if (exp.SocketErrorCode == SocketError.WouldBlock ||
  200. exp.SocketErrorCode == SocketError.IOPending ||
  201. exp.SocketErrorCode == SocketError.NoBufferSpaceAvailable)
  202. {
  203. // socket buffer is probably empty, wait and try again
  204. Thread.Sleep(30);
  205. }
  206. else
  207. {
  208. throw new SshConnectionException(exp.Message, DisconnectReason.ConnectionLost, exp);
  209. }
  210. }
  211. } while (receivedTotal < length);
  212. }
  213. /// <summary>
  214. /// Writes the specified data to the server.
  215. /// </summary>
  216. /// <param name="data">The data to write to the server.</param>
  217. /// <param name="offset">The zero-based offset in <paramref name="data"/> at which to begin taking data from.</param>
  218. /// <param name="length">The number of bytes of <paramref name="data"/> to write.</param>
  219. /// <exception cref="SshOperationTimeoutException">The write has timed-out.</exception>
  220. /// <exception cref="SocketException">The write failed.</exception>
  221. private void SocketWrite(byte[] data, int offset, int length)
  222. {
  223. var totalBytesSent = 0; // how many bytes are already sent
  224. var totalBytesToSend = length;
  225. do
  226. {
  227. try
  228. {
  229. totalBytesSent += _socket.Send(data, offset + totalBytesSent, totalBytesToSend - totalBytesSent,
  230. SocketFlags.None);
  231. }
  232. catch (SocketException ex)
  233. {
  234. if (ex.SocketErrorCode == SocketError.WouldBlock ||
  235. ex.SocketErrorCode == SocketError.IOPending ||
  236. ex.SocketErrorCode == SocketError.NoBufferSpaceAvailable)
  237. {
  238. // socket buffer is probably full, wait and try again
  239. Thread.Sleep(30);
  240. }
  241. else
  242. throw; // any serious error occurr
  243. }
  244. } while (totalBytesSent < totalBytesToSend);
  245. }
  246. [Conditional("DEBUG")]
  247. partial void Log(string text)
  248. {
  249. _log.TraceEvent(TraceEventType.Verbose, 1, text);
  250. }
  251. #if ASYNC_SOCKET_READ
  252. private void SocketRead(int length, ref byte[] buffer)
  253. {
  254. var state = new SocketReadState(_socket, length, ref buffer);
  255. _socket.BeginReceive(buffer, 0, length, SocketFlags.None, SocketReceiveCallback, state);
  256. var readResult = state.Wait();
  257. switch (readResult)
  258. {
  259. case SocketReadResult.Complete:
  260. break;
  261. case SocketReadResult.ConnectionLost:
  262. if (_isDisconnecting)
  263. throw new SshConnectionException(
  264. "An established connection was aborted by the software in your host machine.",
  265. DisconnectReason.ConnectionLost);
  266. throw new SshConnectionException("An established connection was aborted by the server.",
  267. DisconnectReason.ConnectionLost);
  268. case SocketReadResult.Failed:
  269. var socketException = state.Exception as SocketException;
  270. if (socketException != null)
  271. {
  272. if (socketException.SocketErrorCode == SocketError.ConnectionAborted)
  273. {
  274. buffer = new byte[length];
  275. Disconnect();
  276. return;
  277. }
  278. }
  279. throw state.Exception;
  280. }
  281. }
  282. private void SocketReceiveCallback(IAsyncResult ar)
  283. {
  284. var state = ar.AsyncState as SocketReadState;
  285. var socket = state.Socket;
  286. try
  287. {
  288. var bytesReceived = socket.EndReceive(ar);
  289. if (bytesReceived > 0)
  290. {
  291. _bytesReadFromSocket.Set();
  292. state.BytesRead += bytesReceived;
  293. if (state.BytesRead < state.TotalBytesToRead)
  294. {
  295. socket.BeginReceive(state.Buffer, state.BytesRead, state.TotalBytesToRead - state.BytesRead,
  296. SocketFlags.None, SocketReceiveCallback, state);
  297. }
  298. else
  299. {
  300. // we received all bytes that we wanted, so lets mark the read
  301. // complete
  302. state.Complete();
  303. }
  304. }
  305. else
  306. {
  307. // the remote host shut down the connection; this could also have been
  308. // triggered by a SSH_MSG_DISCONNECT sent by the client
  309. state.ConnectionLost();
  310. }
  311. }
  312. catch (SocketException ex)
  313. {
  314. if (ex.SocketErrorCode != SocketError.ConnectionAborted)
  315. {
  316. if (ex.SocketErrorCode == SocketError.WouldBlock ||
  317. ex.SocketErrorCode == SocketError.IOPending ||
  318. ex.SocketErrorCode == SocketError.NoBufferSpaceAvailable)
  319. {
  320. // socket buffer is probably empty, wait and try again
  321. Thread.Sleep(30);
  322. socket.BeginReceive(state.Buffer, state.BytesRead, state.TotalBytesToRead - state.BytesRead,
  323. SocketFlags.None, SocketReceiveCallback, state);
  324. return;
  325. }
  326. }
  327. state.Fail(ex);
  328. }
  329. catch (Exception ex)
  330. {
  331. state.Fail(ex);
  332. }
  333. }
  334. private class SocketReadState
  335. {
  336. private SocketReadResult _result;
  337. /// <summary>
  338. /// WaitHandle to signal that read from socket has completed (either successfully
  339. /// or with failure)
  340. /// </summary>
  341. private EventWaitHandle _socketReadComplete;
  342. public SocketReadState(Socket socket, int totalBytesToRead, ref byte[] buffer)
  343. {
  344. Socket = socket;
  345. TotalBytesToRead = totalBytesToRead;
  346. Buffer = buffer;
  347. _socketReadComplete = new ManualResetEvent(false);
  348. }
  349. /// <summary>
  350. /// Gets the <see cref="Socket"/> to read from.
  351. /// </summary>
  352. /// <value>
  353. /// The <see cref="Socket"/> to read from.
  354. /// </value>
  355. public Socket Socket { get; private set; }
  356. /// <summary>
  357. /// Gets or sets the number of bytes that have been read from the <see cref="Socket"/>.
  358. /// </summary>
  359. /// <value>
  360. /// The number of bytes that have been read from the <see cref="Socket"/>.
  361. /// </value>
  362. public int BytesRead { get; set; }
  363. /// <summary>
  364. /// Gets the total number of bytes to read from the <see cref="Socket"/>.
  365. /// </summary>
  366. /// <value>
  367. /// The total number of bytes to read from the <see cref="Socket"/>.
  368. /// </value>
  369. public int TotalBytesToRead { get; private set; }
  370. /// <summary>
  371. /// Gets or sets the buffer to hold the bytes that have been read.
  372. /// </summary>
  373. /// <value>
  374. /// The buffer to hold the bytes that have been read.
  375. /// </value>
  376. public byte[] Buffer { get; private set; }
  377. /// <summary>
  378. /// Gets or sets the exception that was thrown while reading from the
  379. /// <see cref="Socket"/>.
  380. /// </summary>
  381. /// <value>
  382. /// The exception that was thrown while reading from the <see cref="Socket"/>,
  383. /// or <c>null</c> if no exception was thrown.
  384. /// </value>
  385. public Exception Exception { get; private set; }
  386. /// <summary>
  387. /// Signals that the total number of bytes has been read successfully.
  388. /// </summary>
  389. public void Complete()
  390. {
  391. _result = SocketReadResult.Complete;
  392. _socketReadComplete.Set();
  393. }
  394. /// <summary>
  395. /// Signals that the socket read failed.
  396. /// </summary>
  397. /// <param name="cause">The <see cref="Exception"/> that caused the read to fail.</param>
  398. public void Fail(Exception cause)
  399. {
  400. Exception = cause;
  401. _result = SocketReadResult.Failed;
  402. _socketReadComplete.Set();
  403. }
  404. /// <summary>
  405. /// Signals that the connection to the server was lost.
  406. /// </summary>
  407. public void ConnectionLost()
  408. {
  409. _result = SocketReadResult.ConnectionLost;
  410. _socketReadComplete.Set();
  411. }
  412. public SocketReadResult Wait()
  413. {
  414. _socketReadComplete.WaitOne();
  415. _socketReadComplete.Dispose();
  416. _socketReadComplete = null;
  417. return _result;
  418. }
  419. }
  420. private enum SocketReadResult
  421. {
  422. Complete,
  423. ConnectionLost,
  424. Failed
  425. }
  426. #endif
  427. }
  428. }