from include import *
from Crypto.PublicKey import RSA
from Crypto.Hash import SHA1
from Crypto.Cipher import PKCS1_OAEP
from base64 import urlsafe_b64decode,urlsafe_b64encode
import time
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import hashlib
import hmac
from math import ceil
import binascii
from OpenSSL.crypto import load_certificate, FILETYPE_PEM, load_privatekey, dump_publickey,dump_privatekey
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
import json

class crypto:
    def __init__(self):
        pass

    def encrypt(self, key, message, cert_path):
        '''
        encrypt the message 

        returns :
        self.ciphertext = b64encode(IV) + b64encode(ciphertext)
        
        '''
        message = message.encode('utf-8')
        #logging.debug("The message before passing to encryption routine is %s",message)

        #debug("The message being encrypted is",self.message.decode())
        #start from here
        iv = get_random_bytes(12)
        cipher = AES.new(key, AES.MODE_GCM, nonce = iv)
        ciphertext,_ = cipher.encrypt_and_digest(message)
        #logging.info("nonce is %s",cipher.nonce)
        #logging.info("b64 nonce is %s", urlsafe_b64encode(cipher.nonce))
        #logging.debug("ciphertext before is %s", ciphertext)
        ciphertext = urlsafe_b64encode(cipher.nonce) + b"." + urlsafe_b64encode(ciphertext)
        #logging.debug("final ciphertext looks like %s",ciphertext.decode()) 
        return ciphertext


      
    def decrypt(self, session_key, init_vector, ciphertext):
        '''
        decrypt and  verify that the encryption was successful
        exception if it failed, or the tag did not match
        
        input-
        encrypted_key, private_key

        returns
        self.plaintext - contains the plaintext message
        '''
        
        try:
            cipher = AES.new(session_key, AES.MODE_GCM, nonce=init_vector)
            plaintext = cipher.decrypt(ciphertext)
            #logging.info("The decrypted message is: %s" , plaintext)
            #plaintext = json.loads(plaintext) 
            #print(urlsafe_b64encode(plaintext))           
            return plaintext
            
        except(KeyError,ValueError):
            logging.error("incorrect decryption")  
        


    
        
        
    def hash_sha1(self, cert_path):
        '''
        takes a message and hashes it using SHA1 using the public key provided 
        and return the binary digest

        '''
        # f = open(cert_path,'r')
        # key = f.read()
        # logging.info("The cert read is %s",key)
        # key = RSA.import_key(key)
        # public_key = key.exportKey()
        # logging.info("the public key is %s", public_key) 
        # logging.info("the public key is %s", public_key)
        # hash = SHA1.new()
        # hash.update(public_key)
        # logging.info("hash digest %s",hash.digest())


        #new method

        cert_file_string = open(cert_path, "rb").read()
        cert = load_certificate(FILETYPE_PEM, cert_file_string)
        sha1_fingerprint = cert.digest("sha1")
        # returns a : separated hex string
        sha1_fingerprint = "".join(sha1_fingerprint.decode().split(':'))
        binary_string = binascii.unhexlify(sha1_fingerprint)
        return binary_string

    
 
    def verify_certificate(self, kid, ldap):
        '''
        verify the digest of the new certificate matches that of the old one 
        input - kid stored in the AD
        current certificate path

        output
        True or False
        '''

        stored_digest = kid
        # find the hash of all the certificates
        for certificate in ldap.cert_path:
            if(stored_digest == self.hash_sha1(certificate)):
                ldap.current_cert_path = certificate
                return True
        return False    
        
    
    
    def hmac_sha256(self, key, data):
        '''
        utility function which hashes the given data with SHA256 with the key provided
        '''
        return hmac.new(key, data, hashlib.sha256).digest()


    def hkdf(self, length, ikm, salt, info=b""):
        '''
        hash based key derivation function, 
        uses the salt and ikm provided to 
        generate a random seeded key of the desired length provided as argument
        '''
        prk = self.hmac_sha256(salt if len(salt) > 0 else bytes([0]*32), ikm)
        t = b""
        okm = b""
        for i in range(ceil(length / 32)):
            t = self.hmac_sha256(prk, t + info + bytes([1+i]))
            okm += t
        return okm[:length]