using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography;
using Renci.SshNet.Common;
using Renci.SshNet.Compression;
using Renci.SshNet.Messages;
using Renci.SshNet.Messages.Transport;
using Renci.SshNet.Security.Cryptography;
namespace Renci.SshNet.Security
{
    /// 
    /// Represents base class for different key exchange algorithm implementations
    /// 
    public abstract class KeyExchange : Algorithm, IDisposable
    {
        private CipherInfo _clientCipherInfo;
        private CipherInfo _serverCipherInfo;
        private HashInfo _clientHashInfo;
        private HashInfo _serverHashInfo;
        private Type _compressionType;
        private Type _decompressionType;
        /// 
        /// Gets or sets the session.
        /// 
        /// 
        /// The session.
        /// 
        protected Session Session { get; private set; }
        /// 
        /// Gets or sets key exchange shared key.
        /// 
        /// 
        /// The shared key.
        /// 
        public BigInteger SharedKey { get; protected set; }
        private byte[] _exchangeHash;
        /// 
        /// Gets the exchange hash.
        /// 
        /// The exchange hash.
        public byte[] ExchangeHash
        {
            get
            {
                if (this._exchangeHash == null)
                {
                    this._exchangeHash = this.CalculateHash();
                }
                return this._exchangeHash;
            }
        }
        /// 
        /// Occurs when host key received.
        /// 
        public event EventHandler HostKeyReceived;
        /// 
        /// Starts key exchange algorithm
        /// 
        /// The session.
        /// Key exchange init message.
        public virtual void Start(Session session, KeyExchangeInitMessage message)
        {
            this.Session = session;
            this.SendMessage(session.ClientInitMessage);
            //  Determine encryption algorithm
            var clientEncryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys
                                                 from a in message.EncryptionAlgorithmsClientToServer
                                                 where a == b
                                                 select a).FirstOrDefault();
            if (string.IsNullOrEmpty(clientEncryptionAlgorithmName))
            {
                throw new SshConnectionException("Client encryption algorithm not found", DisconnectReason.KeyExchangeFailed);
            }
            session.ConnectionInfo.CurrentClientEncryption = clientEncryptionAlgorithmName;
            //  Determine encryption algorithm
            var serverDecryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys
                                                 from a in message.EncryptionAlgorithmsServerToClient
                                                 where a == b
                                                 select a).FirstOrDefault();
            if (string.IsNullOrEmpty(serverDecryptionAlgorithmName))
            {
                throw new SshConnectionException("Server decryption algorithm not found", DisconnectReason.KeyExchangeFailed);
            }
            session.ConnectionInfo.CurrentServerEncryption = serverDecryptionAlgorithmName;
            //  Determine client hmac algorithm
            var clientHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys
                                           from a in message.MacAlgorithmsClientToServer
                                           where a == b
                                           select a).FirstOrDefault();
            if (string.IsNullOrEmpty(clientHmacAlgorithmName))
            {
                throw new SshConnectionException("Server HMAC algorithm not found", DisconnectReason.KeyExchangeFailed);
            }
            session.ConnectionInfo.CurrentClientHmacAlgorithm = clientHmacAlgorithmName;
            //  Determine server hmac algorithm
            var serverHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys
                                           from a in message.MacAlgorithmsServerToClient
                                           where a == b
                                           select a).FirstOrDefault();
            if (string.IsNullOrEmpty(serverHmacAlgorithmName))
            {
                throw new SshConnectionException("Server HMAC algorithm not found", DisconnectReason.KeyExchangeFailed);
            }
            session.ConnectionInfo.CurrentServerHmacAlgorithm = serverHmacAlgorithmName;
            //  Determine compression algorithm
            var compressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys
                                            from a in message.CompressionAlgorithmsClientToServer
                                            where a == b
                                            select a).LastOrDefault();
            if (string.IsNullOrEmpty(compressionAlgorithmName))
            {
                throw new SshConnectionException("Compression algorithm not found", DisconnectReason.KeyExchangeFailed);
            }
            session.ConnectionInfo.CurrentClientCompressionAlgorithm = compressionAlgorithmName;
            //  Determine decompression algorithm
            var decompressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys
                                              from a in message.CompressionAlgorithmsServerToClient
                                              where a == b
                                              select a).LastOrDefault();
            if (string.IsNullOrEmpty(decompressionAlgorithmName))
            {
                throw new SshConnectionException("Decompression algorithm not found", DisconnectReason.KeyExchangeFailed);
            }
            session.ConnectionInfo.CurrentServerCompressionAlgorithm = decompressionAlgorithmName;
            this._clientCipherInfo = session.ConnectionInfo.Encryptions[clientEncryptionAlgorithmName];
            this._serverCipherInfo = session.ConnectionInfo.Encryptions[serverDecryptionAlgorithmName];
            this._clientHashInfo = session.ConnectionInfo.HmacAlgorithms[clientHmacAlgorithmName];
            this._serverHashInfo = session.ConnectionInfo.HmacAlgorithms[serverHmacAlgorithmName];
            this._compressionType = session.ConnectionInfo.CompressionAlgorithms[compressionAlgorithmName];
            this._decompressionType = session.ConnectionInfo.CompressionAlgorithms[decompressionAlgorithmName];
        }
        /// 
        /// Finishes key exchange algorithm.
        /// 
        public virtual void Finish()
        {
            //  Validate hash
            if (this.ValidateExchangeHash())
            {
                this.SendMessage(new NewKeysMessage());
            }
            else
            {
                throw new SshConnectionException("Key exchange negotiation failed.", DisconnectReason.KeyExchangeFailed);
            }
        }
        /// 
        /// Creates the server side cipher to use.
        /// 
        /// Server cipher.
        public Cipher CreateServerCipher()
        {
            //  Resolve Session ID
            var sessionId = this.Session.SessionId ?? this.ExchangeHash;
            //  Calculate server to client initial IV
            var serverVector = this.Hash(this.GenerateSessionKey(this.SharedKey, this.ExchangeHash, 'B', sessionId));
            //  Calculate server to client encryption
            var serverKey = this.Hash(this.GenerateSessionKey(this.SharedKey, this.ExchangeHash, 'D', sessionId));
            serverKey = this.GenerateSessionKey(this.SharedKey, this.ExchangeHash, serverKey, this._serverCipherInfo.KeySize / 8);
            
            //  Create server cipher
            return this._serverCipherInfo.Cipher(serverKey, serverVector);
        }
        /// 
        /// Creates the client side cipher to use.
        /// 
        /// Client cipher.
        public Cipher CreateClientCipher()
        {
            //  Resolve Session ID
            var sessionId = this.Session.SessionId ?? this.ExchangeHash;
            //  Calculate client to server initial IV
            var clientVector = this.Hash(this.GenerateSessionKey(this.SharedKey, this.ExchangeHash, 'A', sessionId));
            //  Calculate client to server encryption
            var clientKey = this.Hash(this.GenerateSessionKey(this.SharedKey, this.ExchangeHash, 'C', sessionId));
            clientKey = this.GenerateSessionKey(this.SharedKey, this.ExchangeHash, clientKey, this._clientCipherInfo.KeySize / 8);
            //  Create client cipher
            return this._clientCipherInfo.Cipher(clientKey, clientVector);
        }
        /// 
        /// Creates the server side hash algorithm to use.
        /// 
        /// Hash algorithm
        public HashAlgorithm CreateServerHash()
        {
            //  Resolve Session ID
            var sessionId = this.Session.SessionId ?? this.ExchangeHash;
            var serverKey = this.Hash(this.GenerateSessionKey(this.SharedKey, this.ExchangeHash, 'F', sessionId));
            serverKey = this.GenerateSessionKey(this.SharedKey, this.ExchangeHash, serverKey, this._serverHashInfo.KeySize / 8);
            //return serverHMac;
            return this._serverHashInfo.HashAlgorithm(serverKey);
        }
        /// 
        /// Creates the client side hash algorithm to use.
        /// 
        /// Hash algorithm
        public HashAlgorithm CreateClientHash()
        {
            //  Resolve Session ID
            var sessionId = this.Session.SessionId ?? this.ExchangeHash;
            var clientKey = this.Hash(this.GenerateSessionKey(this.SharedKey, this.ExchangeHash, 'E', sessionId));
            
            clientKey = this.GenerateSessionKey(this.SharedKey, this.ExchangeHash, clientKey, this._clientHashInfo.KeySize / 8);
            //return clientHMac;
            return this._clientHashInfo.HashAlgorithm(clientKey);
        }
        /// 
        /// Creates the compression algorithm to use to deflate data.
        /// 
        /// Compression method.
        public Compressor CreateCompressor()
        {
            if (this._compressionType == null)
                return null;
            var compressor = this._compressionType.CreateInstance();
            compressor.Init(this.Session);
            return compressor;
        }
        /// 
        /// Creates the compression algorithm to use to inflate data.
        /// 
        /// Compression method.
        public Compressor CreateDecompressor()
        {
            if (this._compressionType == null)
                return null;
            var decompressor = this._decompressionType.CreateInstance();
            decompressor.Init(this.Session);
            return decompressor;
        }
        /// 
        /// Determines whether the specified host key can be trusted.
        /// 
        /// The host algorithm.
        /// 
        ///   true if the specified host can be trusted; otherwise, false.
        /// 
        protected bool CanTrustHostKey(KeyHostAlgorithm host)
        {
            var handlers = HostKeyReceived;
            if (handlers != null)
            {
                var args = new HostKeyEventArgs(host);
                handlers(this, args);
                return args.CanTrust;
            }
            return true;
        }
        /// 
        /// Validates the exchange hash.
        /// 
        /// true if exchange hash is valid; otherwise false.
        protected abstract bool ValidateExchangeHash();
        /// 
        /// Calculates key exchange hash value.
        /// 
        /// Key exchange hash.
        protected abstract byte[] CalculateHash();
        /// 
        /// Hashes the specified data bytes.
        /// 
        /// The hash data.
        /// 
        /// Hashed bytes
        /// 
        protected virtual byte[] Hash(byte[] hashData)
        {
            using (var sha1 = new SHA1Hash())
            {
                return sha1.ComputeHash(hashData, 0, hashData.Length);
            }
        }
        /// 
        /// Sends SSH message to the server
        /// 
        /// The message.
        protected void SendMessage(Message message)
        {
            this.Session.SendMessage(message);
        }
        /// 
        /// Generates the session key.
        /// 
        /// The shared key.
        /// The exchange hash.
        /// The key.
        /// The size.
        /// 
        private byte[] GenerateSessionKey(BigInteger sharedKey, byte[] exchangeHash, byte[] key, int size)
        {
            var result = new List(key);
            while (size > result.Count)
            {
                result.AddRange(this.Hash(new _SessionKeyAdjustment
                {
                    SharedKey = sharedKey,
                    ExcahngeHash = exchangeHash,
                    Key = key,
                }.GetBytes()));
            }
            return result.ToArray();
        }
        /// 
        /// Generates the session key.
        /// 
        /// The shared key.
        /// The exchange hash.
        /// The p.
        /// The session id.
        /// 
        private byte[] GenerateSessionKey(BigInteger sharedKey, byte[] exchangeHash, char p, byte[] sessionId)
        {
            return new _SessionKeyGeneration
            {
                SharedKey = sharedKey,
                ExchangeHash = exchangeHash,
                Char = p,
                SessionId = sessionId,
            }.GetBytes();
        }
        private class _SessionKeyGeneration : SshData
        {
            public BigInteger SharedKey { get; set; }
            public byte[] ExchangeHash { get; set; }
            public char Char { get; set; }
            public byte[] SessionId { get; set; }
            protected override void LoadData()
            {
                throw new NotImplementedException();
            }
            protected override void SaveData()
            {
                this.Write(this.SharedKey);
                this.Write(this.ExchangeHash);
                this.Write((byte)this.Char);
                this.Write(this.SessionId);
            }
        }
        private class _SessionKeyAdjustment : SshData
        {
            public BigInteger SharedKey { get; set; }
            public byte[] ExcahngeHash { get; set; }
            public byte[] Key { get; set; }
            protected override void LoadData()
            {
                throw new NotImplementedException();
            }
            protected override void SaveData()
            {
                this.Write(this.SharedKey);
                this.Write(this.ExcahngeHash);
                this.Write(this.Key);
            }
        }
        #region IDisposable Members
        /// 
        /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged ResourceMessages.
        /// 
        public void Dispose()
        {
            Dispose(true);
            GC.SuppressFinalize(this);
        }
        /// 
        /// Releases unmanaged and - optionally - managed resources
        /// 
        /// true to release both managed and unmanaged resources; false to release only unmanaged ResourceMessages.
        protected virtual void Dispose(bool disposing)
        {
        }
        /// 
        /// Releases unmanaged resources and performs other cleanup operations before the
        ///  is reclaimed by garbage collection.
        /// 
        ~KeyExchange()
        {
            // Do not re-create Dispose clean-up code here.
            // Calling Dispose(false) is optimal in terms of
            // readability and maintainability.
            Dispose(false);
        }
        #endregion
    }
}