#!/usr/bin/python
########################################################################################
#
#  Copyright (c) 2018-2022 Citrix Systems, Inc.
#  All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#      * Redistributions of source code must retain the above copyright
#        notice, this list of conditions and the following disclaimer.
#      * Redistributions in binary form must reproduce the above copyright
#        notice, this list of conditions and the following disclaimer in the
#        documentation and/or other materials provided with the distribution.
#      * Neither the name of the Citrix Systems, Inc. nor the
#        names of its contributors may be used to endorse or promote products
#        derived from this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
#  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
#  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL CITRIX SYSTEMS, INC. BE LIABLE FOR ANY
#  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
#  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
#  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
#  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
#  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
########################################################################################

import json
import os
import re
import signal
import copy
import socket

from google_compute_engine import metadata_watcher
from google_compute_engine import logger
from googleapiclient import discovery
from oauth2client.client import GoogleCredentials
from libnitrocli import nitro_cli, nitro_cli_output_parser


CLOUDHADAEMON_PID = "/var/run/nscloudhaagent.pid"

REQUIRED_IAM_PERMS = [
    "compute.addresses.list",
    "compute.addresses.get",
    "compute.addresses.use",
    "compute.forwardingRules.create",
    "compute.forwardingRules.delete",
    "compute.forwardingRules.get",
    "compute.forwardingRules.list",
    "compute.instances.use",
    "compute.subnetworks.use",
    "compute.targetInstances.create",
    "compute.targetInstances.list",
    "compute.targetInstances.use",
]

#########################################################################
### XXX FreeBSD 11.x: socket.getaddrinfo() prefers IPV6 by default on
###                   11.x. Wrap socket.getaddrinfo() to use IPV4 instead
old_getaddrinfo = socket.getaddrinfo
def new_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
    if family == socket.AF_UNSPEC:
        family = socket.AF_INET
    return old_getaddrinfo(host, port, family, type, proto, flags)
socket.getaddrinfo = new_getaddrinfo
##########################################################################


gcp_supported_protocols = ["ah", "esp", "icmp", "sctp", "tcp", "udp"]

