using System;
using System.Diagnostics;
using System.Net.Sockets;
using System.Net;
using System.Threading;
using Renci.SshNet.Common;
namespace Renci.SshNet
{
///
/// Provides functionality for local port forwarding
///
public partial class ForwardedPortLocal
{
private Socket _listener;
private int _pendingRequests;
partial void ExecuteThread(Action action);
partial void InternalStart()
{
var addr = BoundHost.GetIPAddress();
var ep = new IPEndPoint(addr, (int) BoundPort);
_listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp) {Blocking = true};
_listener.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.NoDelay, true);
_listener.Bind(ep);
_listener.Listen(1);
// update bound port (in case original was passed as zero)
BoundPort = (uint)((IPEndPoint)_listener.LocalEndPoint).Port;
Session.ErrorOccured += Session_ErrorOccured;
Session.Disconnected += Session_Disconnected;
_listenerTaskCompleted = new ManualResetEvent(false);
ExecuteThread(() =>
{
try
{
while (true)
{
// accept new inbound connection
var asyncResult = _listener.BeginAccept(AcceptCallback, _listener);
// wait for the connection to be established
asyncResult.AsyncWaitHandle.WaitOne();
}
}
catch (ObjectDisposedException)
{
// BeginAccept will throw an ObjectDisposedException when the
// socket is closed
}
catch (Exception ex)
{
RaiseExceptionEvent(ex);
}
finally
{
// mark listener stopped
_listenerTaskCompleted.Set();
}
});
}
private void AcceptCallback(IAsyncResult ar)
{
// Get the socket that handles the client request
var serverSocket = (Socket)ar.AsyncState;
Socket clientSocket;
try
{
clientSocket = serverSocket.EndAccept(ar);
}
catch (ObjectDisposedException)
{
// when the socket is closed, an ObjectDisposedException is thrown
// by Socket.EndAccept(IAsyncResult)
return;
}
Interlocked.Increment(ref _pendingRequests);
try
{
clientSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.DontLinger, true);
clientSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.NoDelay, true);
var originatorEndPoint = (IPEndPoint) clientSocket.RemoteEndPoint;
RaiseRequestReceived(originatorEndPoint.Address.ToString(),
(uint)originatorEndPoint.Port);
using (var channel = Session.CreateChannelDirectTcpip())
{
channel.Exception += Channel_Exception;
channel.Open(Host, Port, this, clientSocket);
channel.Bind();
channel.Close();
}
}
catch (Exception exp)
{
RaiseExceptionEvent(exp);
CloseSocket(clientSocket);
}
finally
{
Interlocked.Decrement(ref _pendingRequests);
}
}
private static void CloseSocket(Socket socket)
{
if (socket.Connected)
{
socket.Shutdown(SocketShutdown.Both);
socket.Close();
}
}
partial void InternalStop(TimeSpan timeout)
{
if (timeout == TimeSpan.Zero)
return;
var stopWatch = new Stopwatch();
stopWatch.Start();
while (true)
{
// break out of loop when all pending requests have been processed
if (Interlocked.CompareExchange(ref _pendingRequests, 0, 0) == 0)
break;
// break out of loop when specified timeout has elapsed
if (stopWatch.Elapsed >= timeout && timeout != SshNet.Session.InfiniteTimeSpan)
break;
// give channels time to process pending requests
Thread.Sleep(50);
}
stopWatch.Stop();
}
///
/// Interrupts the listener, and waits for the listener loop to finish.
///
///
/// When the forwarded port is stopped, then any further action is skipped.
///
partial void StopListener()
{
if (!IsStarted)
return;
Session.Disconnected -= Session_Disconnected;
Session.ErrorOccured -= Session_ErrorOccured;
// close listener socket
_listener.Close();
// wait for listener loop to finish
_listenerTaskCompleted.WaitOne();
}
private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
{
StopListener();
}
private void Session_Disconnected(object sender, EventArgs e)
{
StopListener();
}
private void Channel_Exception(object sender, ExceptionEventArgs e)
{
RaiseExceptionEvent(e.Exception);
}
}
}