using System; using System.Collections.Generic; using System.Linq; using System.Security.Cryptography; using Renci.SshNet.Common; using Renci.SshNet.Compression; using Renci.SshNet.Messages; using Renci.SshNet.Messages.Transport; using Renci.SshNet.Security.Cryptography; namespace Renci.SshNet.Security { /// /// Represents base class for different key exchange algorithm implementations /// public abstract class KeyExchange : Algorithm, IKeyExchange { private CipherInfo _clientCipherInfo; private CipherInfo _serverCipherInfo; private HashInfo _clientHashInfo; private HashInfo _serverHashInfo; private Type _compressionType; private Type _decompressionType; /// /// Gets or sets the session. /// /// /// The session. /// protected Session Session { get; private set; } /// /// Gets or sets key exchange shared key. /// /// /// The shared key. /// public BigInteger SharedKey { get; protected set; } private byte[] _exchangeHash; /// /// Gets the exchange hash. /// /// The exchange hash. public byte[] ExchangeHash { get { if (_exchangeHash == null) { _exchangeHash = CalculateHash(); } return _exchangeHash; } } /// /// Occurs when host key received. /// public event EventHandler HostKeyReceived; /// /// Starts key exchange algorithm /// /// The session. /// Key exchange init message. public virtual void Start(Session session, KeyExchangeInitMessage message) { Session = session; SendMessage(session.ClientInitMessage); // Determine encryption algorithm var clientEncryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys from a in message.EncryptionAlgorithmsClientToServer where a == b select a).FirstOrDefault(); if (string.IsNullOrEmpty(clientEncryptionAlgorithmName)) { throw new SshConnectionException("Client encryption algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentClientEncryption = clientEncryptionAlgorithmName; // Determine encryption algorithm var serverDecryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys from a in message.EncryptionAlgorithmsServerToClient where a == b select a).FirstOrDefault(); if (string.IsNullOrEmpty(serverDecryptionAlgorithmName)) { throw new SshConnectionException("Server decryption algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentServerEncryption = serverDecryptionAlgorithmName; // Determine client hmac algorithm var clientHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys from a in message.MacAlgorithmsClientToServer where a == b select a).FirstOrDefault(); if (string.IsNullOrEmpty(clientHmacAlgorithmName)) { throw new SshConnectionException("Server HMAC algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentClientHmacAlgorithm = clientHmacAlgorithmName; // Determine server hmac algorithm var serverHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys from a in message.MacAlgorithmsServerToClient where a == b select a).FirstOrDefault(); if (string.IsNullOrEmpty(serverHmacAlgorithmName)) { throw new SshConnectionException("Server HMAC algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentServerHmacAlgorithm = serverHmacAlgorithmName; // Determine compression algorithm var compressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys from a in message.CompressionAlgorithmsClientToServer where a == b select a).LastOrDefault(); if (string.IsNullOrEmpty(compressionAlgorithmName)) { throw new SshConnectionException("Compression algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentClientCompressionAlgorithm = compressionAlgorithmName; // Determine decompression algorithm var decompressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys from a in message.CompressionAlgorithmsServerToClient where a == b select a).LastOrDefault(); if (string.IsNullOrEmpty(decompressionAlgorithmName)) { throw new SshConnectionException("Decompression algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentServerCompressionAlgorithm = decompressionAlgorithmName; _clientCipherInfo = session.ConnectionInfo.Encryptions[clientEncryptionAlgorithmName]; _serverCipherInfo = session.ConnectionInfo.Encryptions[serverDecryptionAlgorithmName]; _clientHashInfo = session.ConnectionInfo.HmacAlgorithms[clientHmacAlgorithmName]; _serverHashInfo = session.ConnectionInfo.HmacAlgorithms[serverHmacAlgorithmName]; _compressionType = session.ConnectionInfo.CompressionAlgorithms[compressionAlgorithmName]; _decompressionType = session.ConnectionInfo.CompressionAlgorithms[decompressionAlgorithmName]; } /// /// Finishes key exchange algorithm. /// public virtual void Finish() { // Validate hash if (ValidateExchangeHash()) { SendMessage(new NewKeysMessage()); } else { throw new SshConnectionException("Key exchange negotiation failed.", DisconnectReason.KeyExchangeFailed); } } /// /// Creates the server side cipher to use. /// /// Server cipher. public Cipher CreateServerCipher() { // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; // Calculate server to client initial IV var serverVector = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'B', sessionId)); // Calculate server to client encryption var serverKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'D', sessionId)); serverKey = GenerateSessionKey(SharedKey, ExchangeHash, serverKey, _serverCipherInfo.KeySize / 8); // Create server cipher return _serverCipherInfo.Cipher(serverKey, serverVector); } /// /// Creates the client side cipher to use. /// /// Client cipher. public Cipher CreateClientCipher() { // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; // Calculate client to server initial IV var clientVector = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'A', sessionId)); // Calculate client to server encryption var clientKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'C', sessionId)); clientKey = GenerateSessionKey(SharedKey, ExchangeHash, clientKey, _clientCipherInfo.KeySize / 8); // Create client cipher return _clientCipherInfo.Cipher(clientKey, clientVector); } /// /// Creates the server side hash algorithm to use. /// /// Hash algorithm public HashAlgorithm CreateServerHash() { // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; var serverKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'F', sessionId)); serverKey = GenerateSessionKey(SharedKey, ExchangeHash, serverKey, _serverHashInfo.KeySize / 8); //return serverHMac; return _serverHashInfo.HashAlgorithm(serverKey); } /// /// Creates the client side hash algorithm to use. /// /// Hash algorithm public HashAlgorithm CreateClientHash() { // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; var clientKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'E', sessionId)); clientKey = GenerateSessionKey(SharedKey, ExchangeHash, clientKey, _clientHashInfo.KeySize / 8); //return clientHMac; return _clientHashInfo.HashAlgorithm(clientKey); } /// /// Creates the compression algorithm to use to deflate data. /// /// Compression method. public Compressor CreateCompressor() { if (_compressionType == null) return null; var compressor = _compressionType.CreateInstance(); compressor.Init(Session); return compressor; } /// /// Creates the compression algorithm to use to inflate data. /// /// Compression method. public Compressor CreateDecompressor() { if (_compressionType == null) return null; var decompressor = _decompressionType.CreateInstance(); decompressor.Init(Session); return decompressor; } /// /// Determines whether the specified host key can be trusted. /// /// The host algorithm. /// /// true if the specified host can be trusted; otherwise, false. /// protected bool CanTrustHostKey(KeyHostAlgorithm host) { var handlers = HostKeyReceived; if (handlers != null) { var args = new HostKeyEventArgs(host); handlers(this, args); return args.CanTrust; } return true; } /// /// Validates the exchange hash. /// /// true if exchange hash is valid; otherwise false. protected abstract bool ValidateExchangeHash(); /// /// Calculates key exchange hash value. /// /// Key exchange hash. protected abstract byte[] CalculateHash(); /// /// Hashes the specified data bytes. /// /// The hash data. /// /// Hashed bytes /// protected virtual byte[] Hash(byte[] hashData) { using (var sha1 = HashAlgorithmFactory.CreateSHA1()) { return sha1.ComputeHash(hashData, 0, hashData.Length); } } /// /// Sends SSH message to the server /// /// The message. protected void SendMessage(Message message) { Session.SendMessage(message); } /// /// Generates the session key. /// /// The shared key. /// The exchange hash. /// The key. /// The size. /// private byte[] GenerateSessionKey(BigInteger sharedKey, byte[] exchangeHash, byte[] key, int size) { var result = new List(key); while (size > result.Count) { result.AddRange(Hash(new _SessionKeyAdjustment { SharedKey = sharedKey, ExcahngeHash = exchangeHash, Key = key, }.GetBytes())); } return result.ToArray(); } /// /// Generates the session key. /// /// The shared key. /// The exchange hash. /// The p. /// The session id. /// private byte[] GenerateSessionKey(BigInteger sharedKey, byte[] exchangeHash, char p, byte[] sessionId) { return new _SessionKeyGeneration { SharedKey = sharedKey, ExchangeHash = exchangeHash, Char = p, SessionId = sessionId, }.GetBytes(); } private class _SessionKeyGeneration : SshData { #if TUNING private byte[] _sharedKey; public BigInteger SharedKey { private get { return _sharedKey.ToBigInteger(); } set { _sharedKey = value.ToByteArray().Reverse(); } } #else public BigInteger SharedKey { get; set; } #endif public byte[] ExchangeHash { get; set; } public char Char { get; set; } public byte[] SessionId { get; set; } #if TUNING /// /// Gets the size of the message in bytes. /// /// /// The size of the messages in bytes. /// protected override int BufferCapacity { get { var capacity = base.BufferCapacity; capacity += 4; // SharedKey length capacity += _sharedKey.Length; // SharedKey capacity += ExchangeHash.Length; // ExchangeHash capacity += 1; // Char capacity += SessionId.Length; // SessionId return capacity; } } #endif protected override void LoadData() { throw new NotImplementedException(); } protected override void SaveData() { #if TUNING WriteBinaryString(_sharedKey); #else this.Write(this.SharedKey); #endif Write(ExchangeHash); Write((byte)Char); Write(SessionId); } } private class _SessionKeyAdjustment : SshData { #if TUNING private byte[] _sharedKey; public BigInteger SharedKey { private get { return _sharedKey.ToBigInteger(); } set { _sharedKey = value.ToByteArray().Reverse(); } } #else public BigInteger SharedKey { get; set; } #endif public byte[] ExcahngeHash { get; set; } public byte[] Key { get; set; } #if TUNING /// /// Gets the size of the message in bytes. /// /// /// The size of the messages in bytes. /// protected override int BufferCapacity { get { var capacity = base.BufferCapacity; capacity += 4; // SharedKey length capacity += _sharedKey.Length; // SharedKey capacity += ExcahngeHash.Length; // ExchangeHash capacity += Key.Length; // Key return capacity; } } #endif protected override void LoadData() { throw new NotImplementedException(); } protected override void SaveData() { #if TUNING WriteBinaryString(_sharedKey); #else this.Write(this.SharedKey); #endif Write(ExcahngeHash); Write(Key); } } #region IDisposable Members /// /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged ResourceMessages. /// public void Dispose() { Dispose(true); GC.SuppressFinalize(this); } /// /// Releases unmanaged and - optionally - managed resources /// /// true to release both managed and unmanaged resources; false to release only unmanaged ResourceMessages. protected virtual void Dispose(bool disposing) { } /// /// Releases unmanaged resources and performs other cleanup operations before the /// is reclaimed by garbage collection. /// ~KeyExchange() { // Do not re-create Dispose clean-up code here. // Calling Dispose(false) is optimal in terms of // readability and maintainability. Dispose(false); } #endregion } }