ForwardedPortLocal.NET.cs 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. using System;
  2. using System.Net.Sockets;
  3. using System.Net;
  4. using System.Threading;
  5. using Renci.SshNet.Abstractions;
  6. using Renci.SshNet.Common;
  7. namespace Renci.SshNet
  8. {
  9. public partial class ForwardedPortLocal
  10. {
  11. private Socket _listener;
  12. private CountdownEvent _pendingChannelCountdown;
  13. partial void InternalStart()
  14. {
  15. var addr = DnsAbstraction.GetHostAddresses(BoundHost)[0];
  16. var ep = new IPEndPoint(addr, (int) BoundPort);
  17. _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp) {NoDelay = true};
  18. _listener.Bind(ep);
  19. _listener.Listen(5);
  20. // update bound port (in case original was passed as zero)
  21. BoundPort = (uint)((IPEndPoint)_listener.LocalEndPoint).Port;
  22. Session.ErrorOccured += Session_ErrorOccured;
  23. Session.Disconnected += Session_Disconnected;
  24. InitializePendingChannelCountdown();
  25. // consider port started when we're listening for inbound connections
  26. _status = ForwardedPortStatus.Started;
  27. StartAccept(null);
  28. }
  29. private void StartAccept(SocketAsyncEventArgs e)
  30. {
  31. if (e == null)
  32. {
  33. e = new SocketAsyncEventArgs();
  34. e.Completed += AcceptCompleted;
  35. }
  36. else
  37. {
  38. // clear the socket as we're reusing the context object
  39. e.AcceptSocket = null;
  40. }
  41. // only accept new connections while we are started
  42. if (IsStarted)
  43. {
  44. try
  45. {
  46. if (!_listener.AcceptAsync(e))
  47. {
  48. AcceptCompleted(null, e);
  49. }
  50. }
  51. catch (ObjectDisposedException)
  52. {
  53. if (_status == ForwardedPortStatus.Stopped || _status == ForwardedPortStatus.Stopped)
  54. {
  55. // ignore ObjectDisposedException while stopping or stopped
  56. return;
  57. }
  58. throw;
  59. }
  60. }
  61. }
  62. private void AcceptCompleted(object sender, SocketAsyncEventArgs e)
  63. {
  64. if (e.SocketError == SocketError.OperationAborted || e.SocketError == SocketError.NotSocket)
  65. {
  66. // server was stopped
  67. return;
  68. }
  69. // capture client socket
  70. var clientSocket = e.AcceptSocket;
  71. if (e.SocketError != SocketError.Success)
  72. {
  73. // accept new connection
  74. StartAccept(e);
  75. // dispose broken client socket
  76. CloseClientSocket(clientSocket);
  77. return;
  78. }
  79. // accept new connection
  80. StartAccept(e);
  81. // process connection
  82. ProcessAccept(clientSocket);
  83. }
  84. private void ProcessAccept(Socket clientSocket)
  85. {
  86. // close the client socket if we're no longer accepting new connections
  87. if (!IsStarted)
  88. {
  89. CloseClientSocket(clientSocket);
  90. return;
  91. }
  92. // capture the countdown event that we're adding a count to, as we need to make sure that we'll be signaling
  93. // that same instance; the instance field for the countdown event is re-initialized when the port is restarted
  94. // and at that time there may still be pending requests
  95. var pendingChannelCountdown = _pendingChannelCountdown;
  96. pendingChannelCountdown.AddCount();
  97. try
  98. {
  99. var originatorEndPoint = (IPEndPoint) clientSocket.RemoteEndPoint;
  100. RaiseRequestReceived(originatorEndPoint.Address.ToString(),
  101. (uint)originatorEndPoint.Port);
  102. using (var channel = Session.CreateChannelDirectTcpip())
  103. {
  104. channel.Exception += Channel_Exception;
  105. channel.Open(Host, Port, this, clientSocket);
  106. channel.Bind();
  107. channel.Close();
  108. }
  109. }
  110. catch (Exception exp)
  111. {
  112. RaiseExceptionEvent(exp);
  113. CloseClientSocket(clientSocket);
  114. }
  115. finally
  116. {
  117. // take into account that CountdownEvent has since been disposed; when stopping the port we
  118. // wait for a given time for the channels to close, but once that timeout period has elapsed
  119. // the CountdownEvent will be disposed
  120. try
  121. {
  122. pendingChannelCountdown.Signal();
  123. }
  124. catch (ObjectDisposedException)
  125. {
  126. }
  127. }
  128. }
  129. /// <summary>
  130. /// Initializes the <see cref="CountdownEvent"/>.
  131. /// </summary>
  132. /// <remarks>
  133. /// <para>
  134. /// When the port is started for the first time, a <see cref="CountdownEvent"/> is created with an initial count
  135. /// of <c>1</c>.
  136. /// </para>
  137. /// <para>
  138. /// On subsequent (re)starts, we'll dispose the current <see cref="CountdownEvent"/> and create a new one with
  139. /// initial count of <c>1</c>.
  140. /// </para>
  141. /// </remarks>
  142. private void InitializePendingChannelCountdown()
  143. {
  144. var original = Interlocked.Exchange(ref _pendingChannelCountdown, new CountdownEvent(1));
  145. if (original != null)
  146. {
  147. original.Dispose();
  148. }
  149. }
  150. private static void CloseClientSocket(Socket clientSocket)
  151. {
  152. if (clientSocket.Connected)
  153. {
  154. try
  155. {
  156. clientSocket.Shutdown(SocketShutdown.Send);
  157. }
  158. catch (Exception)
  159. {
  160. // ignore exception when client socket was already closed
  161. }
  162. }
  163. clientSocket.Dispose();
  164. }
  165. /// <summary>
  166. /// Interrupts the listener, and unsubscribes from <see cref="Session"/> events.
  167. /// </summary>
  168. partial void StopListener()
  169. {
  170. // close listener socket
  171. var listener = _listener;
  172. if (listener != null)
  173. {
  174. listener.Dispose();
  175. }
  176. // unsubscribe from session events
  177. var session = Session;
  178. if (session != null)
  179. {
  180. session.ErrorOccured -= Session_ErrorOccured;
  181. session.Disconnected -= Session_Disconnected;
  182. }
  183. }
  184. /// <summary>
  185. /// Waits for pending channels to close.
  186. /// </summary>
  187. /// <param name="timeout">The maximum time to wait for the pending channels to close.</param>
  188. partial void InternalStop(TimeSpan timeout)
  189. {
  190. _pendingChannelCountdown.Signal();
  191. _pendingChannelCountdown.Wait(timeout);
  192. }
  193. partial void InternalDispose(bool disposing)
  194. {
  195. if (disposing)
  196. {
  197. var listener = _listener;
  198. if (listener != null)
  199. {
  200. _listener = null;
  201. listener.Dispose();
  202. }
  203. var pendingRequestsCountdown = _pendingChannelCountdown;
  204. if (pendingRequestsCountdown != null)
  205. {
  206. _pendingChannelCountdown = null;
  207. pendingRequestsCountdown.Dispose();
  208. }
  209. }
  210. }
  211. private void Session_Disconnected(object sender, EventArgs e)
  212. {
  213. var session = Session;
  214. if (session != null)
  215. {
  216. StopPort(session.ConnectionInfo.Timeout);
  217. }
  218. }
  219. private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
  220. {
  221. var session = Session;
  222. if (session != null)
  223. {
  224. StopPort(session.ConnectionInfo.Timeout);
  225. }
  226. }
  227. private void Channel_Exception(object sender, ExceptionEventArgs e)
  228. {
  229. RaiseExceptionEvent(e.Exception);
  230. }
  231. }
  232. }