"""
Copyright 2000-2025 Citrix Systems, Inc. All rights reserved.
This software and documentation contain valuable trade secrets
and proprietary property belonging to Citrix Systems, Inc.
None of this software and documentation may be copied,
duplicated or disclosed without the express written permission
of Citrix Systems, Inc.
"""

import socket
import json
import time
import sys
import os
##This is being added due to new cryptography module 3.2.1 in python which is not allowing
##Openssl 1.0.2 without setting it explicitly.
os.environ["CRYPTOGRAPHY_OPENSSL_NO_LEGACY"] = "1"
from random import SystemRandom
from urllib.parse import urlparse
import queue

from google_compute_engine import metadata_watcher
from googleapiclient import discovery
from oauth2client.client import GoogleCredentials
from google.cloud import pubsub_v1
import google.cloud.logging as gcp_logging

from rainman_core.common.logger import RainLogger
from rainman_core.common import rain
from rainman_core.common.base import base_cloud_driver
from rainman_core.common.rain import rainman_config, server, group, group_info, event
from rainman_core.common.exception import *
from rainman_core.common.stats import stats_config


RAINMAN_CONF_FILE = '/flash/nsconfig/rainman.conf'
REQUIRED_IAM_PERMS = [
    "compute.instances.get",
    "compute.instanceGroupManagers.get",
    "compute.instanceGroupManagers.list",
    "compute.zones.list",
    "logging.sinks.create",
    "logging.sinks.delete",
    "logging.sinks.get",
    "logging.sinks.list",
    "logging.sinks.update",
    "pubsub.subscriptions.consume",
    "pubsub.subscriptions.create",
    "pubsub.subscriptions.delete",
    "pubsub.subscriptions.get",
    "pubsub.topics.attachSubscription",
    "pubsub.topics.create",
    "pubsub.topics.delete",
    "pubsub.topics.get",
    "pubsub.topics.getIamPolicy",
    "pubsub.topics.setIamPolicy",
]
RAINMAN_IAM_NOT_OK_FILE = '/flash/nsconfig/.GCP/rainman_iam_not_ok'

config = rain.rainman_config()
local = config.get_local_config_service()
log = RainLogger.getLogger()

# Rainman Queue
rain_event_queue = queue.Queue(maxsize=0)

#########################################################################
### XXX FreeBSD 11.x: socket.getaddrinfo() prefers IPV6 by default on
###                   11.x. Wrap socket.getaddrinfo() to use IPV4 instead
old_getaddrinfo = socket.getaddrinfo
def new_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
    if family == socket.AF_UNSPEC:
        family = socket.AF_INET
    return old_getaddrinfo(host, port, family, type, proto, flags)
socket.getaddrinfo = new_getaddrinfo
##########################################################################

def rand_hex_8():
    return "%08x" % int(SystemRandom().getrandbits(32))

def rand_hex_32():
    return rand_hex_8()+rand_hex_8()+rand_hex_8()+rand_hex_8()

def json_loads_to_ascii(json_text):
    return _to_ascii(
        json.loads(json_text, object_hook=_to_ascii),
        ignore_dicts=True
    )

def _to_ascii(data, ignore_dicts=False):
    if isinstance(data, str):
        return data.encode('utf-8')
    if isinstance(data, list):
        return [_to_ascii(item, ignore_dicts=True) for item in data]
    if isinstance(data, dict) and not ignore_dicts:
        return {
            _to_ascii(key, ignore_dicts=True): _to_ascii(value, ignore_dicts=True)
            for key, value in list(data.items())
        }
    return data

def ns_exception_handler(func):
    # Decorator to handle exceptions
    def wrapper(*args, **kwargs):
        res = None
        try:
            res = func(*args, **kwargs)
        except Exception as e:
            log.error(f"{func.__name__} failed with error ->\n{e}")
        return res
    return wrapper

