using System; using System.Collections.Generic; using System.Globalization; using System.Net; using System.Net.Sockets; using System.Security.Cryptography; using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; using Renci.SshNet.Common; using Renci.SshNet.Compression; using Renci.SshNet.Connection; using Renci.SshNet.Messages; using Renci.SshNet.Messages.Transport; using Renci.SshNet.Security; using Renci.SshNet.Security.Cryptography; using Renci.SshNet.Tests.Common; namespace Renci.SshNet.Tests.Classes { [TestClass] public abstract class SessionTest_ConnectingBase { internal Mock ServiceFactoryMock { get; private set; } internal Mock SocketFactoryMock { get; private set; } internal Mock ConnectorMock { get; private set; } private Mock _protocolVersionExchangeMock; private Mock _keyExchangeMock; private Mock _clientAuthenticationMock; private IPEndPoint _serverEndPoint; private string[] _keyExchangeAlgorithms; private bool _authenticationStarted; private SocketFactory _socketFactory; protected Random Random { get; private set; } protected byte[] SessionId { get; private set; } protected ConnectionInfo ConnectionInfo { get; private set; } protected IList DisconnectedRegister { get; private set; } protected IList> DisconnectReceivedRegister { get; private set; } protected IList ErrorOccurredRegister { get; private set; } protected AsyncSocketListener ServerListener { get; private set; } protected IList ServerBytesReceivedRegister { get; private set; } protected Session Session { get; private set; } protected Socket ClientSocket { get; private set; } protected Socket ServerSocket { get; private set; } protected SshIdentification ServerIdentification { get; set; } protected virtual bool ServerSupportsStrictKex { get; } protected virtual bool ServerResetsSequenceAfterSendingNewKeys { get { return ServerSupportsStrictKex; } } protected uint ServerOutboundPacketSequence { get; set; } [TestInitialize] public void Setup() { CreateMocks(); SetupData(); SetupMocks(); } protected virtual void ActionBeforeKexInit() { } protected virtual void ActionAfterKexInit() { } [TestCleanup] public void TearDown() { if (ServerListener != null) { ServerListener.Dispose(); ServerListener = null; } if (ServerSocket != null) { ServerSocket.Dispose(); ServerSocket = null; } if (Session != null) { Session.Dispose(); Session = null; } if (ClientSocket != null && ClientSocket.Connected) { ClientSocket.Shutdown(SocketShutdown.Both); ClientSocket.Dispose(); } } protected virtual void SetupData() { Random = new Random(); _serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122); ConnectionInfo = new ConnectionInfo( _serverEndPoint.Address.ToString(), _serverEndPoint.Port, "user", new PasswordAuthenticationMethod("user", "password")) { Timeout = TimeSpan.FromSeconds(20) }; _keyExchangeAlgorithms = ServerSupportsStrictKex ? [Random.Next().ToString(CultureInfo.InvariantCulture), "kex-strict-s-v00@openssh.com"] : [Random.Next().ToString(CultureInfo.InvariantCulture)]; SessionId = new byte[10]; Random.NextBytes(SessionId); DisconnectedRegister = new List(); DisconnectReceivedRegister = new List>(); ErrorOccurredRegister = new List(); ServerBytesReceivedRegister = new List(); ServerIdentification = new SshIdentification("2.0", "OurServerStub"); _authenticationStarted = false; _socketFactory = new SocketFactory(); Session = new Session(ConnectionInfo, ServiceFactoryMock.Object, SocketFactoryMock.Object); Session.Disconnected += (sender, args) => DisconnectedRegister.Add(args); Session.DisconnectReceived += (sender, args) => DisconnectReceivedRegister.Add(args); Session.ErrorOccured += (sender, args) => ErrorOccurredRegister.Add(args); ServerListener = new AsyncSocketListener(_serverEndPoint) { ShutdownRemoteCommunicationSocket = false }; ServerListener.Connected += socket => { ServerSocket = socket; ActionBeforeKexInit(); var keyExchangeInitMessage = new KeyExchangeInitMessage { CompressionAlgorithmsClientToServer = new string[0], CompressionAlgorithmsServerToClient = new string[0], EncryptionAlgorithmsClientToServer = new string[0], EncryptionAlgorithmsServerToClient = new string[0], KeyExchangeAlgorithms = _keyExchangeAlgorithms, LanguagesClientToServer = new string[0], LanguagesServerToClient = new string[0], MacAlgorithmsClientToServer = new string[0], MacAlgorithmsServerToClient = new string[0], ServerHostKeyAlgorithms = new string[0] }; var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null); _ = ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None); ServerOutboundPacketSequence++; }; ServerListener.BytesReceived += (received, socket) => { ServerBytesReceivedRegister.Add(received); if (received.Length > 5 && received[5] == 20) { ActionAfterKexInit(); var newKeysMessage = new NewKeysMessage(); var newKeys = newKeysMessage.GetPacket(8, null); _ = ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None); if (ServerResetsSequenceAfterSendingNewKeys) { ServerOutboundPacketSequence = 0; } else { ServerOutboundPacketSequence++; } if (!_authenticationStarted) { var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication) .Build(ServerOutboundPacketSequence); var hash = Abstractions.CryptoAbstraction.CreateSHA256().ComputeHash(serviceAcceptMessage); var packet = new byte[serviceAcceptMessage.Length - 4 + hash.Length]; Array.Copy(serviceAcceptMessage, 4, packet, 0, serviceAcceptMessage.Length - 4); Array.Copy(hash, 0, packet, serviceAcceptMessage.Length - 4, hash.Length); _ = ServerSocket.Send(packet, 0, packet.Length, SocketFlags.None); ServerOutboundPacketSequence++; _authenticationStarted = true; } } }; ServerListener.Start(); ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo); } private void CreateMocks() { ServiceFactoryMock = new Mock(MockBehavior.Strict); SocketFactoryMock = new Mock(MockBehavior.Strict); ConnectorMock = new Mock(MockBehavior.Strict); _protocolVersionExchangeMock = new Mock(MockBehavior.Strict); _keyExchangeMock = new Mock(MockBehavior.Strict); _clientAuthenticationMock = new Mock(MockBehavior.Strict); } private void SetupMocks() { _ = ServiceFactoryMock.Setup(p => p.CreateConnector(ConnectionInfo, SocketFactoryMock.Object)) .Returns(ConnectorMock.Object); _ = ConnectorMock.Setup(p => p.Connect(ConnectionInfo)) .Returns(ClientSocket); _ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange()) .Returns(_protocolVersionExchangeMock.Object); _ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout)) .Returns(() => ServerIdentification); _ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, _keyExchangeAlgorithms)).Returns(_keyExchangeMock.Object); _ = _keyExchangeMock.Setup(p => p.Name) .Returns(_keyExchangeAlgorithms[0]); _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny(), false)); _ = _keyExchangeMock.Setup(p => p.ExchangeHash) .Returns(SessionId); _ = _keyExchangeMock.Setup(p => p.CreateServerCipher(out It.Ref.IsAny)) .Returns((ref bool serverAead) => { serverAead = false; return (Cipher)null; }); _ = _keyExchangeMock.Setup(p => p.CreateClientCipher(out It.Ref.IsAny)) .Returns((ref bool clientAead) => { clientAead = false; return (Cipher)null; }); _ = _keyExchangeMock.Setup(p => p.CreateServerHash(out It.Ref.IsAny)) .Returns((ref bool serverEtm) => { serverEtm = false; return SHA256.Create(); }); _ = _keyExchangeMock.Setup(p => p.CreateClientHash(out It.Ref.IsAny)) .Returns((ref bool clientEtm) => { clientEtm = false; return (HashAlgorithm)null; }); _ = _keyExchangeMock.Setup(p => p.CreateCompressor()) .Returns((Compressor)null); _ = _keyExchangeMock.Setup(p => p.CreateDecompressor()) .Returns((Compressor)null); _ = _keyExchangeMock.Setup(p => p.Dispose()); _ = ServiceFactoryMock.Setup(p => p.CreateClientAuthentication()) .Returns(_clientAuthenticationMock.Object); _ = _clientAuthenticationMock.Setup(p => p.Authenticate(ConnectionInfo, Session)); } private class ServiceAcceptMessageBuilder { private readonly ServiceName _serviceName; private ServiceAcceptMessageBuilder(ServiceName serviceName) { _serviceName = serviceName; } public static ServiceAcceptMessageBuilder Create(ServiceName serviceName) { return new ServiceAcceptMessageBuilder(serviceName); } public byte[] Build(uint sequence) { var serviceName = _serviceName.ToArray(); var target = new ServiceAcceptMessage(); var sshDataStream = new SshDataStream(4 + 4 + 1 + 1 + 4 + serviceName.Length); sshDataStream.Write(sequence); sshDataStream.Write((uint)(sshDataStream.Capacity - 8)); //sequence and packet length sshDataStream.WriteByte(0); // padding length sshDataStream.WriteByte(target.MessageNumber); sshDataStream.WriteBinary(serviceName); return sshDataStream.ToArray(); } } } }