using System;
using System.Threading;
using Renci.SshNet.Messages.Connection;
using Renci.SshNet.Common;
using System.Globalization;
using System.Net;
using Renci.SshNet.Abstractions;
namespace Renci.SshNet
{
///
/// Provides functionality for remote port forwarding
///
public class ForwardedPortRemote : ForwardedPort, IDisposable
{
private bool _requestStatus;
private EventWaitHandle _globalRequestResponse = new AutoResetEvent(false);
private int _pendingRequests;
private bool _isStarted;
///
/// Gets or sets a value indicating whether port forwarding is started.
///
///
/// true if port forwarding is started; otherwise, false.
///
public override bool IsStarted
{
get { return _isStarted; }
}
///
/// Gets the bound host.
///
public IPAddress BoundHostAddress { get; private set; }
///
/// Gets the bound host.
///
public string BoundHost
{
get
{
return BoundHostAddress.ToString();
}
}
///
/// Gets the bound port.
///
public uint BoundPort { get; private set; }
///
/// Gets the forwarded host.
///
public IPAddress HostAddress { get; private set; }
///
/// Gets the forwarded host.
///
public string Host
{
get
{
return HostAddress.ToString();
}
}
///
/// Gets the forwarded port.
///
public uint Port { get; private set; }
///
/// Initializes a new instance of the class.
///
/// The bound host address.
/// The bound port.
/// The host address.
/// The port.
/// is null.
/// is null.
/// is greater than .
/// is greater than .
public ForwardedPortRemote(IPAddress boundHostAddress, uint boundPort, IPAddress hostAddress, uint port)
{
if (boundHostAddress == null)
throw new ArgumentNullException("boundHostAddress");
if (hostAddress == null)
throw new ArgumentNullException("hostAddress");
boundPort.ValidatePort("boundPort");
port.ValidatePort("port");
BoundHostAddress = boundHostAddress;
BoundPort = boundPort;
HostAddress = hostAddress;
Port = port;
}
///
/// Initializes a new instance of the class.
///
/// The bound port.
/// The host.
/// The port.
///
///
///
public ForwardedPortRemote(uint boundPort, string host, uint port)
: this(string.Empty, boundPort, host, port)
{
}
///
/// Initializes a new instance of the class.
///
/// The bound host.
/// The bound port.
/// The host.
/// The port.
public ForwardedPortRemote(string boundHost, uint boundPort, string host, uint port)
: this(DnsAbstraction.GetHostAddresses(boundHost)[0],
boundPort,
DnsAbstraction.GetHostAddresses(host)[0],
port)
{
}
///
/// Starts remote port forwarding.
///
protected override void StartPort()
{
Session.RegisterMessage("SSH_MSG_REQUEST_FAILURE");
Session.RegisterMessage("SSH_MSG_REQUEST_SUCCESS");
Session.RegisterMessage("SSH_MSG_CHANNEL_OPEN");
Session.RequestSuccessReceived += Session_RequestSuccess;
Session.RequestFailureReceived += Session_RequestFailure;
Session.ChannelOpenReceived += Session_ChannelOpening;
// send global request to start direct tcpip
Session.SendMessage(new GlobalRequestMessage(GlobalRequestName.TcpIpForward, true, BoundHost, BoundPort));
// wat for response on global request to start direct tcpip
Session.WaitOnHandle(_globalRequestResponse);
if (!_requestStatus)
{
// when the request to start port forward was rejected, then we're no longer
// interested in these events
Session.RequestSuccessReceived -= Session_RequestSuccess;
Session.RequestFailureReceived -= Session_RequestFailure;
Session.ChannelOpenReceived -= Session_ChannelOpening;
throw new SshException(string.Format(CultureInfo.CurrentCulture, "Port forwarding for '{0}' port '{1}' failed to start.", Host, Port));
}
_isStarted = true;
}
///
/// Stops remote port forwarding.
///
/// The maximum amount of time to wait for pending requests to finish processing.
protected override void StopPort(TimeSpan timeout)
{
// if the port not started, then there's nothing to stop
if (!IsStarted)
return;
// mark forwarded port stopped, this also causes open of new channels to be rejected
_isStarted = false;
base.StopPort(timeout);
// send global request to cancel direct tcpip
Session.SendMessage(new GlobalRequestMessage(GlobalRequestName.CancelTcpIpForward, true, BoundHost, BoundPort));
// wait for response on global request to cancel direct tcpip or completion of message
// listener loop (in which case response on global request can never be received)
WaitHandle.WaitAny(new[] { _globalRequestResponse, Session.MessageListenerCompleted }, timeout);
// unsubscribe from session events as either the tcpip forward is cancelled at the
// server, or our session message loop has completed
Session.RequestSuccessReceived -= Session_RequestSuccess;
Session.RequestFailureReceived -= Session_RequestFailure;
Session.ChannelOpenReceived -= Session_ChannelOpening;
var startWaiting = DateTime.Now;
while (true)
{
// break out of loop when all pending requests have been processed
if (Interlocked.CompareExchange(ref _pendingRequests, 0, 0) == 0)
break;
// determine time elapsed since waiting for pending requests to finish
var elapsed = DateTime.Now - startWaiting;
// break out of loop when specified timeout has elapsed
if (elapsed >= timeout && timeout != SshNet.Session.InfiniteTimeSpan)
break;
// give channels time to process pending requests
ThreadAbstraction.Sleep(50);
}
}
///
/// Ensures the current instance is not disposed.
///
/// The current instance is disposed.
protected override void CheckDisposed()
{
if (_isDisposed)
throw new ObjectDisposedException(GetType().FullName);
}
private void Session_ChannelOpening(object sender, MessageEventArgs e)
{
var channelOpenMessage = e.Message;
var info = channelOpenMessage.Info as ForwardedTcpipChannelInfo;
if (info != null)
{
// Ensure this is the corresponding request
if (info.ConnectedAddress == BoundHost && info.ConnectedPort == BoundPort)
{
if (!_isStarted)
{
Session.SendMessage(new ChannelOpenFailureMessage(channelOpenMessage.LocalChannelNumber, "", ChannelOpenFailureMessage.AdministrativelyProhibited));
return;
}
ThreadAbstraction.ExecuteThread(() =>
{
Interlocked.Increment(ref _pendingRequests);
try
{
RaiseRequestReceived(info.OriginatorAddress, info.OriginatorPort);
using (var channel = Session.CreateChannelForwardedTcpip(channelOpenMessage.LocalChannelNumber, channelOpenMessage.InitialWindowSize, channelOpenMessage.MaximumPacketSize))
{
channel.Exception += Channel_Exception;
channel.Bind(new IPEndPoint(HostAddress, (int) Port), this);
channel.Close();
}
}
catch (Exception exp)
{
RaiseExceptionEvent(exp);
}
finally
{
Interlocked.Decrement(ref _pendingRequests);
}
});
}
}
}
private void Channel_Exception(object sender, ExceptionEventArgs exceptionEventArgs)
{
RaiseExceptionEvent(exceptionEventArgs.Exception);
}
private void Session_RequestFailure(object sender, EventArgs e)
{
_requestStatus = false;
_globalRequestResponse.Set();
}
private void Session_RequestSuccess(object sender, MessageEventArgs e)
{
_requestStatus = true;
if (BoundPort == 0)
{
BoundPort = (e.Message.BoundPort == null) ? 0 : e.Message.BoundPort.Value;
}
_globalRequestResponse.Set();
}
#region IDisposable Members
private bool _isDisposed;
///
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
///
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
///
/// Releases unmanaged and - optionally - managed resources
///
/// true to release both managed and unmanaged resources; false to release only unmanaged resources.
protected override void Dispose(bool disposing)
{
if (!_isDisposed)
{
base.Dispose(disposing);
if (disposing)
{
if (Session != null)
{
Session.RequestSuccessReceived -= Session_RequestSuccess;
Session.RequestFailureReceived -= Session_RequestFailure;
Session.ChannelOpenReceived -= Session_ChannelOpening;
Session = null;
}
if (_globalRequestResponse != null)
{
_globalRequestResponse.Dispose();
_globalRequestResponse = null;
}
}
_isDisposed = true;
}
}
///
/// Releases unmanaged resources and performs other cleanup operations before the
/// is reclaimed by garbage collection.
///
~ForwardedPortRemote()
{
// Do not re-create Dispose clean-up code here.
// Calling Dispose(false) is optimal in terms of
// readability and maintainability.
Dispose(false);
}
#endregion
}
}