using System.Globalization; using System.Linq; using System; using System.Net.Sockets; using System.Net; using Renci.SshNet.Common; using Renci.SshNet.Messages.Transport; using System.Diagnostics; using System.Collections.Generic; using Renci.SshNet.Abstractions; namespace Renci.SshNet { public partial class Session { private const byte Null = 0x00; private const byte CarriageReturn = 0x0d; private const byte LineFeed = 0x0a; #if FEATURE_DIAGNOSTICS_TRACESOURCE private readonly TraceSource _log = #if DEBUG new TraceSource("SshNet.Logging", SourceLevels.All); #else new TraceSource("SshNet.Logging"); #endif // DEBUG #endif // FEATURE_DIAGNOSTICS_TRACESOURCE /// /// Holds the lock object to ensure read access to the socket is synchronized. /// private readonly object _socketReadLock = new object(); /// /// Gets a value indicating whether the socket is connected. /// /// true if the socket is connected; otherwise, false /// /// /// As a first check we verify whether is /// true. However, this only returns the state of the socket as of /// the last I/O operation. Therefore we use the combination of Socket.Poll /// with mode SelectRead and Socket.Available to verify if the socket is /// still connected. /// /// /// The MSDN doc mention the following on the return value of /// with mode : /// /// /// true if data is available for reading; /// /// /// true if the connection has been closed, reset, or terminated; otherwise, returns false. /// /// /// /// /// Conclusion: when the return value is true - but no data is available for reading - then /// the socket is no longer connected. /// /// /// When a is used from multiple threads, there's a race condition /// between the invocation of and the moment /// when the value of is obtained. As a workaround, we signal /// when bytes are read from the . /// /// partial void IsSocketConnected(ref bool isConnected) { isConnected = (_socket != null && _socket.Connected); if (isConnected) { // synchronize this to ensure thread B does not reset the wait handle before // thread A was able to check whether "bytes read from socket" signal was // actually received lock (_socketReadLock) { _bytesReadFromSocket.Reset(); var connectionClosedOrDataAvailable = _socket.Poll(1000, SelectMode.SelectRead); isConnected = !(connectionClosedOrDataAvailable && _socket.Available == 0); if (!isConnected) { // the race condition is between the Socket.Poll call and // Socket.Available, but the event handler - where we signal that // bytes have been received from the socket - is sometimes invoked // shortly after isConnected = _bytesReadFromSocket.WaitOne(500); } } } } /// /// Establishes a socket connection to the specified host and port. /// /// The host name of the server to connect to. /// The port to connect to. /// The connection failed to establish within the configured . /// An error occurred trying to establish the connection. partial void SocketConnect(string host, int port) { const int socketBufferSize = 2 * MaximumSshPacketSize; var ipAddress = host.GetIPAddress(); var timeout = ConnectionInfo.Timeout; var ep = new IPEndPoint(ipAddress, port); _socket = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp); _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true); _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.SendBuffer, socketBufferSize); _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReceiveBuffer, socketBufferSize); Log(string.Format("Initiating connect to '{0}:{1}'.", ConnectionInfo.Host, ConnectionInfo.Port)); #if FEATURE_SOCKET_EAP if (!_socket.ConnectAsync(ep).Wait(timeout)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", timeout.TotalMilliseconds)); #else var connectResult = _socket.BeginConnect(ep, null, null); if (!connectResult.AsyncWaitHandle.WaitOne(timeout, false)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", timeout.TotalMilliseconds)); _socket.EndConnect(connectResult); #endif // FEATURE_SOCKET_ASYNC_TPL } /// /// Closes the socket and allows the socket to be reused after the current connection is closed. /// /// An error occurred when trying to access the socket. partial void SocketDisconnect() { _socket.Dispose(); } /// /// Performs a blocking read on the socket until a line is read. /// /// The line read from the socket, or null when the remote server has shutdown and all data has been received. /// A that represents the time to wait until a line is read. /// The read has timed-out. /// An error occurred when trying to access the socket. partial void SocketReadLine(ref string response, TimeSpan timeout) { var buffer = new List(); var data = new byte[1]; // read data one byte at a time to find end of line and leave any unhandled information in the buffer // to be processed by subsequent invocations do { #if FEATURE_SOCKET_TAP var receiveTask = _socket.ReceiveAsync(new ArraySegment(data, 0, data.Length), SocketFlags.None); if (!receiveTask.Wait(timeout)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds)); var received = receiveTask.Result; #else var asyncResult = _socket.BeginReceive(data, 0, data.Length, SocketFlags.None, null, null); if (!asyncResult.AsyncWaitHandle.WaitOne(timeout)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds)); var received = _socket.EndReceive(asyncResult); #endif // FEATURE_SOCKET_TAP if (received == 0) // the remote server shut down the socket break; buffer.Add(data[0]); } while (!(buffer.Count > 0 && (buffer[buffer.Count - 1] == LineFeed || buffer[buffer.Count - 1] == Null))); if (buffer.Count == 0) response = null; else if (buffer.Count == 1 && buffer[buffer.Count - 1] == 0x00) // return an empty version string if the buffer consists of only a 0x00 character response = string.Empty; else if (buffer.Count > 1 && buffer[buffer.Count - 2] == CarriageReturn) // strip trailing CRLF response = SshData.Ascii.GetString(buffer.Take(buffer.Count - 2).ToArray()); else if (buffer.Count > 1 && buffer[buffer.Count - 1] == LineFeed) // strip trailing LF response = SshData.Ascii.GetString(buffer.Take(buffer.Count - 1).ToArray()); else response = SshData.Ascii.GetString(buffer.ToArray()); } /// /// Performs a blocking read on the socket until bytes are received. /// /// The number of bytes to read. /// The buffer to read to. /// The socket is closed. /// The read failed. partial void SocketRead(int length, ref byte[] buffer) { var receivedTotal = 0; // how many bytes is already received do { try { var receivedBytes = _socket.Receive(buffer, receivedTotal, length - receivedTotal, SocketFlags.None); if (receivedBytes > 0) { // signal that bytes have been read from the socket // this is used to improve accuracy of Session.IsSocketConnected _bytesReadFromSocket.Set(); receivedTotal += receivedBytes; continue; } // 2012-09-11: Kenneth_aa // When Disconnect or Dispose is called, this throws SshConnectionException(), which... // 1 - goes up to ReceiveMessage() // 2 - up again to MessageListener() // which is where there is a catch-all exception block so it can notify event listeners. // 3 - MessageListener then again calls RaiseError(). // There the exception is checked for the exception thrown here (ConnectionLost), and if it matches it will not call Session.SendDisconnect(). // // Adding a check for _isDisconnecting causes ReceiveMessage() to throw SshConnectionException: "Bad packet length {0}". // if (_isDisconnecting) throw new SshConnectionException("An established connection was aborted by the software in your host machine.", DisconnectReason.ConnectionLost); throw new SshConnectionException("An established connection was aborted by the server.", DisconnectReason.ConnectionLost); } catch (SocketException exp) { if (exp.SocketErrorCode == SocketError.WouldBlock || exp.SocketErrorCode == SocketError.IOPending || exp.SocketErrorCode == SocketError.NoBufferSpaceAvailable) { // socket buffer is probably empty, wait and try again ThreadAbstraction.Sleep(30); } else { throw new SshConnectionException(exp.Message, DisconnectReason.ConnectionLost, exp); } } } while (receivedTotal < length); } /// /// Writes the specified data to the server. /// /// The data to write to the server. /// The zero-based offset in at which to begin taking data from. /// The number of bytes of to write. /// The write has timed-out. /// The write failed. private void SocketWrite(byte[] data, int offset, int length) { var totalBytesSent = 0; // how many bytes are already sent var totalBytesToSend = length; do { try { totalBytesSent += _socket.Send(data, offset + totalBytesSent, totalBytesToSend - totalBytesSent, SocketFlags.None); } catch (SocketException ex) { if (ex.SocketErrorCode == SocketError.WouldBlock || ex.SocketErrorCode == SocketError.IOPending || ex.SocketErrorCode == SocketError.NoBufferSpaceAvailable) { // socket buffer is probably full, wait and try again ThreadAbstraction.Sleep(30); } else throw; // any serious error occurr } } while (totalBytesSent < totalBytesToSend); } [Conditional("DEBUG")] partial void Log(string text) { #if FEATURE_DIAGNOSTICS_TRACESOURCE _log.TraceEvent(TraceEventType.Verbose, 1, text); #endif // FEATURE_DIAGNOSTICS_TRACESOURCE } #if ASYNC_SOCKET_READ private void SocketRead(int length, ref byte[] buffer) { var state = new SocketReadState(_socket, length, ref buffer); _socket.BeginReceive(buffer, 0, length, SocketFlags.None, SocketReceiveCallback, state); var readResult = state.Wait(); switch (readResult) { case SocketReadResult.Complete: break; case SocketReadResult.ConnectionLost: if (_isDisconnecting) throw new SshConnectionException( "An established connection was aborted by the software in your host machine.", DisconnectReason.ConnectionLost); throw new SshConnectionException("An established connection was aborted by the server.", DisconnectReason.ConnectionLost); case SocketReadResult.Failed: var socketException = state.Exception as SocketException; if (socketException != null) { if (socketException.SocketErrorCode == SocketError.ConnectionAborted) { buffer = new byte[length]; Disconnect(); return; } } throw state.Exception; } } private void SocketReceiveCallback(IAsyncResult ar) { var state = ar.AsyncState as SocketReadState; var socket = state.Socket; try { var bytesReceived = socket.EndReceive(ar); if (bytesReceived > 0) { _bytesReadFromSocket.Set(); state.BytesRead += bytesReceived; if (state.BytesRead < state.TotalBytesToRead) { socket.BeginReceive(state.Buffer, state.BytesRead, state.TotalBytesToRead - state.BytesRead, SocketFlags.None, SocketReceiveCallback, state); } else { // we received all bytes that we wanted, so lets mark the read // complete state.Complete(); } } else { // the remote host shut down the connection; this could also have been // triggered by a SSH_MSG_DISCONNECT sent by the client state.ConnectionLost(); } } catch (SocketException ex) { if (ex.SocketErrorCode != SocketError.ConnectionAborted) { if (ex.SocketErrorCode == SocketError.WouldBlock || ex.SocketErrorCode == SocketError.IOPending || ex.SocketErrorCode == SocketError.NoBufferSpaceAvailable) { // socket buffer is probably empty, wait and try again Thread.Sleep(30); socket.BeginReceive(state.Buffer, state.BytesRead, state.TotalBytesToRead - state.BytesRead, SocketFlags.None, SocketReceiveCallback, state); return; } } state.Fail(ex); } catch (Exception ex) { state.Fail(ex); } } private class SocketReadState { private SocketReadResult _result; /// /// WaitHandle to signal that read from socket has completed (either successfully /// or with failure) /// private EventWaitHandle _socketReadComplete; public SocketReadState(Socket socket, int totalBytesToRead, ref byte[] buffer) { Socket = socket; TotalBytesToRead = totalBytesToRead; Buffer = buffer; _socketReadComplete = new ManualResetEvent(false); } /// /// Gets the to read from. /// /// /// The to read from. /// public Socket Socket { get; private set; } /// /// Gets or sets the number of bytes that have been read from the . /// /// /// The number of bytes that have been read from the . /// public int BytesRead { get; set; } /// /// Gets the total number of bytes to read from the . /// /// /// The total number of bytes to read from the . /// public int TotalBytesToRead { get; private set; } /// /// Gets or sets the buffer to hold the bytes that have been read. /// /// /// The buffer to hold the bytes that have been read. /// public byte[] Buffer { get; private set; } /// /// Gets or sets the exception that was thrown while reading from the /// . /// /// /// The exception that was thrown while reading from the , /// or null if no exception was thrown. /// public Exception Exception { get; private set; } /// /// Signals that the total number of bytes has been read successfully. /// public void Complete() { _result = SocketReadResult.Complete; _socketReadComplete.Set(); } /// /// Signals that the socket read failed. /// /// The that caused the read to fail. public void Fail(Exception cause) { Exception = cause; _result = SocketReadResult.Failed; _socketReadComplete.Set(); } /// /// Signals that the connection to the server was lost. /// public void ConnectionLost() { _result = SocketReadResult.ConnectionLost; _socketReadComplete.Set(); } public SocketReadResult Wait() { _socketReadComplete.WaitOne(); _socketReadComplete.Dispose(); _socketReadComplete = null; return _result; } } private enum SocketReadResult { Complete, ConnectionLost, Failed } #endif } }