from include import *
from Crypto.PublicKey import RSA
from OpenSSL.crypto import load_certificate, FILETYPE_PEM, load_privatekey, dump_publickey,dump_privatekey
from OpenSSL import crypto
import binascii
class ldapaccess:
    
    def __init__(
            self, host, 
            user_name, password,
            port,
            source_attribute,
            target_attribute,
            cert_path,
            encryption
            ):
        '''
        initialise the object with the 
        server credentials and 
        the admin username and
        password and other tool specific details
        '''
        self.host = host
        self.user_name = user_name
        self.password = password
        self.server = Server(self.host,port,True) 
        self.source_attribute = source_attribute
        self.target_attribute = target_attribute
        self.encryption = encryption
        self.cert_path = cert_path

    def connect(self):
        '''
        connect to the server using stored credentials
        input - from the user provided credentials
        output - connection instance handle
        '''
        try:
            self.connection = Connection(self.server, 
                self.user_name, self.password,
                auto_bind=True)
        except exceptions.LDAPSocketOpenError:
            logging.error('Unable to reach the server, please check the server credentials and port number ') 
            sys.exit()   
        except exceptions.LDAPExceptionError as er:
            logging.error("Could not connect to the server due to the following error %s",er)
            sys.exit()
        except:
            logging.error("Could not connect to the server due to an error\n")
            sys.exit()
        logging.info('Connection Successful..')
        return self.connection


    def retrieve_entries(
            self,search_base,
            search_filter = '(userParameters=*)',
            search_scope = SUBTREE,
            attributes = ALL_ATTRIBUTES,
            time_limit = 300, 
            ):
        '''
 
        search the AD for attributes, based on the the search filter and base
        input:
        search_base. search_filter and the time limit for the amount of time search operation can take

        output:
        True - search is successful
        False - unsuccesful search
        self.connection.response will contain all the attributes of all the retrieved users

        '''    
        # paged search wrapped in a generator
        try:
            entry_generator = self.connection.extend.standard.paged_search(search_base = search_base,
                                                            search_filter =search_filter,
                                                            search_scope = SUBTREE,
                                                            attributes = [self.source_attribute,self.target_attribute],
                                                            paged_size = 100,
                                                            generator = True)
        
            return entry_generator
        except:
            return False


    def convert_to_json_and_store(
        self, values_dict, DN
        ):
        '''
        convert the json to dictionary and store it in the active directory
        input - the dictionary of key value pairs to be stored in active directory
        output - 
        True - success
        False - failure
        '''
        if(self.target_value == "" or  self.target_value == "." or self.target_value == [] or self.target_attribute == self.source_attribute):
            
            output_dict = {}
            output_dict["otpdata"] = values_dict
            json_to_store = json.dumps(output_dict)
            data_to_store = json_to_store
            logging.info("data_to_store %s",data_to_store)
            self.target_value = ''
            updation_result = self.update(
                DN, self.target_attribute, data_to_store
                )
        elif(is_json(urlsafe_b64decode(self.target_value))):
            existing_attribute = urlsafe_b64decode(self.target_value)
            json_object = json.loads(existing_attribute)

            if("otpdata" not in json_object.keys()):
                json_object["otpdata"] = values_dict
                data_to_store = urlsafe_b64encode(json.dumps(json_object).encode('utf-8')).decode()
                updation_result = self.update(
                    DN, self.target_attribute, data_to_store,
                )
            else:
                updation_result = False
                logging.error("Attribute otp already exist\n")
        
        else:
            print("WARN, The target attribute already has some data, ignoring the case.")
            response = 'n' 
            if(response == 'y' or response == 'Y' or response == 'yes'):
                output_dict = {}
                output_dict["otpdata"] = values_dict
                json_to_store = json.dumps(output_dict)
                data_to_store = json_to_store
                #logging.info("data_to_store %s",data_to_store)
                self.target_value = ''
                updation_result = self.update(
                    DN, self.target_attribute, data_to_store
                    )
            else:
                updation_result = False

        if(updation_result == True):
            logging.info("Update successful for %s ",DN)
        else:
            logging.error('could not update the user %s due to an error in writing to the AD\n',DN)
            return False
        return True


    def update(
        self, DN,
        attribute, new_value,
        ):
        '''
        update the value of the attribute to a new value 

        input - attriute name and the new value
        output - 
        True - success
        False - failure

        '''

        self.connection.modify(DN,
            {
                attribute:[(MODIFY_REPLACE,[new_value])]
            })
        if(self.connection.result['result']!=0):
            logging.error("error in modification %s",self.connection.result['description'])
            return False
            
        return True

    def generate_symmetric_key(self,):
        ''''
        generate the symmetric key using the hkdf algorithm
        output - 32 byte symmetric key
        return False in case of an error
        '''
        try:
            cert_file_string = open(self.current_cert_path, "rt").read()
            
            # generate the public key
            #logging.info("opening pkey file %s ", self.current_cert_path )
            ikm = ""
            salt = ""
            m = re.search(r"(?<=-----BEGIN CERTIFICATE-----).*?(?=-----END CERTIFICATE-----)", cert_file_string, flags=re.DOTALL)
            if m:
                ikm = m.group()
            ikm = ikm.replace('\n', '').replace('\r', '')
            ikm = ikm.encode('utf-8')
            
            #generate the private key
            key = load_privatekey(FILETYPE_PEM, cert_file_string)
            key = dump_privatekey(crypto.FILETYPE_ASN1, key) 
            key = b64encode(key)
            salt = key
            
        except Exception as er:
            logging.error("Error in processing the certificate, due to %s",type(er).__name__) 
            return False

        info = b'' 
        #generate the symmetric key
        symmetric_key = self.crypt.hkdf(32, salt, ikm)
        return symmetric_key

    def encrypt_and_update(self,
        attribute_list, DN):
        '''
        take a users attribute,
        encrypt the OTPSecret

        input - 
        attribute list is the list of all devices in plaintext
        DN - distinguished name

        output - 
        True or False

        '''
        try:
            device_dict = {}
            for index in range(len(attribute_list)):
                device_list = attribute_list[index].split('=', 1)
                if(len(device_list) == 2):
                    device_name = device_list[0]
                    secret_data = device_list[1].split('&')
                    OTPSecret = secret_data[0]
                    #prepare the secret json
                    secret_dict = {}
                    secret_dict["secret"] = OTPSecret

                    for info in secret_data:
                        if('=' in list(info)):
                            info_key = info.split('=',1)[0]
                            info_item = info.split('=',1)[1]
                            secret_dict[info_key] = info_item

                    secret_json = json.dumps(secret_dict)
                    # manually populate the json
                    # buf_string = '{"secret":'
                    # buf_string += '"'+str(OTPSecret)+'"'
                    # buf_string += '}'
                    #secret_json = buf_string

                    #logging.info("secret jsons that is created after jsonify %s", secret_json)
                    symmetric_key = self.generate_symmetric_key() 
                    if(symmetric_key == False):
                        raise ValueError

   
                    #logging.info("base64 encoded Secret key %s", urlsafe_b64encode(symmetric_key))

                    #encrypt the secret
                    encrypted_secret = self.crypt.encrypt(symmetric_key, secret_json, self.cert_path)
                    

                    #calculate the kid
                    kid = self.crypt.hash_sha1(self.cert_path)
                    kid = urlsafe_b64encode(kid).decode()

                    #final value of the secret
                    encrypted_secret = kid + '.' + encrypted_secret.decode()   
                    device_dict[device_name] = encrypted_secret

                elif(index != len(attribute_list)-1):
                    logging.error("error in data: invalid format,\n skipping the entry for DN:%s", DN)
                    return False
            final_value = {}
            final_value["devices"] = device_dict
            #logging.debug("the final dict %s",final_value)
            output = self.convert_to_json_and_store(final_value, DN)
            return output
        except Exception as er:
           logging.error('could not update the user %s due to an error in updating %s\n',DN,type(er).__name__)
           return False
        return True        
 