SocketExtensions.cs 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #if !NET6_0_OR_GREATER
  2. using System;
  3. using System.Net;
  4. using System.Net.Sockets;
  5. using System.Runtime.CompilerServices;
  6. using System.Threading;
  7. using System.Threading.Tasks;
  8. namespace Renci.SshNet.Abstractions
  9. {
  10. // Async helpers based on https://devblogs.microsoft.com/pfxteam/awaiting-socket-operations/
  11. internal static class SocketExtensions
  12. {
  13. private sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, INotifyCompletion
  14. {
  15. private static readonly Action SENTINEL = () => { };
  16. private bool _isCancelled;
  17. private Action _continuationAction;
  18. public AwaitableSocketAsyncEventArgs()
  19. {
  20. Completed += (sender, e) => SetCompleted();
  21. }
  22. public AwaitableSocketAsyncEventArgs ExecuteAsync(Func<SocketAsyncEventArgs, bool> func)
  23. {
  24. if (!func(this))
  25. {
  26. SetCompleted();
  27. }
  28. return this;
  29. }
  30. private void SetCompleted()
  31. {
  32. IsCompleted = true;
  33. var continuation = Interlocked.Exchange(ref _continuationAction, SENTINEL);
  34. if (continuation is not null)
  35. {
  36. continuation();
  37. }
  38. }
  39. public void SetCancelled()
  40. {
  41. _isCancelled = true;
  42. SetCompleted();
  43. }
  44. public AwaitableSocketAsyncEventArgs GetAwaiter()
  45. {
  46. return this;
  47. }
  48. public bool IsCompleted { get; private set; }
  49. void INotifyCompletion.OnCompleted(Action continuation)
  50. {
  51. if (_continuationAction == SENTINEL || Interlocked.CompareExchange(ref _continuationAction, continuation, comparand: null) == SENTINEL)
  52. {
  53. // We have already completed; run continuation asynchronously
  54. _ = Task.Run(continuation);
  55. }
  56. }
  57. public void GetResult()
  58. {
  59. if (_isCancelled)
  60. {
  61. throw new TaskCanceledException();
  62. }
  63. if (!IsCompleted)
  64. {
  65. // We don't support sync/async
  66. throw new InvalidOperationException("The asynchronous operation has not yet completed.");
  67. }
  68. if (SocketError != SocketError.Success)
  69. {
  70. throw new SocketException((int)SocketError);
  71. }
  72. }
  73. }
  74. public static async Task ConnectAsync(this Socket socket, EndPoint remoteEndpoint, CancellationToken cancellationToken)
  75. {
  76. cancellationToken.ThrowIfCancellationRequested();
  77. using (var args = new AwaitableSocketAsyncEventArgs())
  78. {
  79. args.RemoteEndPoint = remoteEndpoint;
  80. #if NET || NETSTANDARD2_1_OR_GREATER
  81. await using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs)o).SetCancelled(), args, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false))
  82. #else
  83. using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs)o).SetCancelled(), args, useSynchronizationContext: false))
  84. #endif // NET || NETSTANDARD2_1_OR_GREATER
  85. {
  86. await args.ExecuteAsync(socket.ConnectAsync);
  87. }
  88. }
  89. }
  90. public static async Task<int> ReceiveAsync(this Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
  91. {
  92. cancellationToken.ThrowIfCancellationRequested();
  93. using (var args = new AwaitableSocketAsyncEventArgs())
  94. {
  95. args.SetBuffer(buffer, offset, length);
  96. #if NET || NETSTANDARD2_1_OR_GREATER
  97. await using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs)o).SetCancelled(), args, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false))
  98. #else
  99. using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs)o).SetCancelled(), args, useSynchronizationContext: false))
  100. #endif // NET || NETSTANDARD2_1_OR_GREATER
  101. {
  102. await args.ExecuteAsync(socket.ReceiveAsync);
  103. }
  104. return args.BytesTransferred;
  105. }
  106. }
  107. }
  108. }
  109. #endif