using System; using System.Globalization; using System.Net; using System.Net.Sockets; using System.Threading; using Renci.SshNet.Common; using Renci.SshNet.Messages.Transport; namespace Renci.SshNet.Abstractions { internal static class SocketAbstraction { public static bool CanRead(Socket socket) { if (socket.Connected) { #if FEATURE_SOCKET_POLL return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0; #endif // FEATURE_SOCKET_POLL } return false; } public static bool CanWrite(Socket socket) { if (socket.Connected) { #if FEATURE_SOCKET_POLL return socket.Poll(-1, SelectMode.SelectWrite); #endif // FEATURE_SOCKET_POLL } return false; } public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout) { var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) {NoDelay = true}; #if FEATURE_SOCKET_EAP var connectCompleted = new ManualResetEvent(false); var args = new SocketAsyncEventArgs { UserToken = connectCompleted, RemoteEndPoint = remoteEndpoint }; args.Completed += ConnectCompleted; if (socket.ConnectAsync(args)) { if (!connectCompleted.WaitOne(connectTimeout)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds)); } if (args.SocketError != SocketError.Success) throw new SocketException((int) args.SocketError); return socket; #elif FEATURE_SOCKET_APM var connectResult = socket.BeginConnect(remoteEndpoint, null, null); if (!connectResult.AsyncWaitHandle.WaitOne(connectTimeout, false)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds)); socket.EndConnect(connectResult); return socket; #elif FEATURE_SOCKET_TAP if (!socket.ConnectAsync(remoteEndpoint).Wait(connectTimeout)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds)); return socket; #else #error Connecting to a remote endpoint is not implemented. #endif } public static void ClearReadBuffer(Socket socket) { try { var buffer = new byte[256]; int bytesReceived; do { bytesReceived = ReadPartial(socket, buffer, 0, buffer.Length, TimeSpan.FromSeconds(2)); } while (bytesReceived > 0); } catch { // ignore any exceptions } } public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout) { #if FEATURE_SOCKET_SYNC return socket.Receive(buffer, offset, size, SocketFlags.None); #elif FEATURE_SOCKET_EAP var receiveCompleted = new ManualResetEvent(false); var sendReceiveToken = new PartialSendReceiveToken(socket, receiveCompleted); var args = new SocketAsyncEventArgs { RemoteEndPoint = socket.RemoteEndPoint, UserToken = sendReceiveToken }; args.Completed += ReceiveCompleted; args.SetBuffer(buffer, offset, size); try { if (socket.ReceiveAsync(args)) { if (!receiveCompleted.WaitOne(timeout)) throw new SshOperationTimeoutException( string.Format( CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds)); } if (args.SocketError != SocketError.Success) throw new SocketException((int) args.SocketError); return args.BytesTransferred; } finally { // initialize token to avoid the waithandle getting used after it's disposed args.UserToken = null; args.Dispose(); receiveCompleted.Dispose(); } #else #error Receiving data from a Socket is not implemented. #endif } /// /// Receives data from a bound into a receive buffer. /// /// /// An array of type that is the storage location for the received data. /// The position in parameter to store the received data. /// The number of bytes to receive. /// Specifies the amount of time after which the call will time out. /// /// The number of bytes received. /// /// /// If no data is available for reading, the method will /// block until data is available or the time-out value was exceeded. If the time-out value was exceeded, the /// call will throw a . /// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the /// method will complete immediately and throw a . /// public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout) { #if FEATURE_SOCKET_SYNC var totalBytesRead = 0; var totalBytesToRead = size; do { try { var bytesRead = socket.Receive(buffer, offset + totalBytesRead, totalBytesToRead - totalBytesRead, SocketFlags.None); if (bytesRead == 0) return 0; totalBytesRead += bytesRead; } catch (SocketException ex) { if (IsErrorResumable(ex.SocketErrorCode)) { ThreadAbstraction.Sleep(30); continue; } throw; } } while (totalBytesRead < totalBytesToRead); return totalBytesRead; #elif FEATURE_SOCKET_EAP var receiveCompleted = new ManualResetEvent(false); var sendReceiveToken = new BlockingSendReceiveToken(socket, buffer, offset, size, receiveCompleted); var args = new SocketAsyncEventArgs { UserToken = sendReceiveToken, RemoteEndPoint = socket.RemoteEndPoint }; args.Completed += ReceiveCompleted; args.SetBuffer(buffer, offset, size); try { if (socket.ReceiveAsync(args)) { if (!receiveCompleted.WaitOne(timeout)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds)); } if (args.SocketError != SocketError.Success) throw new SocketException((int) args.SocketError); return sendReceiveToken.TotalBytesTransferred; } finally { // initialize token to avoid the waithandle getting used after it's disposed args.UserToken = null; args.Dispose(); receiveCompleted.Dispose(); } #else #error Receiving data from a Socket is not implemented. #endif } public static void Send(Socket socket, byte[] data) { Send(socket, data, 0, data.Length); } public static void Send(Socket socket, byte[] data, int offset, int size) { #if FEATURE_SOCKET_SYNC var totalBytesSent = 0; // how many bytes are already sent var totalBytesToSend = size; do { try { var bytesSent = socket.Send(data, offset + totalBytesSent, totalBytesToSend - totalBytesSent, SocketFlags.None); if (bytesSent == 0) throw new SshConnectionException("An established connection was aborted by the server.", DisconnectReason.ConnectionLost); totalBytesSent += bytesSent; } catch (SocketException ex) { if (IsErrorResumable(ex.SocketErrorCode)) { // socket buffer is probably full, wait and try again ThreadAbstraction.Sleep(30); } else throw; // any serious error occurr } } while (totalBytesSent < totalBytesToSend); #elif FEATURE_SOCKET_EAP var sendCompleted = new ManualResetEvent(false); var sendReceiveToken = new BlockingSendReceiveToken(socket, data, offset, size, sendCompleted); var socketAsyncSendArgs = new SocketAsyncEventArgs { RemoteEndPoint = socket.RemoteEndPoint, UserToken = sendReceiveToken }; socketAsyncSendArgs.SetBuffer(data, offset, size); socketAsyncSendArgs.Completed += SendCompleted; try { if (socket.SendAsync(socketAsyncSendArgs)) { if (!sendCompleted.WaitOne()) throw new SocketException((int) SocketError.TimedOut); } if (socketAsyncSendArgs.SocketError != SocketError.Success) throw new SocketException((int) socketAsyncSendArgs.SocketError); if (sendReceiveToken.TotalBytesTransferred == 0) throw new SshConnectionException("An established connection was aborted by the server.", DisconnectReason.ConnectionLost); } finally { // initialize token to avoid the completion waithandle getting used after it's disposed socketAsyncSendArgs.UserToken = null; socketAsyncSendArgs.Dispose(); sendCompleted.Dispose(); } #else #error Sending data to a Socket is not implemented. #endif } private static bool IsErrorResumable(SocketError socketError) { switch (socketError) { case SocketError.WouldBlock: case SocketError.IOPending: case SocketError.NoBufferSpaceAvailable: return true; default: return false; } } #if FEATURE_SOCKET_EAP private static void ConnectCompleted(object sender, SocketAsyncEventArgs e) { var eventWaitHandle = (ManualResetEvent) e.UserToken; if (eventWaitHandle != null) eventWaitHandle.Set(); } #endif // FEATURE_SOCKET_EAP #if FEATURE_SOCKET_EAP && !FEATURE_SOCKET_SYNC private static void ReceiveCompleted(object sender, SocketAsyncEventArgs e) { var sendReceiveToken = (Token) e.UserToken; if (sendReceiveToken != null) sendReceiveToken.Process(e); } private static void SendCompleted(object sender, SocketAsyncEventArgs e) { var sendReceiveToken = (Token) e.UserToken; if (sendReceiveToken != null) sendReceiveToken.Process(e); } private interface Token { void Process(SocketAsyncEventArgs args); } private class BlockingSendReceiveToken : Token { public BlockingSendReceiveToken(Socket socket, byte[] buffer, int offset, int size, EventWaitHandle completionWaitHandle) { _socket = socket; _buffer = buffer; _offset = offset; _bytesToTransfer = size; _completionWaitHandle = completionWaitHandle; } public void Process(SocketAsyncEventArgs args) { if (args.SocketError == SocketError.Success) { TotalBytesTransferred += args.BytesTransferred; if (TotalBytesTransferred == _bytesToTransfer) { // finished transferring specified bytes _completionWaitHandle.Set(); return; } if (args.BytesTransferred == 0) { // remote server closed the connection _completionWaitHandle.Set(); return; } _offset += args.BytesTransferred; args.SetBuffer(_buffer, _offset, _bytesToTransfer - TotalBytesTransferred); ResumeOperation(args); return; } if (IsErrorResumable(args.SocketError)) { ThreadAbstraction.Sleep(30); ResumeOperation(args); return; } // we're dealing with a (fatal) error _completionWaitHandle.Set(); } private void ResumeOperation(SocketAsyncEventArgs args) { switch (args.LastOperation) { case SocketAsyncOperation.Receive: _socket.ReceiveAsync(args); break; case SocketAsyncOperation.Send: _socket.SendAsync(args); break; } } private readonly int _bytesToTransfer; public int TotalBytesTransferred { get; private set; } private readonly EventWaitHandle _completionWaitHandle; private readonly Socket _socket; private readonly byte[] _buffer; private int _offset; } private class PartialSendReceiveToken : Token { public PartialSendReceiveToken(Socket socket, EventWaitHandle completionWaitHandle) { _socket = socket; _completionWaitHandle = completionWaitHandle; } public void Process(SocketAsyncEventArgs args) { if (args.SocketError == SocketError.Success) { _completionWaitHandle.Set(); return; } if (IsErrorResumable(args.SocketError)) { ThreadAbstraction.Sleep(30); ResumeOperation(args); return; } // we're dealing with a (fatal) error _completionWaitHandle.Set(); } private void ResumeOperation(SocketAsyncEventArgs args) { switch (args.LastOperation) { case SocketAsyncOperation.Receive: _socket.ReceiveAsync(args); break; case SocketAsyncOperation.Send: _socket.SendAsync(args); break; } } private readonly EventWaitHandle _completionWaitHandle; private readonly Socket _socket; } #endif // FEATURE_SOCKET_EAP && !FEATURE_SOCKET_SYNC } }