class gcp_config(base_cloud_driver):
    cur_configured_groups = []
    webhook_key = ""
    previous_webhook_key = ""
    webhook_expire_duration = 5 * 60

    def __init__(self):
        self.log_folder = "/flash/nsconfig/.GCP/"
        self.last_iam_check_time = 0
        self.get_self_info()
        self.ps_initialized = False
        self.ps_retry = 0
        self.init_pubsub()
        self.SERVER_STANDBY_STR = "TODOTODO"
        self.SERVER_ACTIVE_STR = "TODOTODO"

    def is_authenticated(self):
        return True

    def get_self_info(self):

    # Format of data:
        # {u'instance': {u'attributes': {},
        #       u'cpuPlatform' xxxxx,
        #       u'description' xxxxx,
        #       u'disks': [{u'deviceName' xxxxx,
        #                   u'index' xxxxx,
        #                   u'mode' xxxxx,
        #                   u'type' xxxxx,
        #       u'guestAttributes' xxxxx,
        #       u'hostname' xxxxx,
        #       u'id' xxxxx,
        #       u'image' xxxxx,
        #       u'machineType' xxxxx,
        #       u'maintenanceEvent' xxxxx,
        #       u'name' xxxxx,
        #       u'networkInterfaces': [{u'accessConfigs': [{u'externalIp' xxxxx,
        #                                                   u'type' xxxxx,
        #                               u'dnsServers': [u'169.254.169.254'],
        #                               u'forwardedIps': [],
        #                               u'gateway' xxxxx,
        #                               u'ip' xxxxx,
        #                               u'ipAliases': [],
        #                               u'mac' xxxxx,
        #                               u'mtu' xxxxx,
        #                               u'network' xxxxx,
        #                               u'subnetmask' xxxxx,
        #                               u'targetInstanceIps': []},

        data = metadata_watcher.MetadataWatcher().GetMetadata()
        self.project = data["project"]["projectId"]
        self.nics = data["instance"]["networkInterfaces"]
        self.instance_name = data["instance"]["name"]
        log.debug("Self info gathered")

    def init_pubsub(self):
        # Initialize PubSub modules and variables
        self.ps_publisher = pubsub_v1.PublisherClient()
        self.ps_subscriber = pubsub_v1.SubscriberClient()
        self.ps_logging_client = gcp_logging.Client()
        self.ps_streaming_pull_future = None
        self.ps_sink = None
        self.ps_topic = f"{self.instance_name}_netscaler_pubsub"
        self.ps_topic_path = self.ps_publisher.topic_path(self.project, self.ps_topic)
        self.ps_subscription_path = self.ps_subscriber.subscription_path(self.project, self.ps_topic)
        self.ps_sink_destination = f"pubsub.googleapis.com/{self.ps_topic_path}"

    def get_self_ips(self):
        ip_addresses = []
        if self.nics:
            for nic in self.nics:
                ip_addresses.append(nic["ip"])
        else:
            log.error(" No NICs configured" )

        return ip_addresses

    def get_instance_freeips(self, mac=None):
        return self.get_self_ips()

    def get_instance_name(self, instance_url):
        if instance_url is None:
              return
        try:
              parsed_url = urlparse(instance_url)
              path = parsed_url.path
              components = path.split("/")
        except:
              log.debug("parsing instance url failed")

        instance_name = components[len(components)-1]
        return  instance_name

    def check_iam(self):
        """ Check instance IAM. Returns the Missing IAMs """

        credentials = GoogleCredentials.get_application_default()
        service = discovery.build('cloudresourcemanager', 'v1', credentials=credentials, cache_discovery=False)

        test_permissions_request = {
            "permissions": REQUIRED_IAM_PERMS
        }

        request = service.projects().testIamPermissions(resource=self.project,
                                                        body=test_permissions_request)
        try:
            response = request.execute()
        except Exception as err:
            log.error("No IAM permissions for testIamPermissions API")
            log.error("Exception %s", str(err))
            return (set(REQUIRED_IAM_PERMS))
        log.debug(response)
        #return missing IAMs
        return (set(REQUIRED_IAM_PERMS) - set(response["permissions"]))

    def get_servers_in_group(self, asgroup):
        result = []
        credentials = GoogleCredentials.get_application_default()
        service = discovery.build('compute', 'v1', credentials=credentials, cache_discovery=False)

        if asgroup is None:
           log.warn("Empty group name provided while looking for ig servers")
           return result
        try:
            request = service.instanceGroupManagers().listManagedInstances(
                    project=self.project, zone=asgroup.zone,
                    instanceGroupManager=asgroup.name)
            response = request.execute()
            instance_list = response.get("managedInstances")
            if not instance_list:
                return result
            for instance in instance_list:
                if instance.get("instanceStatus", "NOT_FOUND") != "RUNNING":
                    # Skip NON RUNNING instances
                    continue
                url = instance["instance"]
                serv_name = self.get_instance_name(url)
                request = service.instances().get(project=self.project, zone=asgroup.zone, instance=serv_name)
                instance_details = request.execute()
                ipblock = instance_details["networkInterfaces"]
                this_server = server()
                this_server.ip = ipblock[0]["networkIP"]
                this_server.name = serv_name
                this_server.id = instance_details["id"]
                result.append(this_server)
        except Exception as e:
            log.error(f'Unable to get ig servers, Exception seen : {e!r}')
        return result

    def get_group_info(self, group_name):
        result = []
        try:
            autoscaling_groups = self.get_autoscaling_groups(group_name)
            for asgroup in autoscaling_groups:
                this_group = group_info()
                this_group.zone = asgroup["zone"].split("zones/")[1]
                this_group.name = asgroup["name"]
                result.append(this_group)
        except Exception as e:
            log.error('Unable to query for autoscale groups :  %s' % (str(e)))
            raise

        return result

    def get_zone_list(self):
        zones = []
        # Get the list of the zones for this project

        credentials = GoogleCredentials.get_application_default()
        service = discovery.build('compute', 'v1', credentials=credentials, cache_discovery=False)
        request = service.zones().list(project=self.project)
        while request is not None:
           response = request.execute()
           for zone in response['items']:
               zones.append(zone["name"])
           request = service.zones().list_next(previous_request=request, previous_response=response)
        return zones

    def get_autoscaling_groups(self, group_name):
        autoscalegroups = []
        credentials = GoogleCredentials.get_application_default()
        service = discovery.build('compute', 'v1', credentials=credentials, cache_discovery=False)

        zones = self.get_zone_list()
        try:
           for zone in zones:
                # GCP doesn't allow same named instance group across zones
                request = service.instanceGroupManagers().list(project=self.project, zone=zone)
                while request is not None:
                    response = request.execute()
                    if 'items' in list(response.keys()):
                        for asgroup in response['items']:
                            if group_name:
                                if asgroup["name"] == group_name:
                                    log.debug("zone found for %s is %s" % (group_name, zone))
                                    autoscalegroups.append(asgroup)
                                    return autoscalegroups
                            else:
                                autoscalegroups.append(asgroup)
                    request = service.instanceGroupManagers().list_next(previous_request=request,
previous_response=response)
        except Exception as err:
              log.error('Unable to query for autoscale groups')
              log.error("Exception : %s", str(err))

        return autoscalegroups

    def cleanup_on_group_remove(self, group_name): 
        return

    def remove_notification_config_from_group(self, queue_conf, group_name):
        return

    def remove_event_queue(self, queue_conf):
        return

    def get_cloud_platform(self):
        return "GCP"

    def validate_intf_count(self, intf_count):
        #Allow always.
        return (intf_count < 1)

    def get_ftu_filename(self):
        return '/flash/nsconfig/.GCP/ftumode'


    def get_event_queue_details(self):
        try:
            group_names = []
            with open(RAINMAN_CONF_FILE) as f:
                groups = json.load(f)
                for i in range(0, len(groups["groups"])):
                    group_names.append(groups["groups"][i].get("name"))
        except IOError:
            log.error("File not accessible")
        finally:
            f.close()
        return group_names

    def add_event_queue(self, queue_conf):
        return

    def check_event_queue(self):
        """
        Returns True if cloud supports configuring event queues (PubSub for GCP).
        """
        return True

    def configure_events_for_group(self, queue_conf, group):
        """
        Configure PubSub
        """
        return self.configure_events_for_groups(queue_conf, [group.name])

    def get_daemon_pid_file(self):
        return '/flash/nsconfig/.GCP/rain_scale.pid'

    @ns_exception_handler
    def ps_validate(self):
        """ Ensure resources required by Pub/Sub have been deployed on cloud """
        self.ps_retry = self.ps_retry + 1
        topic = self.ps_publisher.get_topic(request={"topic": self.ps_topic_path})
        subscription = self.ps_subscriber.get_subscription(request={"subscription": self.ps_subscription_path})
        if self.ps_sink.exists() and topic!= None and subscription!= None:
            log.info("Pub/Sub initialized")
            self.ps_initialized = True
            self.ps_retry = 0

    @ns_exception_handler
    def configure_events_for_groups(self, queue_conf=None, group_names=None):
        """
        Configure PubSub
        """
        # Check IAM Permissions
        self.check_iam_perm()
        # Create required resources, if not created
        if self.ps_initialized is False and (self.ps_retry < 3 and not os.path.exists(RAINMAN_IAM_NOT_OK_FILE)):
            self.init_pubsub()
            self.ps_sink = self.ps_logging_client.sink(self.ps_topic, filter_="AutoscalerSizeChangeExplanation", destination=self.ps_sink_destination)
            self.ps_create_topic(self.ps_topic_path)
            self.ps_create_subscription(self.ps_topic_path, self.ps_subscription_path)
            self.ps_create_sink(self.ps_sink, self.ps_topic_path)
            self.ps_validate()

    def update_rainman_iam_file(self):
            missing_perm = self.check_iam()
            if (len(missing_perm) != 0):
                log.error("IAM check failed. Please provide these permissions and check Cloud Resource Manager API is enabled for your project: %s", missing_perm)
                try:
                    with open(RAINMAN_IAM_NOT_OK_FILE,'w') as iam_file :
                        iam_file.write(json.dumps({"instance_iam":list(missing_perm)}))
                except Exception as e:
                    log.error("Error in creating file: %s" % str(e))
            elif (os.path.exists(RAINMAN_IAM_NOT_OK_FILE)):
                # IAM permissions are proper, so we can retry Pub/Sub Initialization
                self.ps_retry = 0
                try:
                    os.remove(RAINMAN_IAM_NOT_OK_FILE)
                except Exception as e:
                    log.error("Error in removing file: %s" % str(e))
            return missing_perm

    def get_min_servers_in_group(self, group_name):
        return 0

    def remove_server_from_group(self, server, group):
        return

    def message_to_event(self, message):
        """
        Convert recieved msg to event object.
        """
        MAX_POLLS = 15
        POLL_INTERVAL = 10 # secs

        _group_info = group_info()
        _group_info.zone = message["zone"]
        _group_info.name = message["mig_group"]

        cloud_servers = self.get_servers_in_group(_group_info)

        for _ in range(MAX_POLLS):
            # Polling for change in size for MAX_POLL times
            if message["new_size"] == len(cloud_servers):
                break
            time.sleep(POLL_INTERVAL)
            cloud_servers = self.get_servers_in_group(_group_info)
            log.debug(f"cloud_servers {' '.join([x.name for x in cloud_servers])}")

        if message["old_size"] > message["new_size"]:
            # Scale down
            return event('ALARM', message["mig_group"], None, 'drain')
        elif message["old_size"] < message["new_size"]:
            # Scale up
            return event('LAUNCH', message["mig_group"], None, 'sync')
        return None

    def get_events_from_queue(self, queue_conf):
        """
        Returns List of events for all queued up msgs.
        """
        messages = []
        while True:
            try:
                message = rain_event_queue.get(False)
                log.debug(f"Found msg {message}")
            except queue.Empty:
                log.debug("Event queue is empty")
                break
            else:
                messages.append(message)
        events = []
        for message in messages:
            log.debug(f"Creating event for message: {message}")
            try:
                event = self.message_to_event(message)
                if event:
                    events.append(event)
            except:
                log.debug("Not able to get event")
        return events

    def start_pub_sub(self, queue_conf):
        """
        Start PubSub listening channel
        """
        log.info("Start Streaming...")
        self.configure_events_for_groups()
        time.sleep(3)
        self.ps_start_streaming_pull_future(queue_conf)

    def stop_pub_sub(self, queue_conf):
        """
        Kill Streaming pull and cleanup PubSub resources
        """
        if not self.ps_streaming_pull_future:
            return
        self.ps_streaming_pull_future.cancel()
        self.ps_streaming_pull_future.result()
        self.ps_cleanup()
        self.ps_subscriber.close()

    def check_iam_perm(self):
        try:
            self.handle_node_state()
            # Check IAM permissions every 10 mins.
            if (self.last_iam_check_time < time.time() - 10*60):
                self.last_iam_check_time = time.time()
                missing_perm = self.update_rainman_iam_file()
                log.info("IAM check performed")
                return missing_perm
        except Exception as e:
                log.error(f'IAM check failed with exception {e!r}')

    def handle_node_state(self):
        nodestate = local.get_node_config()
        log.debug("hanode: %s" % nodestate)
        if not (nodestate in ['Primary', 'CCO', 'StandAlone']):
            log.info("not primary, Exiting the process...")
            exit(9)

    @ns_exception_handler
    def ps_create_topic(self, ps_topic_path):
        log.debug(f"Creating topic: {ps_topic_path}")
        try:
            self.ps_publisher.create_topic(request={"name": ps_topic_path})
            log.info(f"Topic created: {ps_topic_path}")
        except Exception as e:
            log.debug(f"Error when creating PubSub topic: {e}")

    @ns_exception_handler
    def ps_create_subscription(self, ps_topic_path, ps_subscription_path):
        log.debug(f"Creating subscription: {ps_subscription_path}")
        subscription = self.ps_subscriber.create_subscription(
            request={"name": ps_subscription_path, "topic": ps_topic_path}
            )
        log.info(f"Subscription created: {ps_subscription_path}")


    @ns_exception_handler
    def ps_create_sink(self, ps_sink, ps_topic_path):
        if ps_sink.exists():
            log.debug(f"Sink {ps_sink.name} already exists.")
            return

        log.info(f"Creating sink: {ps_sink.name}")
        ps_sink.create()
        self.ps_set_topic_policy(ps_topic_path, "roles/pubsub.publisher", [ps_sink.writer_identity])
        ps_sink.reload()
        log.info(f"Sink created: {self.ps_sink.name}")

    @ns_exception_handler
    def ps_set_topic_policy(self, ps_topic_path, role, members):
        policy = self.ps_publisher.get_iam_policy(request={"resource": ps_topic_path})
        policy.bindings.add(role=role, members=members)
        policy = self.ps_publisher.set_iam_policy(request={"resource": ps_topic_path, "policy": policy})
        log.debug(f"IAM policy for topic {ps_topic_path} set: {policy}")

    @ns_exception_handler
    def ps_start_streaming_pull_future(self, queue_conf):

        @ns_exception_handler
        def callback(message: pubsub_v1.subscriber.message.Message) -> None:
            log.debug(f"Received {message}")
            try:
                msg_data_proto = json.loads(message.data.decode('utf-8'))["protoPayload"]
                msg_data_resource = json.loads(message.data.decode('utf-8'))["resource"]
                mig_group = msg_data_resource["labels"]["instance_group_manager_name"]
                configured_groups = self.get_event_queue_details()
                if mig_group in configured_groups:
                    mig_group_zone = msg_data_resource["labels"]["location"]
                    new_size = int(msg_data_proto["metadata"]["newSize"])
                    old_size = int(msg_data_proto["metadata"]["oldSize"])
                    event = {"mig_group": mig_group, "zone": mig_group_zone, "new_size": new_size, "old_size": old_size}
                    log.debug(f"Putting event in queue-> {event}")
                    rain_event_queue.put(event)
                else:
                    log.info("Ignoring as MIG is not configured as ASG on ADC")
            except Exception as e:
                log.error(f"Skipping message. {e}")
            message.ack()
            log.info(f"Acknowledged {message.message_id}.")
        self.ps_streaming_pull_future = self.ps_subscriber.subscribe(self.ps_subscription_path, callback=callback)
        log.info(f"Listening for messages on {self.ps_subscription_path}...")

    def ps_cleanup(self):
        log.info("Starting GCP PubSub cleanup.")
        self.ps_initialized = False
        self.ps_delete_topic(self.ps_topic_path)
        self.ps_delete_subscription(self.ps_subscription_path)
        self.ps_delete_sink(self.ps_sink)
        log.info("GCP PubSub cleanup is complete.")

    @ns_exception_handler
    def ps_delete_topic(self, ps_topic_path):
        log.info(f"Cleaning topic {ps_topic_path}")
        self.ps_publisher.delete_topic(request={"topic": ps_topic_path})
        log.info(f"Deleted topic {ps_topic_path}")

    @ns_exception_handler
    def ps_delete_subscription(self, ps_subscription_path):
        log.info(f"Cleaning subscription {ps_subscription_path}")
        self.ps_subscriber.delete_subscription(request={"subscription": ps_subscription_path})
        log.info(f"Deleted subscription {ps_subscription_path}")

    @ns_exception_handler
    def ps_delete_sink(self, ps_sink):
        if ps_sink.exists():
            log.info(f"Cleaning sink {ps_sink.name}")
            ps_sink.delete()
            log.info(f"Deleted sink {ps_sink.name}")

    def check_privileges(self, feature):
        pass
