from include import *
from Crypto.PublicKey import RSA
from OpenSSL.crypto import load_certificate, FILETYPE_PEM, load_privatekey, dump_publickey,dump_privatekey, FILETYPE_ASN1


class cert_upgrade:
    def __init__(
        self, host, 
        user_name, password,
        port,
        otp_attribute,
        old_cert_path,
        new_cert_path,
        ):
        '''
        initialise the object with the server credentials and 
        the admin username and password 
        '''
        self.host = host 
        self.user_name = user_name
        self.password = password
        self.server = Server(self.host, port, True) 
        self.otp_attribute = otp_attribute
        self.old_cert_path = old_cert_path
        self.new_cert_path = new_cert_path

    def connect(self):
        '''
        connect to the server using stored credentials
        '''
        #debug(self.server.info)
        try:
            self.connection = Connection(self.server, 
                self.user_name, self.password,
                auto_bind=True)
        except exceptions.LDAPSocketOpenError:
            logging.error('unable to reach the server, please check the server credentials and port number ') 
            sys.exit()   
        except exceptions.LDAPExceptionError as er:
            logging.error("could not connect to the server due to the following error %s",er)
            sys.exit()
        except:
            logging.error("could not connect to the server due to an error\n")
            sys.exit()
        logging.info('Connection successful..')
        return self.connection

    def retrieve_entries(
            self,search_base,
            search_filter = '(userParameters=*)',
            search_scope = SUBTREE,
            attributes = ALL_ATTRIBUTES,
            time_limit = 300, 
            ):
        '''
        search the AD for attributes, based on the the serach filter and base
        output:
        True - search is successful
        False - unsuccesful search
        self.connection.response will contain all the attributes of all the retrieved users

        '''    
         # paged search wrapped in a generator
        try:
            entry_generator = self.connection.extend.standard.paged_search(search_base = search_base,
                                                            search_filter =search_filter,
                                                            search_scope = SUBTREE,
                                                            attributes = [self.otp_attribute],
                                                            paged_size = 100,
                                                            generator = True)
        
            return entry_generator
        except Exception as er:
            logging.error("Error %s", type(er).__name__)
            return False

    def traverse_entries(self, entries):
        '''
        traverse the entries and perform certificate upgrading
        the main function traversing the attributes, fetching the values, decrypt the attribute, and reencrypt it with the new certificate
        '''
        self.crypt = crypto.crypto()
        entered = False
        self.number_of_entries = 0
        self.total_entries = 0
        self.number_of_entries_not_modified = 0
        self.not_modified = []
        
        #modify the parameter
        try:
            for entry in entries:
                self.total_entries += 1
                entered = True
                if 'dn' not in entry:
                    continue
                DN = entry['dn']
                attribute = entry['attributes'][self.otp_attribute]
                logging.info("Updating the attribute %s", attribute)
                updated_attribute = self.decrypt_and_update(attribute, DN)
                if(updated_attribute != False):
                    if(self.convert_to_json_and_store(updated_attribute, DN) == True):
                        self.number_of_entries += 1
                        logging.info("updation successful for the user %s",DN)
                    else:
                        self.not_modified.append(DN)
                        logging.info("error in updation for %s",DN)
                else:
                    self.not_modified.append(DN)
                    self.number_of_entries_not_modified += 1
                    logging.error("error in updating certificate for %s",DN)        
        except Exception as er:
            logging.error("Error in traversing the entries due to following error %s", type(er).__name__ )    
        if(not(entered)):
            logging.error("No Entries to be updated")
        

    def decrypt_and_update(
        self, attribute, DN
        ):
        '''
        verify that the encryption is successful and decrypt the data and headers

        input - attribute where the encrypted entry is stored

        output - attribute which is reencrypted with the new key

        '''
        try:
            if(isinstance(attribute,str)):
                logging.info("Processing a Single Valued attribute")
                attribute_value = attribute
            elif(isinstance(attribute,list)):
                logging.info("Processing a Multi valued attribute")
                attribute_value = attribute[0]
            else:
                logging.info("Operation failed due to unsupported attribute type")
                return False
            if(is_json(attribute_value)):
                json_decoded_attribute_value = json.loads(attribute_value)
            elif(is_json(urlsafe_b64decode(attribute_value))):
                json_decoded_attribute_value = json.loads(urlsafe_b64decode(attribute_value))
            logging.info('verifying if the encryption is intact')
            try:
                otp_data = json_decoded_attribute_value['otpdata']
            except KeyError:
                logging.info("Cannot store Json decode value due to Exception KeyError")
                return False
            devices = otp_data['devices']
            #logging.info("devices are %s", devices)
            #iterate for each device and update the header
            for device,secret in devices.items():
                secret_list = secret.split('.')
                kid = secret_list[0]
                iv = secret_list[1]
                ciphertext = secret_list[2]

                kid = urlsafe_b64decode(kid)
                ciphertext = urlsafe_b64decode(ciphertext)
                iv = urlsafe_b64decode(iv)
                
                if(self.verify_certificate(kid) == True):
                    #generate the key
                    key = self.generate_symmetric_key(self.current_cert_path)
                    #reencrypt the entry with the new certificate
                    updated_info = self.update_header(iv, ciphertext, key)
                    devices[device] = updated_info
                else:
                    logging.error("error in validating the certificate")
                    return False
            try:    
                json_decoded_attribute_value['otpdata']['devices'] = devices
            except KeyError:
                logging.error("Error in storing the Updated value for entry %s due to KeyError", DN)
                return False
            return json_decoded_attribute_value 
        except Exception as er:
            logging.error("error in processing the encrypted data, skipping the entry %s due to error - %s", DN,type(er).__name__)
            return False 

    def update_header(self, init_vector, ciphertext, session_key):
        '''
        update the kid to the new certificate value and reencrypt the symmetric key with the newly generated symmetric key
        input-
        iv, ciphertext and sesion_key generated by the old certificate
        output-
        new encrypted secret using the new symmetric key
        '''
        decrypted_secret = self.crypt.decrypt(session_key, init_vector, ciphertext)
        new_session_key = self.generate_symmetric_key(self.new_cert_path)
        decrypted_secret = decrypted_secret.decode()
        #encrypt the secret
        encrypted_secret = self.crypt.encrypt(new_session_key, decrypted_secret, self.new_cert_path)
            
        #calculate the kid
        kid = self.crypt.hash_sha1(self.new_cert_path)
        kid = urlsafe_b64encode(kid).decode()

        #final value of the secret
        encrypted_secret = kid + '.' + encrypted_secret.decode()
        return encrypted_secret

    def verify_certificate(self, kid):
        '''
        verify the digest of the new certificate matches that of the old one 
        '''
        stored_digest = kid
        # find the hash of all the certificates
        for certificate in self.old_cert_path:
            if(stored_digest == self.crypt.hash_sha1(certificate)):
                self.current_cert_path = certificate
                return True
        return False    
    

    def convert_to_json_and_store(
        self, values_dict, DN
        ):
        '''
        convert the json to dictionary and store it in the active directory
        '''
        # if there are more than one keys
        
        json_to_store = json.dumps(values_dict)
        if(len(values_dict.keys())>1):
            json_to_store = urlsafe_b64encode(json_to_store.encode()).decode()
        
        #logging.debug("%s is the attribute",json_to_store)
        data_to_store = json_to_store
        self.target_value = ''
        updation_result = self.update(
            DN, self.otp_attribute, data_to_store
            )
        if(updation_result == True):
            logging.info("Update successful for %s ",DN)
        else:
            logging.error('could not update the user %s due to an error in writing to the AD\n',DN)
            return False
        return True

    def update(
        self, DN,
        attribute, new_value,
        ):
        '''
        update the value of the attribute to a new value 
        '''
        self.connection.modify(DN,
            {
                attribute:[(MODIFY_REPLACE,[new_value])]
            })
        if(self.connection.result['result']!=0):
            logging.error("error in modification %s",self.connection.result['description'])
            return False
            
        return True
        
    def generate_symmetric_key(self, cert_path):
        ''''
        generate the symmetric key using the hkdf algorithm
        returns 32 byte random key generated by using the user provided certificate
        '''
        try:
            cert_file_string = open(cert_path, "rt").read()
            ikm = ""
            salt = ""
            #logging.info("%s", cert_file_string)
            m = re.search(r"(?<=-----BEGIN CERTIFICATE-----).*?(?=-----END CERTIFICATE-----)", cert_file_string, flags=re.DOTALL)
            if m:
                ikm = m.group()
            ikm = ikm.replace('\n', '').replace('\r', '')
            ikm = ikm.encode('utf-8')
            #logging.info("PUBLIC KEY IS %s", ikm)

            #get the private key
            key = load_privatekey(FILETYPE_PEM, cert_file_string)
            key = dump_privatekey(FILETYPE_ASN1, key) 
            key = b64encode(key)
            salt = key
            #logging.info("PRIVATE KEY IS %s", salt)
             
        except Exception as er:
            logging.error("Error in processing the certificate, due to %s",type(er).__name__) 
            return False
        # generate the symmetric key
        info = b'' 
        symmetric_key = self.crypt.hkdf(32, salt, ikm)
        return symmetric_key
