using System; using System.Globalization; using System.Net; using System.Net.Sockets; using System.Threading; using Renci.SshNet.Common; using Renci.SshNet.Messages.Transport; namespace Renci.SshNet.Abstractions { internal static class SocketAbstraction { public static bool CanRead(Socket socket) { if (socket.Connected) { #if FEATURE_SOCKET_POLL return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0; #else return true; #endif // FEATURE_SOCKET_POLL } return false; } public static bool CanWrite(Socket socket) { if (socket.Connected) { #if FEATURE_SOCKET_POLL return socket.Poll(-1, SelectMode.SelectWrite); #else return true; #endif // FEATURE_SOCKET_POLL } return false; } public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout) { var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) {NoDelay = true}; #if FEATURE_SOCKET_EAP var connectCompleted = new ManualResetEvent(false); var args = new SocketAsyncEventArgs { UserToken = connectCompleted, RemoteEndPoint = remoteEndpoint }; args.Completed += ConnectCompleted; if (socket.ConnectAsync(args)) { if (!connectCompleted.WaitOne(connectTimeout)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds)); } if (args.SocketError != SocketError.Success) throw new SocketException((int) args.SocketError); return socket; #elif FEATURE_SOCKET_APM var connectResult = socket.BeginConnect(remoteEndpoint, null, null); if (!connectResult.AsyncWaitHandle.WaitOne(connectTimeout, false)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds)); socket.EndConnect(connectResult); return socket; #elif FEATURE_SOCKET_TAP if (!socket.ConnectAsync(remoteEndpoint).Wait(connectTimeout)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds)); return socket; #else #error Connecting to a remote endpoint is not implemented. #endif } public static void ClearReadBuffer(Socket socket) { var timeout = TimeSpan.FromMilliseconds(500); var buffer = new byte[256]; int bytesReceived; do { bytesReceived = ReadPartial(socket, buffer, 0, buffer.Length, timeout); } while (bytesReceived > 0); } public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout) { #if FEATURE_SOCKET_SYNC socket.ReceiveTimeout = (int) timeout.TotalMilliseconds; try { return socket.Receive(buffer, offset, size, SocketFlags.None); } catch (SocketException ex) { if (ex.SocketErrorCode == SocketError.TimedOut) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds)); throw; } #elif FEATURE_SOCKET_EAP var receiveCompleted = new ManualResetEvent(false); var sendReceiveToken = new PartialSendReceiveToken(socket, receiveCompleted); var args = new SocketAsyncEventArgs { RemoteEndPoint = socket.RemoteEndPoint, UserToken = sendReceiveToken }; args.Completed += ReceiveCompleted; args.SetBuffer(buffer, offset, size); try { if (socket.ReceiveAsync(args)) { if (!receiveCompleted.WaitOne(timeout)) throw new SshOperationTimeoutException( string.Format( CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds)); } if (args.SocketError != SocketError.Success) throw new SocketException((int) args.SocketError); return args.BytesTransferred; } finally { // initialize token to avoid the waithandle getting used after it's disposed args.UserToken = null; args.Dispose(); receiveCompleted.Dispose(); } #else #error Receiving data from a Socket is not implemented. #endif } public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int size, Action processReceivedBytesAction) { #if FEATURE_SOCKET_SYNC // do not time-out receive socket.ReceiveTimeout = 0; while (socket.Connected) { try { var bytesRead = socket.Receive(buffer, offset, size, SocketFlags.None); if (bytesRead == 0) break; processReceivedBytesAction(buffer, offset, bytesRead); } catch (SocketException ex) { if (IsErrorResumable(ex.SocketErrorCode)) continue; switch (ex.SocketErrorCode) { case SocketError.ConnectionAborted: case SocketError.ConnectionReset: // connection was closed return; case SocketError.Interrupted: // connection was closed because FIN/ACK was not received in time after // shutting down the (send part of the) socket return; default: throw; // throw any other error } } } #elif FEATURE_SOCKET_EAP var completionWaitHandle = new ManualResetEvent(false); var readToken = new ContinuousReceiveToken(socket, processReceivedBytesAction, completionWaitHandle); var args = new SocketAsyncEventArgs { RemoteEndPoint = socket.RemoteEndPoint, UserToken = readToken }; args.Completed += ReceiveCompleted; args.SetBuffer(buffer, offset, size); if (!socket.ReceiveAsync(args)) { ReceiveCompleted(null, args); } completionWaitHandle.WaitOne(); completionWaitHandle.Dispose(); if (readToken.Exception != null) throw readToken.Exception; #else #error Receiving data from a Socket is not implemented. #endif } /// /// Reads a byte from the specified . /// /// The to read from. /// Specifies the amount of time after which the call will time out. /// /// The byte read, or -1 if the socket was closed. /// /// The read operation timed out. /// The read failed. public static int ReadByte(Socket socket, TimeSpan timeout) { var buffer = new byte[1]; if (Read(socket, buffer, 0, 1, timeout) == 0) return -1; return buffer[0]; } /// /// Sends a byte using the specified . /// /// The to write to. /// The value to send. /// The write failed. public static void SendByte(Socket socket, byte value) { var buffer = new[] {value}; Send(socket, buffer, 0, 1); } /// /// Receives data from a bound into a receive buffer. /// /// /// An array of type that is the storage location for the received data. /// The position in parameter to store the received data. /// The number of bytes to receive. /// Specifies the amount of time after which the call will time out. /// /// The number of bytes received. /// /// /// If no data is available for reading, the method will /// block until data is available or the time-out value was exceeded. If the time-out value was exceeded, the /// call will throw a . /// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the /// method will complete immediately and throw a . /// public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout) { #if FEATURE_SOCKET_SYNC var totalBytesRead = 0; var totalBytesToRead = size; socket.ReceiveTimeout = (int) timeout.TotalMilliseconds; do { try { var bytesRead = socket.Receive(buffer, offset + totalBytesRead, totalBytesToRead - totalBytesRead, SocketFlags.None); if (bytesRead == 0) return 0; totalBytesRead += bytesRead; } catch (SocketException ex) { if (IsErrorResumable(ex.SocketErrorCode)) { ThreadAbstraction.Sleep(30); continue; } if (ex.SocketErrorCode == SocketError.TimedOut) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds)); throw; } } while (totalBytesRead < totalBytesToRead); return totalBytesRead; #elif FEATURE_SOCKET_EAP var receiveCompleted = new ManualResetEvent(false); var sendReceiveToken = new BlockingSendReceiveToken(socket, buffer, offset, size, receiveCompleted); var args = new SocketAsyncEventArgs { UserToken = sendReceiveToken, RemoteEndPoint = socket.RemoteEndPoint }; args.Completed += ReceiveCompleted; args.SetBuffer(buffer, offset, size); try { if (socket.ReceiveAsync(args)) { if (!receiveCompleted.WaitOne(timeout)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds)); } if (args.SocketError != SocketError.Success) throw new SocketException((int) args.SocketError); return sendReceiveToken.TotalBytesTransferred; } finally { // initialize token to avoid the waithandle getting used after it's disposed args.UserToken = null; args.Dispose(); receiveCompleted.Dispose(); } #else #error Receiving data from a Socket is not implemented. #endif } public static void Send(Socket socket, byte[] data) { Send(socket, data, 0, data.Length); } public static void Send(Socket socket, byte[] data, int offset, int size) { #if FEATURE_SOCKET_SYNC var totalBytesSent = 0; // how many bytes are already sent var totalBytesToSend = size; do { try { var bytesSent = socket.Send(data, offset + totalBytesSent, totalBytesToSend - totalBytesSent, SocketFlags.None); if (bytesSent == 0) throw new SshConnectionException("An established connection was aborted by the server.", DisconnectReason.ConnectionLost); totalBytesSent += bytesSent; } catch (SocketException ex) { if (IsErrorResumable(ex.SocketErrorCode)) { // socket buffer is probably full, wait and try again ThreadAbstraction.Sleep(30); } else throw; // any serious error occurr } } while (totalBytesSent < totalBytesToSend); #elif FEATURE_SOCKET_EAP var sendCompleted = new ManualResetEvent(false); var sendReceiveToken = new BlockingSendReceiveToken(socket, data, offset, size, sendCompleted); var socketAsyncSendArgs = new SocketAsyncEventArgs { RemoteEndPoint = socket.RemoteEndPoint, UserToken = sendReceiveToken }; socketAsyncSendArgs.SetBuffer(data, offset, size); socketAsyncSendArgs.Completed += SendCompleted; try { if (socket.SendAsync(socketAsyncSendArgs)) { if (!sendCompleted.WaitOne()) throw new SocketException((int) SocketError.TimedOut); } if (socketAsyncSendArgs.SocketError != SocketError.Success) throw new SocketException((int) socketAsyncSendArgs.SocketError); if (sendReceiveToken.TotalBytesTransferred == 0) throw new SshConnectionException("An established connection was aborted by the server.", DisconnectReason.ConnectionLost); } finally { // initialize token to avoid the completion waithandle getting used after it's disposed socketAsyncSendArgs.UserToken = null; socketAsyncSendArgs.Dispose(); sendCompleted.Dispose(); } #else #error Sending data to a Socket is not implemented. #endif } public static bool IsErrorResumable(SocketError socketError) { switch (socketError) { case SocketError.WouldBlock: case SocketError.IOPending: case SocketError.NoBufferSpaceAvailable: return true; default: return false; } } #if FEATURE_SOCKET_EAP private static void ConnectCompleted(object sender, SocketAsyncEventArgs e) { var eventWaitHandle = (ManualResetEvent) e.UserToken; if (eventWaitHandle != null) eventWaitHandle.Set(); } #endif // FEATURE_SOCKET_EAP #if FEATURE_SOCKET_EAP && !FEATURE_SOCKET_SYNC private static void ReceiveCompleted(object sender, SocketAsyncEventArgs e) { var sendReceiveToken = (Token) e.UserToken; if (sendReceiveToken != null) sendReceiveToken.Process(e); } private static void SendCompleted(object sender, SocketAsyncEventArgs e) { var sendReceiveToken = (Token) e.UserToken; if (sendReceiveToken != null) sendReceiveToken.Process(e); } private interface Token { void Process(SocketAsyncEventArgs args); } private class BlockingSendReceiveToken : Token { public BlockingSendReceiveToken(Socket socket, byte[] buffer, int offset, int size, EventWaitHandle completionWaitHandle) { _socket = socket; _buffer = buffer; _offset = offset; _bytesToTransfer = size; _completionWaitHandle = completionWaitHandle; } public void Process(SocketAsyncEventArgs args) { if (args.SocketError == SocketError.Success) { TotalBytesTransferred += args.BytesTransferred; if (TotalBytesTransferred == _bytesToTransfer) { // finished transferring specified bytes _completionWaitHandle.Set(); return; } if (args.BytesTransferred == 0) { // remote server closed the connection _completionWaitHandle.Set(); return; } _offset += args.BytesTransferred; args.SetBuffer(_buffer, _offset, _bytesToTransfer - TotalBytesTransferred); ResumeOperation(args); return; } if (IsErrorResumable(args.SocketError)) { ThreadAbstraction.Sleep(30); ResumeOperation(args); return; } // we're dealing with a (fatal) error _completionWaitHandle.Set(); } private void ResumeOperation(SocketAsyncEventArgs args) { switch (args.LastOperation) { case SocketAsyncOperation.Receive: _socket.ReceiveAsync(args); break; case SocketAsyncOperation.Send: _socket.SendAsync(args); break; } } private readonly int _bytesToTransfer; public int TotalBytesTransferred { get; private set; } private readonly EventWaitHandle _completionWaitHandle; private readonly Socket _socket; private readonly byte[] _buffer; private int _offset; } private class PartialSendReceiveToken : Token { public PartialSendReceiveToken(Socket socket, EventWaitHandle completionWaitHandle) { _socket = socket; _completionWaitHandle = completionWaitHandle; } public void Process(SocketAsyncEventArgs args) { if (args.SocketError == SocketError.Success) { _completionWaitHandle.Set(); return; } if (IsErrorResumable(args.SocketError)) { ThreadAbstraction.Sleep(30); ResumeOperation(args); return; } // we're dealing with a (fatal) error _completionWaitHandle.Set(); } private void ResumeOperation(SocketAsyncEventArgs args) { switch (args.LastOperation) { case SocketAsyncOperation.Receive: _socket.ReceiveAsync(args); break; case SocketAsyncOperation.Send: _socket.SendAsync(args); break; } } private readonly EventWaitHandle _completionWaitHandle; private readonly Socket _socket; } private class ContinuousReceiveToken : Token { public ContinuousReceiveToken(Socket socket, Action processReceivedBytesAction, EventWaitHandle completionWaitHandle) { _socket = socket; _processReceivedBytesAction = processReceivedBytesAction; _completionWaitHandle = completionWaitHandle; } public Exception Exception { get; private set; } public void Process(SocketAsyncEventArgs args) { if (args.SocketError == SocketError.Success) { if (args.BytesTransferred == 0) { // remote socket was closed _completionWaitHandle.Set(); return; } _processReceivedBytesAction(args.Buffer, args.Offset, args.BytesTransferred); ResumeOperation(args); return; } if (IsErrorResumable(args.SocketError)) { ThreadAbstraction.Sleep(30); ResumeOperation(args); return; } if (args.SocketError != SocketError.OperationAborted) { Exception = new SocketException((int) args.SocketError); } // we're dealing with a (fatal) error _completionWaitHandle.Set(); } private void ResumeOperation(SocketAsyncEventArgs args) { switch (args.LastOperation) { case SocketAsyncOperation.Receive: _socket.ReceiveAsync(args); break; case SocketAsyncOperation.Send: _socket.SendAsync(args); break; } } private readonly EventWaitHandle _completionWaitHandle; private readonly Socket _socket; private readonly Action _processReceivedBytesAction; } #endif // FEATURE_SOCKET_EAP && !FEATURE_SOCKET_SYNC } }