using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography;
using Microsoft.Extensions.Logging;
using Renci.SshNet.Common;
using Renci.SshNet.Compression;
using Renci.SshNet.Messages;
using Renci.SshNet.Messages.Transport;
using Renci.SshNet.Security.Cryptography;
namespace Renci.SshNet.Security
{
///
/// Represents base class for different key exchange algorithm implementations.
///
public abstract class KeyExchange : Algorithm, IKeyExchange
{
private readonly ILogger _logger;
private CipherInfo _clientCipherInfo;
private CipherInfo _serverCipherInfo;
private HashInfo _clientHashInfo;
private HashInfo _serverHashInfo;
private Func _compressorFactory;
private Func _decompressorFactory;
///
/// Gets the session.
///
///
/// The session.
///
protected Session Session { get; private set; }
///
/// Gets or sets key exchange shared key.
///
///
/// The shared key.
///
public byte[] SharedKey { get; protected set; }
private byte[] _exchangeHash;
///
/// Gets the exchange hash.
///
/// The exchange hash.
public byte[] ExchangeHash
{
get
{
_exchangeHash ??= CalculateHash();
return _exchangeHash;
}
}
///
/// Occurs when host key received.
///
public event EventHandler HostKeyReceived;
///
/// Initializes a new instance of the class.
///
protected KeyExchange()
{
_logger = SshNetLoggingConfiguration.LoggerFactory.CreateLogger(GetType());
}
///
public virtual void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
Session = session;
if (sendClientInitMessage)
{
SendMessage(session.ClientInitMessage);
}
// Determine client encryption algorithm
var clientEncryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys
from a in message.EncryptionAlgorithmsClientToServer
where a == b
select a).FirstOrDefault();
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.LogTrace("[{SessionId}] Encryption client to server: we offer {WeOffer}",
Session.SessionIdHex,
session.ConnectionInfo.Encryptions.Keys.Join(","));
_logger.LogTrace("[{SessionId}] Encryption client to server: they offer {TheyOffer}",
Session.SessionIdHex,
message.EncryptionAlgorithmsClientToServer.Join(","));
}
if (string.IsNullOrEmpty(clientEncryptionAlgorithmName))
{
throw new SshConnectionException("Client encryption algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentClientEncryption = clientEncryptionAlgorithmName;
_clientCipherInfo = session.ConnectionInfo.Encryptions[clientEncryptionAlgorithmName];
// Determine server encryption algorithm
var serverDecryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys
from a in message.EncryptionAlgorithmsServerToClient
where a == b
select a).FirstOrDefault();
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.LogTrace("[{SessionId}] Encryption server to client: we offer {WeOffer}",
Session.SessionIdHex,
session.ConnectionInfo.Encryptions.Keys.Join(","));
_logger.LogTrace("[{SessionId}] Encryption server to client: they offer {TheyOffer}",
Session.SessionIdHex,
message.EncryptionAlgorithmsServerToClient.Join(","));
}
if (string.IsNullOrEmpty(serverDecryptionAlgorithmName))
{
throw new SshConnectionException("Server decryption algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentServerEncryption = serverDecryptionAlgorithmName;
_serverCipherInfo = session.ConnectionInfo.Encryptions[serverDecryptionAlgorithmName];
if (!_clientCipherInfo.IsAead)
{
// Determine client hmac algorithm
var clientHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys
from a in message.MacAlgorithmsClientToServer
where a == b
select a).FirstOrDefault();
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.LogTrace("[{SessionId}] MAC client to server: we offer {WeOffer}",
Session.SessionIdHex,
session.ConnectionInfo.HmacAlgorithms.Keys.Join(","));
_logger.LogTrace("[{SessionId}] MAC client to server: they offer {TheyOffer}",
Session.SessionIdHex,
message.MacAlgorithmsClientToServer.Join(","));
}
if (string.IsNullOrEmpty(clientHmacAlgorithmName))
{
throw new SshConnectionException("Client HMAC algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentClientHmacAlgorithm = clientHmacAlgorithmName;
_clientHashInfo = session.ConnectionInfo.HmacAlgorithms[clientHmacAlgorithmName];
}
if (!_serverCipherInfo.IsAead)
{
// Determine server hmac algorithm
var serverHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys
from a in message.MacAlgorithmsServerToClient
where a == b
select a).FirstOrDefault();
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.LogTrace("[{SessionId}] MAC server to client: we offer {WeOffer}",
Session.SessionIdHex,
session.ConnectionInfo.HmacAlgorithms.Keys.Join(","));
_logger.LogTrace("[{SessionId}] MAC server to client: they offer {TheyOffer}",
Session.SessionIdHex,
message.MacAlgorithmsServerToClient.Join(","));
}
if (string.IsNullOrEmpty(serverHmacAlgorithmName))
{
throw new SshConnectionException("Server HMAC algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentServerHmacAlgorithm = serverHmacAlgorithmName;
_serverHashInfo = session.ConnectionInfo.HmacAlgorithms[serverHmacAlgorithmName];
}
// Determine compression algorithm
var compressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys
from a in message.CompressionAlgorithmsClientToServer
where a == b
select a).FirstOrDefault();
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.LogTrace("[{SessionId}] Compression client to server: we offer {WeOffer}",
Session.SessionIdHex,
session.ConnectionInfo.CompressionAlgorithms.Keys.Join(","));
_logger.LogTrace("[{SessionId}] Compression client to server: they offer {TheyOffer}",
Session.SessionIdHex,
message.CompressionAlgorithmsClientToServer.Join(","));
}
if (string.IsNullOrEmpty(compressionAlgorithmName))
{
throw new SshConnectionException("Compression algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentClientCompressionAlgorithm = compressionAlgorithmName;
_compressorFactory = session.ConnectionInfo.CompressionAlgorithms[compressionAlgorithmName];
// Determine decompression algorithm
var decompressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys
from a in message.CompressionAlgorithmsServerToClient
where a == b
select a).FirstOrDefault();
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.LogTrace("[{SessionId}] Compression server to client: we offer {WeOffer}",
Session.SessionIdHex,
session.ConnectionInfo.CompressionAlgorithms.Keys.Join(","));
_logger.LogTrace("[{SessionId}] Compression server to client: they offer {TheyOffer}",
Session.SessionIdHex,
message.CompressionAlgorithmsServerToClient.Join(","));
}
if (string.IsNullOrEmpty(decompressionAlgorithmName))
{
throw new SshConnectionException("Decompression algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentServerCompressionAlgorithm = decompressionAlgorithmName;
_decompressorFactory = session.ConnectionInfo.CompressionAlgorithms[decompressionAlgorithmName];
}
///
/// Finishes key exchange algorithm.
///
public virtual void Finish()
{
if (!ValidateExchangeHash())
{
throw new SshConnectionException("Key exchange negotiation failed.", DisconnectReason.KeyExchangeFailed);
}
SendMessage(new NewKeysMessage());
}
///
/// Creates the server side cipher to use.
///
/// to indicate the cipher is AEAD, to indicate the cipher is not AEAD.
/// Server cipher.
public Cipher CreateServerCipher(out bool isAead)
{
isAead = _serverCipherInfo.IsAead;
// Resolve Session ID
var sessionId = Session.SessionId ?? ExchangeHash;
// Calculate server to client initial IV
var serverVector = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'B', sessionId));
// Calculate server to client encryption
var serverKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'D', sessionId));
serverKey = GenerateSessionKey(SharedKey, ExchangeHash, serverKey, _serverCipherInfo.KeySize / 8);
_logger.LogDebug("[{SessionId}] Creating {ServerEncryption} server cipher.",
Session.SessionIdHex,
Session.ConnectionInfo.CurrentServerEncryption);
// Create server cipher
return _serverCipherInfo.Cipher(serverKey, serverVector);
}
///
/// Creates the client side cipher to use.
///
/// to indicate the cipher is AEAD, to indicate the cipher is not AEAD.
/// Client cipher.
public Cipher CreateClientCipher(out bool isAead)
{
isAead = _clientCipherInfo.IsAead;
// Resolve Session ID
var sessionId = Session.SessionId ?? ExchangeHash;
// Calculate client to server initial IV
var clientVector = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'A', sessionId));
// Calculate client to server encryption
var clientKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'C', sessionId));
clientKey = GenerateSessionKey(SharedKey, ExchangeHash, clientKey, _clientCipherInfo.KeySize / 8);
_logger.LogDebug("[{SessionId}] Creating {ClientEncryption} client cipher.",
Session.SessionIdHex,
Session.ConnectionInfo.CurrentClientEncryption);
// Create client cipher
return _clientCipherInfo.Cipher(clientKey, clientVector);
}
///
/// Creates the server side hash algorithm to use.
///
/// to enable encrypt-then-MAC, to use encrypt-and-MAC.
///
/// The server-side hash algorithm.
///
public HashAlgorithm CreateServerHash(out bool isEncryptThenMAC)
{
if (_serverHashInfo == null)
{
isEncryptThenMAC = false;
return null;
}
isEncryptThenMAC = _serverHashInfo.IsEncryptThenMAC;
// Resolve Session ID
var sessionId = Session.SessionId ?? ExchangeHash;
var serverKey = GenerateSessionKey(SharedKey,
ExchangeHash,
Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'F', sessionId)),
_serverHashInfo.KeySize / 8);
_logger.LogDebug("[{SessionId}] Creating {ServerHmacAlgorithm} server hmac algorithm.",
Session.SessionIdHex,
Session.ConnectionInfo.CurrentServerHmacAlgorithm);
return _serverHashInfo.HashAlgorithm(serverKey);
}
///
/// Creates the client side hash algorithm to use.
///
/// to enable encrypt-then-MAC, to use encrypt-and-MAC.
///
/// The client-side hash algorithm.
///
public HashAlgorithm CreateClientHash(out bool isEncryptThenMAC)
{
if (_clientHashInfo == null)
{
isEncryptThenMAC = false;
return null;
}
isEncryptThenMAC = _clientHashInfo.IsEncryptThenMAC;
// Resolve Session ID
var sessionId = Session.SessionId ?? ExchangeHash;
var clientKey = GenerateSessionKey(SharedKey,
ExchangeHash,
Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'E', sessionId)),
_clientHashInfo.KeySize / 8);
_logger.LogDebug("[{SessionId}] Creating {ClientHmacAlgorithm} client hmac algorithm.",
Session.SessionIdHex,
Session.ConnectionInfo.CurrentClientHmacAlgorithm);
return _clientHashInfo.HashAlgorithm(clientKey);
}
///
/// Creates the compression algorithm to use to deflate data.
///
///
/// The compression method.
///
public Compressor CreateCompressor()
{
if (_compressorFactory is null)
{
return null;
}
_logger.LogDebug("[{SessionId}] Creating {CompressionAlgorithm} client compressor.",
Session.SessionIdHex,
Session.ConnectionInfo.CurrentClientCompressionAlgorithm);
var compressor = _compressorFactory();
compressor.Init(Session);
return compressor;
}
///
/// Creates the compression algorithm to use to inflate data.
///
///
/// The decompression method.
///
public Compressor CreateDecompressor()
{
if (_decompressorFactory is null)
{
return null;
}
_logger.LogDebug("[{SessionId}] Creating {ServerCompressionAlgorithm} server decompressor.",
Session.SessionIdHex,
Session.ConnectionInfo.CurrentServerCompressionAlgorithm);
var decompressor = _decompressorFactory();
decompressor.Init(Session);
return decompressor;
}
///
/// Determines whether the specified host key can be trusted.
///
/// The host algorithm.
///
/// if the specified host can be trusted; otherwise, .
///
protected bool CanTrustHostKey(KeyHostAlgorithm host)
{
var handlers = HostKeyReceived;
if (handlers != null)
{
var args = new HostKeyEventArgs(host);
handlers(this, args);
return args.CanTrust;
}
return true;
}
///
/// Validates the exchange hash.
///
/// true if exchange hash is valid; otherwise false.
protected abstract bool ValidateExchangeHash();
private protected bool ValidateExchangeHash(byte[] encodedKey, byte[] encodedSignature)
{
var exchangeHash = CalculateHash();
// We need to inspect both the key and signature format identifers to find the correct
// HostAlgorithm instance. Example cases:
// Key identifier Signature identifier | Algorithm name
// ssh-rsa ssh-rsa | ssh-rsa
// ssh-rsa rsa-sha2-256 | rsa-sha2-256
// ssh-rsa-cert-v01@openssh.com ssh-rsa | ssh-rsa-cert-v01@openssh.com
// ssh-rsa-cert-v01@openssh.com rsa-sha2-256 | rsa-sha2-256-cert-v01@openssh.com
var signatureData = new KeyHostAlgorithm.SignatureKeyData();
signatureData.Load(encodedSignature);
string keyName;
using (var keyReader = new SshDataStream(encodedKey))
{
keyName = keyReader.ReadString();
}
string algorithmName;
if (signatureData.AlgorithmName.StartsWith("rsa-sha2", StringComparison.Ordinal))
{
algorithmName = keyName.Replace("ssh-rsa", signatureData.AlgorithmName);
}
else
{
algorithmName = keyName;
}
var keyAlgorithm = Session.ConnectionInfo.HostKeyAlgorithms[algorithmName](encodedKey);
Session.ConnectionInfo.CurrentHostKeyAlgorithm = algorithmName;
return keyAlgorithm.VerifySignatureBlob(exchangeHash, signatureData.Signature) && CanTrustHostKey(keyAlgorithm);
}
///
/// Calculates key exchange hash value.
///
/// Key exchange hash.
protected abstract byte[] CalculateHash();
///
/// Hashes the specified data bytes.
///
/// The hash data.
///
/// The hash of the data.
///
protected abstract byte[] Hash(byte[] hashData);
///
/// Sends SSH message to the server.
///
/// The message.
protected void SendMessage(Message message)
{
Session.SendMessage(message);
}
///
/// Generates the session key.
///
/// The shared key.
/// The exchange hash.
/// The key.
/// The size.
///
/// The session key.
///
private byte[] GenerateSessionKey(byte[] sharedKey, byte[] exchangeHash, byte[] key, int size)
{
var result = new List(key);
while (size > result.Count)
{
var sessionKeyAdjustment = new SessionKeyAdjustment
{
SharedKey = sharedKey,
ExchangeHash = exchangeHash,
Key = key,
};
result.AddRange(Hash(sessionKeyAdjustment.GetBytes()));
}
return result.ToArray();
}
///
/// Generates the session key.
///
/// The shared key.
/// The exchange hash.
/// The p.
/// The session id.
///
/// The session key.
///
private static byte[] GenerateSessionKey(byte[] sharedKey, byte[] exchangeHash, char p, byte[] sessionId)
{
var sessionKeyGeneration = new SessionKeyGeneration
{
SharedKey = sharedKey,
ExchangeHash = exchangeHash,
Char = p,
SessionId = sessionId
};
return sessionKeyGeneration.GetBytes();
}
private sealed class SessionKeyGeneration : SshData
{
public byte[] SharedKey { get; set; }
public byte[] ExchangeHash { get; set; }
public char Char { get; set; }
public byte[] SessionId { get; set; }
///
/// Gets the size of the message in bytes.
///
///
/// The size of the messages in bytes.
///
protected override int BufferCapacity
{
get
{
var capacity = base.BufferCapacity;
capacity += 4; // SharedKey length
capacity += SharedKey.Length; // SharedKey
capacity += ExchangeHash.Length; // ExchangeHash
capacity += 1; // Char
capacity += SessionId.Length; // SessionId
return capacity;
}
}
protected override void LoadData()
{
throw new NotImplementedException();
}
protected override void SaveData()
{
WriteBinaryString(SharedKey);
Write(ExchangeHash);
Write((byte)Char);
Write(SessionId);
}
}
private sealed class SessionKeyAdjustment : SshData
{
public byte[] SharedKey { get; set; }
public byte[] ExchangeHash { get; set; }
public byte[] Key { get; set; }
///
/// Gets the size of the message in bytes.
///
///
/// The size of the messages in bytes.
///
protected override int BufferCapacity
{
get
{
var capacity = base.BufferCapacity;
capacity += 4; // SharedKey length
capacity += SharedKey.Length; // SharedKey
capacity += ExchangeHash.Length; // ExchangeHash
capacity += Key.Length; // Key
return capacity;
}
}
protected override void LoadData()
{
throw new NotImplementedException();
}
protected override void SaveData()
{
WriteBinaryString(SharedKey);
Write(ExchangeHash);
Write(Key);
}
}
#region IDisposable Members
///
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
///
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
///
/// Releases unmanaged and - optionally - managed resources.
///
/// to release both managed and unmanaged resources; to release only unmanaged resources.
protected virtual void Dispose(bool disposing)
{
}
#endregion
}
}