# This class provides utility for configuring addition VIPs on ADC on GCP.
class vipScaling():

    def __init__(self, logger):
        self.prev_vips_input = []
        self.cur_vips_input = []
        self.delete_bucket = []
        self.add_bucket = []
        self.forwarding_rules_list = []
        self.forwarding_rules_list_local = []
        self.__nitro = nitro_cli(timeout=500)
        self.__parser = nitro_cli_output_parser()
        data = metadata_watcher.MetadataWatcher().GetMetadata()
        # Project ID for this request.
        self.project = data["project"]["projectId"]
        # The name of the zone for this request.
        self.zone = data["instance"]["zone"].split("zones/")[1]
        # Name of the region scoping this request.
        self.region = self.zone[:-2]
        # Name of the instance resource to return.
        self.instance = data["instance"]["name"]
        self.target_instance_name = self.instance+"-adcinternal"
        self.logger = logger
        # Update the list of self created forwarding rules.
        self._populate_adc_owned_fw_rules()

    def handle_vip_changes(self, result):
        try:
            nodestate = self._get_node_state()
            if (nodestate in ['CCO', 'CL Node', 'Secondary', 'None']):
                self.logger.info("Muting operations on Cluster/Secondary.")
                self.forwarding_rules_list = []
                self.prev_vips_input = []
                return
            if ("vips" in result["instance"]["attributes"]):
                missing_iam_perm = self._check_iam()
                if(missing_iam_perm):
                    self.logger.error("IAM permissions missing [%s]", missing_iam_perm)
                #Create target-instance required for forwarding rules.
                self._create_target_instance()
                self._populate_adc_owned_fw_rules()
                self.cur_vips_input = []
                self.logger.debug("Forwarding_rules created by ADC: %s"% self.forwarding_rules_list)
                requested_vips = result["instance"]["attributes"]["vips"]
                if not (self._validate_input(requested_vips)):
                    self.logger.info("Please correct the entry and submit again.")
                    self.logger.info("Retaining the previous request.")
                    self.cur_vips_input = self.prev_vips_input
                    self.logger.debug("cur_inp [%s]; prev_inp [%s]"% (self.cur_vips_input, self.prev_vips_input))
                    return

                """
                Expected input is in JSON format with key as eip and value as list of procols.
                for e.g. {
                            "eip-2":["TCP", "UDP"],
                            "eip-1":["ICMP", "TCP", "AH"]
                            }
                        * EIPs should be external and not in use. Please note that EIP's name is expected and not the
                            address itself.
                        * Protocols must be one of the GCP supported [AH, ESP, ICMP, SCTP, TCP, UDP]
                Error condition and behaviour:
                    * Incorrect/Internal/In-use EIP or Address is given: Reject the entire entry,
                        customer needs to correct and re-submit.
                    * Incorrect protocol(s): Reject the entire entry, customer needs to correct and re-submit.
                    * Correct but duplicate entries: Create forwarding rules once.
                """
                if (self.cur_vips_input != self.prev_vips_input):
                    # Change in metadata detected.
                    self._update_buckets()
                    if (self.delete_bucket):
                        # Requested for deletion of rules, proceed to delete.
                        self._delete_forwarding_rules()
                    if (self.add_bucket):
                        # Requested for addition of rules, proceed to add.
                        self._create_forwarding_rules()
                    self._delete_stale_forwarding_rules()
                    self.prev_vips_input = self.cur_vips_input
            else:
                # No metadata found, delete all created by us.
                self._delete_all_cloud_forwarding_rules()
                self.prev_vips_input = []
            return
        except Exception as e:
            self.logger.error("Error: %s"%str(e))
            return

    def _check_iam(self):
        """ Check instance IAM. Returns the Missing IAMs """
        """ Requires one of the following OAuth scopes:

            https://www.googleapis.com/auth/cloud-platform
            https://www.googleapis.com/auth/cloud-platform.read-only
            https://www.googleapis.com/auth/cloudplatformprojects
            https://www.googleapis.com/auth/cloudplatformprojects.readonly
            https://www.googleapis.com/auth/iam.test
        """

        credentials = GoogleCredentials.get_application_default()
        service = discovery.build('cloudresourcemanager', 'v1', credentials=credentials, cache_discovery=False)

        test_permissions_request = {
            "permissions": REQUIRED_IAM_PERMS
        }

        request = service.projects().testIamPermissions(resource=self.project,
                                                        body=test_permissions_request)
        try:
            response = request.execute()
        except Exception as err:
            self.logger.error("No IAM permissions for testIamPermissions API")
            self.logger.error("Exception %s"% str(err))
            return (set(REQUIRED_IAM_PERMS))
        #return missing IAMs
        return (set(REQUIRED_IAM_PERMS) - set(response["permissions"]))

    def _update_buckets(self):
        # Parsing the metadata request and update the buckets accordingly.
        self.logger.debug("Updating the buckets.")
        self.delete_bucket = [i for i in (self.cur_vips_input + self.prev_vips_input) if i not in self.cur_vips_input] 
        self.add_bucket = [i for i in (self.cur_vips_input + self.prev_vips_input) if i not in self.prev_vips_input] 
        self.logger.debug("delete_bucket: %s; add_bucket: %s" % (self.delete_bucket, self.add_bucket))

    def _create_target_instance(self):
        # A target-instance is required fot forwarding rules,
        # here all the rules will have a single target-instance.

        # IAM: compute.instances.use

        try:
            if not (self._check_target_instance(self.target_instance_name)):
                # If target-instance is not present, create one.
                credentials = GoogleCredentials.get_application_default()
                service = discovery.build('compute', 'v1', credentials=credentials, cache_discovery=False)

                target_instance_body = {
                   "name": self.target_instance_name,
                    "description": "target instance created fot Citrix ADC",
                    "zone": self.zone,
                    "natPolicy": "NO_NAT",
                    "instance": "zones/"+self.zone+"/instances/"+self.instance,
                }

                request = service.targetInstances().insert(project=self.project, zone=self.zone, body=target_instance_body)
                response = request.execute()

                self.logger.info("target-instance %s created"% self.target_instance_name)
            else:
                self.logger.info("target-instance %s already exists"% self.target_instance_name)
        except Exception as e:
            self.logger.error("Error: %s"%str(e))

    def _check_target_instance(self, name):
        # Helper function of _create_target_instance() 
        try:
            credentials = GoogleCredentials.get_application_default()
            service = discovery.build('compute', 'v1', credentials=credentials, cache_discovery=False)
            request = service.targetInstances().list(project= self.project, zone= self.zone)
            while request is not None:
                response = request.execute()
                for target_instance in response['items']:
                    if target_instance['instance'].split('/')[-1] == self.instance and target_instance['selfLink'].split('/')[-1] == name:
                        self.logger.info("ADC Target Instance exists")
                        return True
                request = service.targetInstances().list_next(previous_request=request, previous_response=response)
        except Exception as e:
            self.logger.error("ADC Target instance info could not be retrieved: %s", e)
        return False
    def _list_reserved_eips(self):
        eips = []
        try:
            credentials = GoogleCredentials.get_application_default()
            service = discovery.build('compute', 'v1', credentials=credentials, cache_discovery=False)

            request = service.addresses().list(project=self.project, region=self.region)
            while request is not None:
                response = request.execute()
                for address in response['items']:
                    if (address["addressType"] == "EXTERNAL"):
                        eips.append(address["name"])
                request = service.addresses().list_next(previous_request=request, previous_response=response)
        except Exception as e:
            self.logger.error("Error: %s"%str(e))
        finally:
            return eips

    def _populate_adc_owned_fw_rules(self):
        # A list is maintained for the rules created by us, this function updates the list.

        #IAM - compute.forwardingRules.get
        
        self.forwarding_rules_list = []
        try:
            credentials = GoogleCredentials.get_application_default()
            service = discovery.build('compute', 'v1', credentials=credentials, cache_discovery=False)

            request = service.forwardingRules().list(project=self.project, region=self.region)
            while request is not None:
                response = request.execute()
                for forwarding_rule in response['items']:
                    # all the rules are created in this format "<proto>-<EIP>-nsinsternal"
                    if (re.search("^.*adcinternal$", forwarding_rule["name"])
                        and (self.target_instance_name == forwarding_rule["target"].split("targetInstances/")[1])):
                        self.forwarding_rules_list.append(forwarding_rule["name"])

                request = service.forwardingRules().list_next(previous_request=request, previous_response=response)
        except Exception as e:
            self.logger.error("Warning(%s): No forwarding rules found."%str(e))

    def _create_forwarding_rule(self, eip, protocol, name):
        #IAM -  compute.forwardingRules.create
        #    -  compute.subnetworks.use
        #    -  compute.targetInstances.use

        try:
            credentials = GoogleCredentials.get_application_default()
            service = discovery.build('compute', 'v1', credentials=credentials, cache_discovery=False)
            target = "https://www.googleapis.com/compute/v1/projects/"+self.project+"/zones/"+self.zone+"/targetInstances/"+self.instance+"-adcinternal"

            forwarding_rule_body = {
                "name": name,
                "region": self.region,
                "IPAddress": eip,
                "IPProtocol": protocol,
                "target": target,
            }

            request = service.forwardingRules().insert(project=self.project, region=self.region, body=forwarding_rule_body)
            response = request.execute()
            self.forwarding_rules_list.append(name)
            self.forwarding_rules_list_local.append(name)
        except Exception as e:
            self.logger.error("Error: %s"%str(e))

    def _create_forwarding_rules(self):
        for eip_d in self.add_bucket:
            for eip, proto in eip_d.items():
                name, address = self._get_name(eip, proto)
                if (name in self.forwarding_rules_list):
                    self.forwarding_rules_list_local.append(name)
                else:
                    self.logger.info("Creating forwarding rule %s" % name)
                    self._create_forwarding_rule(address, proto, name)

    def _delete_forwarding_rule(self, name):
        #IAM -  compute.forwardingRules.delete

        try:
            credentials = GoogleCredentials.get_application_default()
            service = discovery.build('compute', 'v1', credentials=credentials, cache_discovery=False)

            request = service.forwardingRules().delete(project=self.project, region=self.region, forwardingRule=name)
            response = request.execute()

            self.forwarding_rules_list.remove(name)
            self.forwarding_rules_list_local.remove(name)
        except Exception as e:
            self.logger.error("Error: %s"%str(e))

    def _delete_forwarding_rules(self):
        for eip_d in self.delete_bucket:
            for eip, proto in eip_d.items():
                name, address = self._get_name(eip, proto)
                if (self._check_forwarding_rule(name)):
                    self.logger.info("Deleting forwarding rule %s" % name)
                    self._delete_forwarding_rule(name)

    def _delete_all_cloud_forwarding_rules(self):
        if (self.forwarding_rules_list):
            self.logger.debug("Deleting forwarding rules created by ADC.")
            rules_to_be_deleted = copy.deepcopy(self.forwarding_rules_list)
            for fw_rule in rules_to_be_deleted:
                self._delete_forwarding_rule(fw_rule)
        return

    def _delete_stale_forwarding_rules(self):
        stale_rules_to_be_deleted = set(self.forwarding_rules_list) - set(self.forwarding_rules_list_local)
        if (stale_rules_to_be_deleted):
            self.logger.info("deleting stale forwarding rules: %s"% stale_rules_to_be_deleted)
            for fw_rule in stale_rules_to_be_deleted:
                self._delete_forwarding_rule(fw_rule)
        return

    def _check_forwarding_rule(self, name):
        try:
            credentials = GoogleCredentials.get_application_default()
            service = discovery.build('compute', 'v1', credentials=credentials, cache_discovery=False)
            request = service.forwardingRules().get(project=self.project, region=self.region, forwardingRule=name)
            response = request.execute()
            if(self.target_instance_name == response["target"].split("targetInstances/")[1]):
                return True
            return False
        except Exception as e:
            self.logger.error("Error: %s"%str(e))
            return False

    def _get_name(self, eip, proto):
        try:
            credentials = GoogleCredentials.get_application_default()
            service = discovery.build('compute', 'v1', credentials=credentials, cache_discovery=False)
            request = service.addresses().get(project=self.project, region=self.region, address=eip)
            response = request.execute()
            address = response["address"]
            return (proto.lower() + "-" + address.replace(".", "-") + "-adcinternal"), address
        except Exception as e:
            self.logger.error("Error: %s"%str(e))
            return ""

    def _validate_input(self, user_input):
        try:
            requested_vips = json.loads(user_input)
            reserved_ips = self._list_reserved_eips()
            if(self._validate_eip_proto(requested_vips, reserved_ips)):
                for key, value in requested_vips.items():
                    for proto in value:
                        self.cur_vips_input.append({key: proto.lower()})
            else:
                return False
            return True
        except Exception as e:
            self.logger.error("JSON load error [%s], please check "%str(e))
            return False

    def _validate_eip_proto(self, requested_vips, reserved_ips):
        result = True
        for eip, value in requested_vips.items():
            if not (eip in reserved_ips):
                # Check if EIP is External
                self.logger.error("Cannot find EIP: %s"% eip)
                self.logger.error("Incorrect entry [%s: %s]"% (eip, value))
                result = False
                break
            for proto in value:
                # Check if protocol is one of gcp supported
                if not (proto.lower() in gcp_supported_protocols):
                    self.logger.error("Incorrect protocol: %s for eip [%s]"% (proto, eip))
                    result = False
                    break
        return result

    def _get_node_state(self):
        """ Get the nsconfig and the system type """
        CCO_NODE = 0x00000040
        MASTER_NODE = 0x00000004
        out = self.__nitro.get_nsconfig()
        if self.__parser.success(out) and u'nsconfig' in out:
            node_cco = str(out[u'nsconfig'])
            if node_cco == 'NON-CCO':
                return 'CL Node'
            ns_systemtype = str(out[u'nsconfig'][u'systemtype'])
            nc_flags = int(out[u'nsconfig'][u'flags'])
            if ns_systemtype == 'Cluster':
                return ('CCO' if (nc_flags & CCO_NODE) == CCO_NODE else 'CL Node')
            elif ns_systemtype == 'HA':
                return ('Primary' if (nc_flags & MASTER_NODE) == MASTER_NODE else 'Secondary')
            return 'StandAlone'
        return 'None'

