using System; using System.Threading; using Renci.SshNet.Messages.Connection; using Renci.SshNet.Common; using System.Globalization; using System.Net; using Renci.SshNet.Abstractions; namespace Renci.SshNet { /// /// Provides functionality for remote port forwarding /// public class ForwardedPortRemote : ForwardedPort, IDisposable { private bool _requestStatus; private EventWaitHandle _globalRequestResponse = new AutoResetEvent(false); private int _pendingRequests; private bool _isStarted; /// /// Gets or sets a value indicating whether port forwarding is started. /// /// /// true if port forwarding is started; otherwise, false. /// public override bool IsStarted { get { return _isStarted; } } /// /// Gets the bound host. /// public IPAddress BoundHostAddress { get; private set; } /// /// Gets the bound host. /// public string BoundHost { get { return BoundHostAddress.ToString(); } } /// /// Gets the bound port. /// public uint BoundPort { get; private set; } /// /// Gets the forwarded host. /// public IPAddress HostAddress { get; private set; } /// /// Gets the forwarded host. /// public string Host { get { return HostAddress.ToString(); } } /// /// Gets the forwarded port. /// public uint Port { get; private set; } /// /// Initializes a new instance of the class. /// /// The bound host address. /// The bound port. /// The host address. /// The port. /// is null. /// is null. /// is greater than . /// is greater than . public ForwardedPortRemote(IPAddress boundHostAddress, uint boundPort, IPAddress hostAddress, uint port) { if (boundHostAddress == null) throw new ArgumentNullException("boundHostAddress"); if (hostAddress == null) throw new ArgumentNullException("hostAddress"); boundPort.ValidatePort("boundPort"); port.ValidatePort("port"); BoundHostAddress = boundHostAddress; BoundPort = boundPort; HostAddress = hostAddress; Port = port; } /// /// Initializes a new instance of the class. /// /// The bound port. /// The host. /// The port. /// /// /// public ForwardedPortRemote(uint boundPort, string host, uint port) : this(string.Empty, boundPort, host, port) { } /// /// Initializes a new instance of the class. /// /// The bound host. /// The bound port. /// The host. /// The port. public ForwardedPortRemote(string boundHost, uint boundPort, string host, uint port) : this(DnsAbstraction.GetHostAddresses(boundHost)[0], boundPort, DnsAbstraction.GetHostAddresses(host)[0], port) { } /// /// Starts remote port forwarding. /// protected override void StartPort() { Session.RegisterMessage("SSH_MSG_REQUEST_FAILURE"); Session.RegisterMessage("SSH_MSG_REQUEST_SUCCESS"); Session.RegisterMessage("SSH_MSG_CHANNEL_OPEN"); Session.RequestSuccessReceived += Session_RequestSuccess; Session.RequestFailureReceived += Session_RequestFailure; Session.ChannelOpenReceived += Session_ChannelOpening; // send global request to start direct tcpip Session.SendMessage(new GlobalRequestMessage(GlobalRequestName.TcpIpForward, true, BoundHost, BoundPort)); // wat for response on global request to start direct tcpip Session.WaitOnHandle(_globalRequestResponse); if (!_requestStatus) { // when the request to start port forward was rejected, then we're no longer // interested in these events Session.RequestSuccessReceived -= Session_RequestSuccess; Session.RequestFailureReceived -= Session_RequestFailure; Session.ChannelOpenReceived -= Session_ChannelOpening; throw new SshException(string.Format(CultureInfo.CurrentCulture, "Port forwarding for '{0}' port '{1}' failed to start.", Host, Port)); } _isStarted = true; } /// /// Stops remote port forwarding. /// /// The maximum amount of time to wait for pending requests to finish processing. protected override void StopPort(TimeSpan timeout) { // if the port not started, then there's nothing to stop if (!IsStarted) return; // mark forwarded port stopped, this also causes open of new channels to be rejected _isStarted = false; base.StopPort(timeout); // send global request to cancel direct tcpip Session.SendMessage(new GlobalRequestMessage(GlobalRequestName.CancelTcpIpForward, true, BoundHost, BoundPort)); // wait for response on global request to cancel direct tcpip or completion of message // listener loop (in which case response on global request can never be received) WaitHandle.WaitAny(new[] { _globalRequestResponse, Session.MessageListenerCompleted }, timeout); // unsubscribe from session events as either the tcpip forward is cancelled at the // server, or our session message loop has completed Session.RequestSuccessReceived -= Session_RequestSuccess; Session.RequestFailureReceived -= Session_RequestFailure; Session.ChannelOpenReceived -= Session_ChannelOpening; var startWaiting = DateTime.Now; while (true) { // break out of loop when all pending requests have been processed if (Interlocked.CompareExchange(ref _pendingRequests, 0, 0) == 0) break; // determine time elapsed since waiting for pending requests to finish var elapsed = DateTime.Now - startWaiting; // break out of loop when specified timeout has elapsed if (elapsed >= timeout && timeout != SshNet.Session.InfiniteTimeSpan) break; // give channels time to process pending requests ThreadAbstraction.Sleep(50); } } /// /// Ensures the current instance is not disposed. /// /// The current instance is disposed. protected override void CheckDisposed() { if (_isDisposed) throw new ObjectDisposedException(GetType().FullName); } private void Session_ChannelOpening(object sender, MessageEventArgs e) { var channelOpenMessage = e.Message; var info = channelOpenMessage.Info as ForwardedTcpipChannelInfo; if (info != null) { // Ensure this is the corresponding request if (info.ConnectedAddress == BoundHost && info.ConnectedPort == BoundPort) { if (!_isStarted) { Session.SendMessage(new ChannelOpenFailureMessage(channelOpenMessage.LocalChannelNumber, "", ChannelOpenFailureMessage.AdministrativelyProhibited)); return; } ThreadAbstraction.ExecuteThread(() => { Interlocked.Increment(ref _pendingRequests); try { RaiseRequestReceived(info.OriginatorAddress, info.OriginatorPort); using (var channel = Session.CreateChannelForwardedTcpip(channelOpenMessage.LocalChannelNumber, channelOpenMessage.InitialWindowSize, channelOpenMessage.MaximumPacketSize)) { channel.Exception += Channel_Exception; channel.Bind(new IPEndPoint(HostAddress, (int) Port), this); channel.Close(); } } catch (Exception exp) { RaiseExceptionEvent(exp); } finally { Interlocked.Decrement(ref _pendingRequests); } }); } } } private void Channel_Exception(object sender, ExceptionEventArgs exceptionEventArgs) { RaiseExceptionEvent(exceptionEventArgs.Exception); } private void Session_RequestFailure(object sender, EventArgs e) { _requestStatus = false; _globalRequestResponse.Set(); } private void Session_RequestSuccess(object sender, MessageEventArgs e) { _requestStatus = true; if (BoundPort == 0) { BoundPort = (e.Message.BoundPort == null) ? 0 : e.Message.BoundPort.Value; } _globalRequestResponse.Set(); } #region IDisposable Members private bool _isDisposed; /// /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. /// public void Dispose() { Dispose(true); GC.SuppressFinalize(this); } /// /// Releases unmanaged and - optionally - managed resources /// /// true to release both managed and unmanaged resources; false to release only unmanaged resources. protected override void Dispose(bool disposing) { if (!_isDisposed) { base.Dispose(disposing); if (disposing) { if (Session != null) { Session.RequestSuccessReceived -= Session_RequestSuccess; Session.RequestFailureReceived -= Session_RequestFailure; Session.ChannelOpenReceived -= Session_ChannelOpening; Session = null; } if (_globalRequestResponse != null) { _globalRequestResponse.Dispose(); _globalRequestResponse = null; } } _isDisposed = true; } } /// /// Releases unmanaged resources and performs other cleanup operations before the /// is reclaimed by garbage collection. /// ~ForwardedPortRemote() { // Do not re-create Dispose clean-up code here. // Calling Dispose(false) is optimal in terms of // readability and maintainability. Dispose(false); } #endregion } }