using System; using System.Collections.Generic; using System.Linq; using System.Security.Cryptography; using Microsoft.Extensions.Logging; using Renci.SshNet.Common; using Renci.SshNet.Compression; using Renci.SshNet.Messages; using Renci.SshNet.Messages.Transport; using Renci.SshNet.Security.Cryptography; namespace Renci.SshNet.Security { /// /// Represents base class for different key exchange algorithm implementations. /// public abstract class KeyExchange : Algorithm, IKeyExchange { private readonly ILogger _logger; private CipherInfo _clientCipherInfo; private CipherInfo _serverCipherInfo; private HashInfo _clientHashInfo; private HashInfo _serverHashInfo; private Func _compressorFactory; private Func _decompressorFactory; /// /// Gets the session. /// /// /// The session. /// protected Session Session { get; private set; } /// /// Gets or sets key exchange shared key. /// /// /// The shared key. /// public byte[] SharedKey { get; protected set; } private byte[] _exchangeHash; /// /// Gets the exchange hash. /// /// The exchange hash. public byte[] ExchangeHash { get { _exchangeHash ??= CalculateHash(); return _exchangeHash; } } /// /// Occurs when host key received. /// public event EventHandler HostKeyReceived; /// /// Initializes a new instance of the class. /// protected KeyExchange() { _logger = SshNetLoggingConfiguration.LoggerFactory.CreateLogger(GetType()); } /// public virtual void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage) { Session = session; if (sendClientInitMessage) { SendMessage(session.ClientInitMessage); } // Determine client encryption algorithm var clientEncryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys from a in message.EncryptionAlgorithmsClientToServer where a == b select a).FirstOrDefault(); if (_logger.IsEnabled(LogLevel.Trace)) { _logger.LogTrace("[{SessionId}] Encryption client to server: we offer {WeOffer}", Session.SessionIdHex, session.ConnectionInfo.Encryptions.Keys.Join(",")); _logger.LogTrace("[{SessionId}] Encryption client to server: they offer {TheyOffer}", Session.SessionIdHex, message.EncryptionAlgorithmsClientToServer.Join(",")); } if (string.IsNullOrEmpty(clientEncryptionAlgorithmName)) { throw new SshConnectionException("Client encryption algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentClientEncryption = clientEncryptionAlgorithmName; _clientCipherInfo = session.ConnectionInfo.Encryptions[clientEncryptionAlgorithmName]; // Determine server encryption algorithm var serverDecryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys from a in message.EncryptionAlgorithmsServerToClient where a == b select a).FirstOrDefault(); if (_logger.IsEnabled(LogLevel.Trace)) { _logger.LogTrace("[{SessionId}] Encryption server to client: we offer {WeOffer}", Session.SessionIdHex, session.ConnectionInfo.Encryptions.Keys.Join(",")); _logger.LogTrace("[{SessionId}] Encryption server to client: they offer {TheyOffer}", Session.SessionIdHex, message.EncryptionAlgorithmsServerToClient.Join(",")); } if (string.IsNullOrEmpty(serverDecryptionAlgorithmName)) { throw new SshConnectionException("Server decryption algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentServerEncryption = serverDecryptionAlgorithmName; _serverCipherInfo = session.ConnectionInfo.Encryptions[serverDecryptionAlgorithmName]; if (!_clientCipherInfo.IsAead) { // Determine client hmac algorithm var clientHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys from a in message.MacAlgorithmsClientToServer where a == b select a).FirstOrDefault(); if (_logger.IsEnabled(LogLevel.Trace)) { _logger.LogTrace("[{SessionId}] MAC client to server: we offer {WeOffer}", Session.SessionIdHex, session.ConnectionInfo.HmacAlgorithms.Keys.Join(",")); _logger.LogTrace("[{SessionId}] MAC client to server: they offer {TheyOffer}", Session.SessionIdHex, message.MacAlgorithmsClientToServer.Join(",")); } if (string.IsNullOrEmpty(clientHmacAlgorithmName)) { throw new SshConnectionException("Client HMAC algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentClientHmacAlgorithm = clientHmacAlgorithmName; _clientHashInfo = session.ConnectionInfo.HmacAlgorithms[clientHmacAlgorithmName]; } if (!_serverCipherInfo.IsAead) { // Determine server hmac algorithm var serverHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys from a in message.MacAlgorithmsServerToClient where a == b select a).FirstOrDefault(); if (_logger.IsEnabled(LogLevel.Trace)) { _logger.LogTrace("[{SessionId}] MAC server to client: we offer {WeOffer}", Session.SessionIdHex, session.ConnectionInfo.HmacAlgorithms.Keys.Join(",")); _logger.LogTrace("[{SessionId}] MAC server to client: they offer {TheyOffer}", Session.SessionIdHex, message.MacAlgorithmsServerToClient.Join(",")); } if (string.IsNullOrEmpty(serverHmacAlgorithmName)) { throw new SshConnectionException("Server HMAC algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentServerHmacAlgorithm = serverHmacAlgorithmName; _serverHashInfo = session.ConnectionInfo.HmacAlgorithms[serverHmacAlgorithmName]; } // Determine compression algorithm var compressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys from a in message.CompressionAlgorithmsClientToServer where a == b select a).FirstOrDefault(); if (_logger.IsEnabled(LogLevel.Trace)) { _logger.LogTrace("[{SessionId}] Compression client to server: we offer {WeOffer}", Session.SessionIdHex, session.ConnectionInfo.CompressionAlgorithms.Keys.Join(",")); _logger.LogTrace("[{SessionId}] Compression client to server: they offer {TheyOffer}", Session.SessionIdHex, message.CompressionAlgorithmsClientToServer.Join(",")); } if (string.IsNullOrEmpty(compressionAlgorithmName)) { throw new SshConnectionException("Compression algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentClientCompressionAlgorithm = compressionAlgorithmName; _compressorFactory = session.ConnectionInfo.CompressionAlgorithms[compressionAlgorithmName]; // Determine decompression algorithm var decompressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys from a in message.CompressionAlgorithmsServerToClient where a == b select a).FirstOrDefault(); if (_logger.IsEnabled(LogLevel.Trace)) { _logger.LogTrace("[{SessionId}] Compression server to client: we offer {WeOffer}", Session.SessionIdHex, session.ConnectionInfo.CompressionAlgorithms.Keys.Join(",")); _logger.LogTrace("[{SessionId}] Compression server to client: they offer {TheyOffer}", Session.SessionIdHex, message.CompressionAlgorithmsServerToClient.Join(",")); } if (string.IsNullOrEmpty(decompressionAlgorithmName)) { throw new SshConnectionException("Decompression algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentServerCompressionAlgorithm = decompressionAlgorithmName; _decompressorFactory = session.ConnectionInfo.CompressionAlgorithms[decompressionAlgorithmName]; } /// /// Finishes key exchange algorithm. /// public virtual void Finish() { if (!ValidateExchangeHash()) { throw new SshConnectionException("Key exchange negotiation failed.", DisconnectReason.KeyExchangeFailed); } SendMessage(new NewKeysMessage()); } /// /// Creates the server side cipher to use. /// /// to indicate the cipher is AEAD, to indicate the cipher is not AEAD. /// Server cipher. public Cipher CreateServerCipher(out bool isAead) { isAead = _serverCipherInfo.IsAead; // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; // Calculate server to client initial IV var serverVector = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'B', sessionId)); // Calculate server to client encryption var serverKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'D', sessionId)); serverKey = GenerateSessionKey(SharedKey, ExchangeHash, serverKey, _serverCipherInfo.KeySize / 8); _logger.LogDebug("[{SessionId}] Creating {ServerEncryption} server cipher.", Session.SessionIdHex, Session.ConnectionInfo.CurrentServerEncryption); // Create server cipher return _serverCipherInfo.Cipher(serverKey, serverVector); } /// /// Creates the client side cipher to use. /// /// to indicate the cipher is AEAD, to indicate the cipher is not AEAD. /// Client cipher. public Cipher CreateClientCipher(out bool isAead) { isAead = _clientCipherInfo.IsAead; // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; // Calculate client to server initial IV var clientVector = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'A', sessionId)); // Calculate client to server encryption var clientKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'C', sessionId)); clientKey = GenerateSessionKey(SharedKey, ExchangeHash, clientKey, _clientCipherInfo.KeySize / 8); _logger.LogDebug("[{SessionId}] Creating {ClientEncryption} client cipher.", Session.SessionIdHex, Session.ConnectionInfo.CurrentClientEncryption); // Create client cipher return _clientCipherInfo.Cipher(clientKey, clientVector); } /// /// Creates the server side hash algorithm to use. /// /// to enable encrypt-then-MAC, to use encrypt-and-MAC. /// /// The server-side hash algorithm. /// public HashAlgorithm CreateServerHash(out bool isEncryptThenMAC) { if (_serverHashInfo == null) { isEncryptThenMAC = false; return null; } isEncryptThenMAC = _serverHashInfo.IsEncryptThenMAC; // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; var serverKey = GenerateSessionKey(SharedKey, ExchangeHash, Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'F', sessionId)), _serverHashInfo.KeySize / 8); _logger.LogDebug("[{SessionId}] Creating {ServerHmacAlgorithm} server hmac algorithm.", Session.SessionIdHex, Session.ConnectionInfo.CurrentServerHmacAlgorithm); return _serverHashInfo.HashAlgorithm(serverKey); } /// /// Creates the client side hash algorithm to use. /// /// to enable encrypt-then-MAC, to use encrypt-and-MAC. /// /// The client-side hash algorithm. /// public HashAlgorithm CreateClientHash(out bool isEncryptThenMAC) { if (_clientHashInfo == null) { isEncryptThenMAC = false; return null; } isEncryptThenMAC = _clientHashInfo.IsEncryptThenMAC; // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; var clientKey = GenerateSessionKey(SharedKey, ExchangeHash, Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'E', sessionId)), _clientHashInfo.KeySize / 8); _logger.LogDebug("[{SessionId}] Creating {ClientHmacAlgorithm} client hmac algorithm.", Session.SessionIdHex, Session.ConnectionInfo.CurrentClientHmacAlgorithm); return _clientHashInfo.HashAlgorithm(clientKey); } /// /// Creates the compression algorithm to use to deflate data. /// /// /// The compression method. /// public Compressor CreateCompressor() { if (_compressorFactory is null) { return null; } _logger.LogDebug("[{SessionId}] Creating {CompressionAlgorithm} client compressor.", Session.SessionIdHex, Session.ConnectionInfo.CurrentClientCompressionAlgorithm); var compressor = _compressorFactory(); compressor.Init(Session); return compressor; } /// /// Creates the compression algorithm to use to inflate data. /// /// /// The decompression method. /// public Compressor CreateDecompressor() { if (_decompressorFactory is null) { return null; } _logger.LogDebug("[{SessionId}] Creating {ServerCompressionAlgorithm} server decompressor.", Session.SessionIdHex, Session.ConnectionInfo.CurrentServerCompressionAlgorithm); var decompressor = _decompressorFactory(); decompressor.Init(Session); return decompressor; } /// /// Determines whether the specified host key can be trusted. /// /// The host algorithm. /// /// if the specified host can be trusted; otherwise, . /// protected bool CanTrustHostKey(KeyHostAlgorithm host) { var handlers = HostKeyReceived; if (handlers != null) { var args = new HostKeyEventArgs(host); handlers(this, args); return args.CanTrust; } return true; } /// /// Validates the exchange hash. /// /// true if exchange hash is valid; otherwise false. protected abstract bool ValidateExchangeHash(); private protected bool ValidateExchangeHash(byte[] encodedKey, byte[] encodedSignature) { var exchangeHash = CalculateHash(); // We need to inspect both the key and signature format identifers to find the correct // HostAlgorithm instance. Example cases: // Key identifier Signature identifier | Algorithm name // ssh-rsa ssh-rsa | ssh-rsa // ssh-rsa rsa-sha2-256 | rsa-sha2-256 // ssh-rsa-cert-v01@openssh.com ssh-rsa | ssh-rsa-cert-v01@openssh.com // ssh-rsa-cert-v01@openssh.com rsa-sha2-256 | rsa-sha2-256-cert-v01@openssh.com var signatureData = new KeyHostAlgorithm.SignatureKeyData(); signatureData.Load(encodedSignature); string keyName; using (var keyReader = new SshDataStream(encodedKey)) { keyName = keyReader.ReadString(); } string algorithmName; if (signatureData.AlgorithmName.StartsWith("rsa-sha2", StringComparison.Ordinal)) { algorithmName = keyName.Replace("ssh-rsa", signatureData.AlgorithmName); } else { algorithmName = keyName; } var keyAlgorithm = Session.ConnectionInfo.HostKeyAlgorithms[algorithmName](encodedKey); Session.ConnectionInfo.CurrentHostKeyAlgorithm = algorithmName; return keyAlgorithm.VerifySignatureBlob(exchangeHash, signatureData.Signature) && CanTrustHostKey(keyAlgorithm); } /// /// Calculates key exchange hash value. /// /// Key exchange hash. protected abstract byte[] CalculateHash(); /// /// Hashes the specified data bytes. /// /// The hash data. /// /// The hash of the data. /// protected abstract byte[] Hash(byte[] hashData); /// /// Sends SSH message to the server. /// /// The message. protected void SendMessage(Message message) { Session.SendMessage(message); } /// /// Generates the session key. /// /// The shared key. /// The exchange hash. /// The key. /// The size. /// /// The session key. /// private byte[] GenerateSessionKey(byte[] sharedKey, byte[] exchangeHash, byte[] key, int size) { var result = new List(key); while (size > result.Count) { var sessionKeyAdjustment = new SessionKeyAdjustment { SharedKey = sharedKey, ExchangeHash = exchangeHash, Key = key, }; result.AddRange(Hash(sessionKeyAdjustment.GetBytes())); } return result.ToArray(); } /// /// Generates the session key. /// /// The shared key. /// The exchange hash. /// The p. /// The session id. /// /// The session key. /// private static byte[] GenerateSessionKey(byte[] sharedKey, byte[] exchangeHash, char p, byte[] sessionId) { var sessionKeyGeneration = new SessionKeyGeneration { SharedKey = sharedKey, ExchangeHash = exchangeHash, Char = p, SessionId = sessionId }; return sessionKeyGeneration.GetBytes(); } private sealed class SessionKeyGeneration : SshData { public byte[] SharedKey { get; set; } public byte[] ExchangeHash { get; set; } public char Char { get; set; } public byte[] SessionId { get; set; } /// /// Gets the size of the message in bytes. /// /// /// The size of the messages in bytes. /// protected override int BufferCapacity { get { var capacity = base.BufferCapacity; capacity += 4; // SharedKey length capacity += SharedKey.Length; // SharedKey capacity += ExchangeHash.Length; // ExchangeHash capacity += 1; // Char capacity += SessionId.Length; // SessionId return capacity; } } protected override void LoadData() { throw new NotImplementedException(); } protected override void SaveData() { WriteBinaryString(SharedKey); Write(ExchangeHash); Write((byte)Char); Write(SessionId); } } private sealed class SessionKeyAdjustment : SshData { public byte[] SharedKey { get; set; } public byte[] ExchangeHash { get; set; } public byte[] Key { get; set; } /// /// Gets the size of the message in bytes. /// /// /// The size of the messages in bytes. /// protected override int BufferCapacity { get { var capacity = base.BufferCapacity; capacity += 4; // SharedKey length capacity += SharedKey.Length; // SharedKey capacity += ExchangeHash.Length; // ExchangeHash capacity += Key.Length; // Key return capacity; } } protected override void LoadData() { throw new NotImplementedException(); } protected override void SaveData() { WriteBinaryString(SharedKey); Write(ExchangeHash); Write(Key); } } #region IDisposable Members /// /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. /// public void Dispose() { Dispose(disposing: true); GC.SuppressFinalize(this); } /// /// Releases unmanaged and - optionally - managed resources. /// /// to release both managed and unmanaged resources; to release only unmanaged resources. protected virtual void Dispose(bool disposing) { } #endregion } }