ConnectorBase.cs 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. using System;
  2. using System.Net;
  3. using System.Net.Sockets;
  4. using System.Threading;
  5. using System.Threading.Tasks;
  6. using Renci.SshNet.Abstractions;
  7. using Renci.SshNet.Common;
  8. using Renci.SshNet.Messages.Transport;
  9. namespace Renci.SshNet.Connection
  10. {
  11. internal abstract class ConnectorBase : IConnector
  12. {
  13. protected ConnectorBase(ISocketFactory socketFactory)
  14. {
  15. if (socketFactory is null)
  16. {
  17. throw new ArgumentNullException(nameof(socketFactory));
  18. }
  19. SocketFactory = socketFactory;
  20. }
  21. internal ISocketFactory SocketFactory { get; private set; }
  22. public abstract Socket Connect(IConnectionInfo connectionInfo);
  23. public abstract Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken);
  24. /// <summary>
  25. /// Establishes a socket connection to the specified endpoint.
  26. /// </summary>
  27. /// <param name="endPoint">The <see cref="EndPoint"/> representing the server to connect to.</param>
  28. /// <param name="timeout">The maximum time to wait for the connection to be established.</param>
  29. /// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="ConnectionInfo.Timeout"/>.</exception>
  30. /// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
  31. protected Socket SocketConnect(EndPoint endPoint, TimeSpan timeout)
  32. {
  33. DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}'.", endPoint));
  34. var socket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
  35. try
  36. {
  37. SocketAbstraction.Connect(socket, endPoint, timeout);
  38. const int socketBufferSize = 10 * Session.MaximumSshPacketSize;
  39. socket.SendBufferSize = socketBufferSize;
  40. socket.ReceiveBufferSize = socketBufferSize;
  41. return socket;
  42. }
  43. catch (Exception)
  44. {
  45. socket.Dispose();
  46. throw;
  47. }
  48. }
  49. /// <summary>
  50. /// Establishes a socket connection to the specified endpoint.
  51. /// </summary>
  52. /// <param name="endPoint">The <see cref="EndPoint"/> representing the server to connect to.</param>
  53. /// <param name="cancellationToken">The cancellation token to observe.</param>
  54. /// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="ConnectionInfo.Timeout"/>.</exception>
  55. /// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
  56. protected async Task<Socket> SocketConnectAsync(EndPoint endPoint, CancellationToken cancellationToken)
  57. {
  58. cancellationToken.ThrowIfCancellationRequested();
  59. DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}'.", endPoint));
  60. var socket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
  61. try
  62. {
  63. await SocketAbstraction.ConnectAsync(socket, endPoint, cancellationToken).ConfigureAwait(false);
  64. const int socketBufferSize = 2 * Session.MaximumSshPacketSize;
  65. socket.SendBufferSize = socketBufferSize;
  66. socket.ReceiveBufferSize = socketBufferSize;
  67. return socket;
  68. }
  69. catch (Exception)
  70. {
  71. socket.Dispose();
  72. throw;
  73. }
  74. }
  75. protected static byte SocketReadByte(Socket socket)
  76. {
  77. var buffer = new byte[1];
  78. _ = SocketRead(socket, buffer, 0, 1, Timeout.InfiniteTimeSpan);
  79. return buffer[0];
  80. }
  81. protected static byte SocketReadByte(Socket socket, TimeSpan readTimeout)
  82. {
  83. var buffer = new byte[1];
  84. _ = SocketRead(socket, buffer, 0, 1, readTimeout);
  85. return buffer[0];
  86. }
  87. /// <summary>
  88. /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
  89. /// </summary>
  90. /// <param name="socket">The <see cref="Socket"/> to read from.</param>
  91. /// <param name="buffer">An array of type <see cref="byte"/> that is the storage location for the received data.</param>
  92. /// <param name="offset">The position in <paramref name="buffer"/> parameter to store the received data.</param>
  93. /// <param name="length">The number of bytes to read.</param>
  94. /// <returns>
  95. /// The number of bytes read.
  96. /// </returns>
  97. /// <exception cref="SshConnectionException">The socket is closed.</exception>
  98. /// <exception cref="SocketException">The read failed.</exception>
  99. protected static int SocketRead(Socket socket, byte[] buffer, int offset, int length)
  100. {
  101. return SocketRead(socket, buffer, offset, length, Timeout.InfiniteTimeSpan);
  102. }
  103. /// <summary>
  104. /// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
  105. /// </summary>
  106. /// <param name="socket">The <see cref="Socket"/> to read from.</param>
  107. /// <param name="buffer">An array of type <see cref="byte"/> that is the storage location for the received data.</param>
  108. /// <param name="offset">The position in <paramref name="buffer"/> parameter to store the received data.</param>
  109. /// <param name="length">The number of bytes to read.</param>
  110. /// <param name="readTimeout">The maximum time to wait until <paramref name="length"/> bytes have been received.</param>
  111. /// <returns>
  112. /// The number of bytes read.
  113. /// </returns>
  114. /// <exception cref="SshConnectionException">The socket is closed.</exception>
  115. /// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
  116. /// <exception cref="SocketException">The read failed.</exception>
  117. protected static int SocketRead(Socket socket, byte[] buffer, int offset, int length, TimeSpan readTimeout)
  118. {
  119. var bytesRead = SocketAbstraction.Read(socket, buffer, offset, length, readTimeout);
  120. if (bytesRead == 0)
  121. {
  122. throw new SshConnectionException("An established connection was aborted by the server.",
  123. DisconnectReason.ConnectionLost);
  124. }
  125. return bytesRead;
  126. }
  127. }
  128. }