2
0

ForwardedPortLocal.cs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. using System;
  2. using System.Net;
  3. using System.Net.Sockets;
  4. using System.Threading;
  5. using Microsoft.Extensions.Logging;
  6. using Renci.SshNet.Common;
  7. namespace Renci.SshNet
  8. {
  9. /// <summary>
  10. /// Provides functionality for local port forwarding.
  11. /// </summary>
  12. public partial class ForwardedPortLocal : ForwardedPort
  13. {
  14. private readonly ILogger _logger;
  15. private ForwardedPortStatus _status;
  16. private bool _isDisposed;
  17. private Socket _listener;
  18. private CountdownEvent _pendingChannelCountdown;
  19. /// <summary>
  20. /// Gets the bound host.
  21. /// </summary>
  22. public string BoundHost { get; private set; }
  23. /// <summary>
  24. /// Gets the bound port.
  25. /// </summary>
  26. public uint BoundPort { get; private set; }
  27. /// <summary>
  28. /// Gets the forwarded host.
  29. /// </summary>
  30. public string Host { get; private set; }
  31. /// <summary>
  32. /// Gets the forwarded port.
  33. /// </summary>
  34. public uint Port { get; private set; }
  35. /// <summary>
  36. /// Gets a value indicating whether port forwarding is started.
  37. /// </summary>
  38. /// <value>
  39. /// <see langword="true"/> if port forwarding is started; otherwise, <see langword="false"/>.
  40. /// </value>
  41. public override bool IsStarted
  42. {
  43. get { return _status == ForwardedPortStatus.Started; }
  44. }
  45. /// <summary>
  46. /// Initializes a new instance of the <see cref="ForwardedPortLocal"/> class.
  47. /// </summary>
  48. /// <param name="boundPort">The bound port.</param>
  49. /// <param name="host">The host.</param>
  50. /// <param name="port">The port.</param>
  51. /// <exception cref="ArgumentOutOfRangeException"><paramref name="boundPort" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
  52. /// <exception cref="ArgumentNullException"><paramref name="host"/> is <see langword="null"/>.</exception>
  53. /// <exception cref="ArgumentOutOfRangeException"><paramref name="port" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
  54. public ForwardedPortLocal(uint boundPort, string host, uint port)
  55. : this(string.Empty, boundPort, host, port)
  56. {
  57. }
  58. /// <summary>
  59. /// Initializes a new instance of the <see cref="ForwardedPortLocal"/> class.
  60. /// </summary>
  61. /// <param name="boundHost">The bound host.</param>
  62. /// <param name="host">The host.</param>
  63. /// <param name="port">The port.</param>
  64. /// <exception cref="ArgumentNullException"><paramref name="boundHost"/> is <see langword="null"/>.</exception>
  65. /// <exception cref="ArgumentNullException"><paramref name="host"/> is <see langword="null"/>.</exception>
  66. /// <exception cref="ArgumentOutOfRangeException"><paramref name="port" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
  67. public ForwardedPortLocal(string boundHost, string host, uint port)
  68. : this(boundHost, 0, host, port)
  69. {
  70. }
  71. /// <summary>
  72. /// Initializes a new instance of the <see cref="ForwardedPortLocal"/> class.
  73. /// </summary>
  74. /// <param name="boundHost">The bound host.</param>
  75. /// <param name="boundPort">The bound port.</param>
  76. /// <param name="host">The host.</param>
  77. /// <param name="port">The port.</param>
  78. /// <exception cref="ArgumentNullException"><paramref name="boundHost"/> is <see langword="null"/>.</exception>
  79. /// <exception cref="ArgumentNullException"><paramref name="host"/> is <see langword="null"/>.</exception>
  80. /// <exception cref="ArgumentOutOfRangeException"><paramref name="boundPort" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
  81. /// <exception cref="ArgumentOutOfRangeException"><paramref name="port" /> is greater than <see cref="IPEndPoint.MaxPort" />.</exception>
  82. public ForwardedPortLocal(string boundHost, uint boundPort, string host, uint port)
  83. {
  84. ThrowHelper.ThrowIfNull(boundHost);
  85. ThrowHelper.ThrowIfNull(host);
  86. boundPort.ValidatePort();
  87. port.ValidatePort();
  88. BoundHost = boundHost;
  89. BoundPort = boundPort;
  90. Host = host;
  91. Port = port;
  92. _status = ForwardedPortStatus.Stopped;
  93. _logger = SshNetLoggingConfiguration.LoggerFactory.CreateLogger<ForwardedPortLocal>();
  94. }
  95. /// <summary>
  96. /// Starts local port forwarding.
  97. /// </summary>
  98. protected override void StartPort()
  99. {
  100. if (!ForwardedPortStatus.ToStarting(ref _status))
  101. {
  102. return;
  103. }
  104. try
  105. {
  106. InternalStart();
  107. }
  108. catch (Exception)
  109. {
  110. _status = ForwardedPortStatus.Stopped;
  111. throw;
  112. }
  113. }
  114. /// <summary>
  115. /// Stops local port forwarding, and waits for the specified timeout until all pending
  116. /// requests are processed.
  117. /// </summary>
  118. /// <param name="timeout">The maximum amount of time to wait for pending requests to finish processing.</param>
  119. protected override void StopPort(TimeSpan timeout)
  120. {
  121. timeout.EnsureValidTimeout();
  122. if (!ForwardedPortStatus.ToStopping(ref _status))
  123. {
  124. return;
  125. }
  126. // signal existing channels that the port is closing
  127. base.StopPort(timeout);
  128. // prevent new requests from getting processed
  129. StopListener();
  130. // wait for open channels to close
  131. InternalStop(timeout);
  132. // mark port stopped
  133. _status = ForwardedPortStatus.Stopped;
  134. }
  135. /// <summary>
  136. /// Ensures the current instance is not disposed.
  137. /// </summary>
  138. /// <exception cref="ObjectDisposedException">The current instance is disposed.</exception>
  139. protected override void CheckDisposed()
  140. {
  141. ThrowHelper.ThrowObjectDisposedIf(_isDisposed, this);
  142. }
  143. /// <summary>
  144. /// Releases unmanaged and - optionally - managed resources.
  145. /// </summary>
  146. /// <param name="disposing"><see langword="true"/> to release both managed and unmanaged resources; <see langword="false"/> to release only unmanaged resources.</param>
  147. protected override void Dispose(bool disposing)
  148. {
  149. if (_isDisposed)
  150. {
  151. return;
  152. }
  153. base.Dispose(disposing);
  154. InternalDispose(disposing);
  155. _isDisposed = true;
  156. }
  157. /// <summary>
  158. /// Finalizes an instance of the <see cref="ForwardedPortLocal"/> class.
  159. /// </summary>
  160. ~ForwardedPortLocal()
  161. {
  162. Dispose(disposing: false);
  163. }
  164. private void InternalStart()
  165. {
  166. var addr = Dns.GetHostAddresses(BoundHost)[0];
  167. var ep = new IPEndPoint(addr, (int)BoundPort);
  168. _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
  169. _listener.Bind(ep);
  170. _listener.Listen(5);
  171. // update bound port (in case original was passed as zero)
  172. BoundPort = (uint)((IPEndPoint)_listener.LocalEndPoint).Port;
  173. Session.ErrorOccured += Session_ErrorOccured;
  174. Session.Disconnected += Session_Disconnected;
  175. InitializePendingChannelCountdown();
  176. // consider port started when we're listening for inbound connections
  177. _status = ForwardedPortStatus.Started;
  178. StartAccept(e: null);
  179. }
  180. private void StartAccept(SocketAsyncEventArgs e)
  181. {
  182. if (e is null)
  183. {
  184. #pragma warning disable CA2000 // Dispose objects before losing scope
  185. e = new SocketAsyncEventArgs();
  186. #pragma warning restore CA2000 // Dispose objects before losing scope
  187. e.Completed += AcceptCompleted;
  188. }
  189. else
  190. {
  191. // clear the socket as we're reusing the context object
  192. e.AcceptSocket = null;
  193. }
  194. // only accept new connections while we are started
  195. if (IsStarted)
  196. {
  197. try
  198. {
  199. if (!_listener.AcceptAsync(e))
  200. {
  201. AcceptCompleted(sender: null, e);
  202. }
  203. }
  204. catch (ObjectDisposedException)
  205. {
  206. if (_status == ForwardedPortStatus.Stopping || _status == ForwardedPortStatus.Stopped)
  207. {
  208. // ignore ObjectDisposedException while stopping or stopped
  209. return;
  210. }
  211. throw;
  212. }
  213. }
  214. }
  215. private void AcceptCompleted(object sender, SocketAsyncEventArgs e)
  216. {
  217. if (e.SocketError is SocketError.OperationAborted or SocketError.NotSocket)
  218. {
  219. // server was stopped
  220. return;
  221. }
  222. // capture client socket
  223. var clientSocket = e.AcceptSocket;
  224. if (e.SocketError != SocketError.Success)
  225. {
  226. // accept new connection
  227. StartAccept(e);
  228. // dispose broken client socket
  229. CloseClientSocket(clientSocket);
  230. return;
  231. }
  232. // accept new connection
  233. StartAccept(e);
  234. // process connection
  235. ProcessAccept(clientSocket);
  236. }
  237. private void ProcessAccept(Socket clientSocket)
  238. {
  239. // close the client socket if we're no longer accepting new connections
  240. if (!IsStarted)
  241. {
  242. CloseClientSocket(clientSocket);
  243. return;
  244. }
  245. // capture the countdown event that we're adding a count to, as we need to make sure that we'll be signaling
  246. // that same instance; the instance field for the countdown event is re-initialized when the port is restarted
  247. // and at that time there may still be pending requests
  248. var pendingChannelCountdown = _pendingChannelCountdown;
  249. pendingChannelCountdown.AddCount();
  250. try
  251. {
  252. var originatorEndPoint = (IPEndPoint)clientSocket.RemoteEndPoint;
  253. RaiseRequestReceived(originatorEndPoint.Address.ToString(),
  254. (uint)originatorEndPoint.Port);
  255. using (var channel = Session.CreateChannelDirectTcpip())
  256. {
  257. channel.Exception += Channel_Exception;
  258. channel.Open(Host, Port, this, clientSocket);
  259. channel.Bind();
  260. }
  261. }
  262. catch (Exception exp)
  263. {
  264. RaiseExceptionEvent(exp);
  265. CloseClientSocket(clientSocket);
  266. }
  267. finally
  268. {
  269. // take into account that CountdownEvent has since been disposed; when stopping the port we
  270. // wait for a given time for the channels to close, but once that timeout period has elapsed
  271. // the CountdownEvent will be disposed
  272. try
  273. {
  274. _ = pendingChannelCountdown.Signal();
  275. }
  276. catch (ObjectDisposedException)
  277. {
  278. // Ignore any ObjectDisposedException
  279. }
  280. }
  281. }
  282. /// <summary>
  283. /// Initializes the <see cref="CountdownEvent"/>.
  284. /// </summary>
  285. /// <remarks>
  286. /// <para>
  287. /// When the port is started for the first time, a <see cref="CountdownEvent"/> is created with an initial count
  288. /// of <c>1</c>.
  289. /// </para>
  290. /// <para>
  291. /// On subsequent (re)starts, we'll dispose the current <see cref="CountdownEvent"/> and create a new one with
  292. /// initial count of <c>1</c>.
  293. /// </para>
  294. /// </remarks>
  295. private void InitializePendingChannelCountdown()
  296. {
  297. var original = Interlocked.Exchange(ref _pendingChannelCountdown, new CountdownEvent(1));
  298. original?.Dispose();
  299. }
  300. private static void CloseClientSocket(Socket clientSocket)
  301. {
  302. if (clientSocket.Connected)
  303. {
  304. try
  305. {
  306. clientSocket.Shutdown(SocketShutdown.Send);
  307. }
  308. catch (Exception)
  309. {
  310. // ignore exception when client socket was already closed
  311. }
  312. }
  313. clientSocket.Dispose();
  314. }
  315. /// <summary>
  316. /// Interrupts the listener, and unsubscribes from <see cref="Session"/> events.
  317. /// </summary>
  318. private void StopListener()
  319. {
  320. // close listener socket
  321. _listener?.Dispose();
  322. // unsubscribe from session events
  323. var session = Session;
  324. if (session != null)
  325. {
  326. session.ErrorOccured -= Session_ErrorOccured;
  327. session.Disconnected -= Session_Disconnected;
  328. }
  329. }
  330. /// <summary>
  331. /// Waits for pending channels to close.
  332. /// </summary>
  333. /// <param name="timeout">The maximum time to wait for the pending channels to close.</param>
  334. private void InternalStop(TimeSpan timeout)
  335. {
  336. _ = _pendingChannelCountdown.Signal();
  337. if (!_pendingChannelCountdown.Wait(timeout))
  338. {
  339. _logger.LogInformation("Timeout waiting for pending channels in local forwarded port to close.");
  340. }
  341. }
  342. private void InternalDispose(bool disposing)
  343. {
  344. if (disposing)
  345. {
  346. var listener = _listener;
  347. if (listener is not null)
  348. {
  349. _listener = null;
  350. listener.Dispose();
  351. }
  352. var pendingRequestsCountdown = _pendingChannelCountdown;
  353. if (pendingRequestsCountdown is not null)
  354. {
  355. _pendingChannelCountdown = null;
  356. pendingRequestsCountdown.Dispose();
  357. }
  358. }
  359. }
  360. private void Session_Disconnected(object sender, EventArgs e)
  361. {
  362. var session = Session;
  363. if (session is not null)
  364. {
  365. StopPort(session.ConnectionInfo.Timeout);
  366. }
  367. }
  368. private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
  369. {
  370. var session = Session;
  371. if (session is not null)
  372. {
  373. StopPort(session.ConnectionInfo.Timeout);
  374. }
  375. }
  376. private void Channel_Exception(object sender, ExceptionEventArgs e)
  377. {
  378. RaiseExceptionEvent(e.Exception);
  379. }
  380. }
  381. }