"""
Copyright (c) 2023-2024 Cloud Software Group, Inc. All Rights Reserved. Confidential & Proprietary.
"""

"""
Utility functions for Azure
"""
import logging
import os
import time
import concurrent.futures
import requests

from src.ns_ha import HaStatus
from src.ns_cloud_ha_daemon import ENABLE_DEBUG_FILE
from lib_ns_cloud.azure import imds
from lib_ns_cloud.azure import api

#########################################################################
### XXX FreeBSD 11.x: requests prefers IPV6 by default on
###                   11.x. Force requests to use IPV4 instead
requests.packages.urllib3.util.connection.HAS_IPV6 = False
##########################################################################

log = logging.getLogger(__name__)

if not os.path.exists(ENABLE_DEBUG_FILE):
    # Suppress verbose logging for Azure SDK
    logging.getLogger("azure").setLevel(logging.ERROR)

API_NOT_WORKING_FILE = '/flash/nsconfig/.AZURE/api_not_working'

class AzureInstanceInfo:
    """Contains Cloud-specific API. Must implement
    update_ha_info, handle_failover, and __init__"""

    def __init__(self, ns_obj):
        self.peer_ip = ns_obj.peer_ip
        self.inc_mode = ns_obj.inc_mode
        self.state = ns_obj.state
        self.subscription_id = imds.get_subscription_id()
        self.rg_name = imds.get_resource_group_name()
        self.my_vm_name = imds.get_vm_name()
        self.authenticated = False

    def _get_peer_vm_name(self):
        vms = api.get_vm_list(self.rg_name)

        for vm in vms:
            for nic_reference in vm.network_profile.network_interfaces:
                nic_name = nic_reference.id.split('/')[-1]
                nic = api.get_network_interface(self.rg_name, nic_name)

                for ip_config in nic.ip_configurations:
                    if ip_config.private_ip_address == self.peer_ip:
                        return vm.name

        log.warning("No VM found with private IP %s", self.peer_ip)
        return None

    def update_ha_info(self, ns_obj):
        """Update HA info"""
        self.state = ns_obj.state
        self.peer_ip = ns_obj.peer_ip
        self.inc_mode = ns_obj.inc_mode

    def process_nic(self, my_nic_name, subnet_ip_config_map):
        my_nic = api.get_network_interface(self.rg_name, my_nic_name)
        my_subnet_id = my_nic.ip_configurations[0].subnet.id

        if my_subnet_id in subnet_ip_config_map:
            peer_nic_info_list = subnet_ip_config_map[my_subnet_id]

            def process_peer_nic(peer_nic_info):
                peer_nic_name = peer_nic_info['nic_name']
                peer_ip_configs = peer_nic_info['ip_configs']

                peer_nic = api.get_network_interface(self.rg_name, peer_nic_name)
                peer_nic.ip_configurations = [peer_nic.ip_configurations[0]]

                api.create_or_update_network_interface(self.rg_name, peer_nic_name, peer_nic)

                my_nic.ip_configurations.extend(peer_ip_configs)

            with concurrent.futures.ThreadPoolExecutor() as executor:
                executor.map(process_peer_nic, peer_nic_info_list)

        api.create_or_update_network_interface(self.rg_name, my_nic_name, my_nic)

    def migrate_pip_from_peer(self, subnet_ip_config_map):
        my_vm = api.get_vm(self.rg_name, self.my_vm_name)

        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []

            for nic_reference in my_vm.network_profile.network_interfaces:
                my_nic_name = nic_reference.id.split('/')[-1]

                futures.append(executor.submit(self.process_nic, my_nic_name, subnet_ip_config_map))

            try:
                concurrent.futures.wait(futures, timeout=120)
            except concurrent.futures.TimeoutError:
                log.warning("Failover didn't complete within 120 seconds")

    def get_subnet_ip_config_map(self, vm_name):
        vm = api.get_vm(self.rg_name, vm_name)
        subnet_ip_config_map = {}

        for nic_reference in vm.network_profile.network_interfaces:
            nic_name = nic_reference.id.split('/')[-1]
            nic = api.get_network_interface(self.rg_name, nic_name)
            if nic.ip_configurations and nic.ip_configurations[0].subnet:
                subnet_id = nic.ip_configurations[0].subnet.id
                if subnet_id not in subnet_ip_config_map:
                    subnet_ip_config_map[subnet_id] = []

                ip_configs = nic.ip_configurations[1:]  # Exclude the primary IP configuration
                if ip_configs:
                    subnet_ip_config_map[subnet_id].append({'nic_name': nic_name, 'ip_configs': ip_configs})

        if all(not value for value in subnet_ip_config_map.values()):
            return None

        return subnet_ip_config_map

    def handle_failover(self):
        if self.state != HaStatus.PRIMARY:
            log.debug("Not in PRIMARY state, skipping failover.")
            return

        if self.inc_mode != "DISABLED":
            log.info("INC mode is ENABLED, skipping failover.")
            return

        peer_vm_name = self._get_peer_vm_name()
        subnet_ip_config_map = self.get_subnet_ip_config_map(peer_vm_name)

        if not subnet_ip_config_map:
            log.info("No Secondary IPs to be moved from peer %s", self.peer_ip)
            return

        log.info("Handling HA failover here, peer IP is %s", self.peer_ip)
        self.migrate_pip_from_peer(subnet_ip_config_map)
        log.info("Failover completed")

    def do_sleep(self):
        time.sleep(600)

    def do_periodic(self):
        if self.state == HaStatus.STANDALONE or self.inc_mode == "ENABLED":
            log.debug("Standalone System or INC Enabled, skipping Periodic Tasks")
            if os.path.exists(API_NOT_WORKING_FILE):
                os.remove(API_NOT_WORKING_FILE)
            return

        vms = api.get_vm_list(self.rg_name)
        if vms is None:
            self.authenticated = False
            with open(API_NOT_WORKING_FILE, 'w', encoding='utf-8') as api_file:
                api_file.write("API Call Failed")
        else:
            self.authenticated = True
            log.debug("Periodic API Check Passed")
            if os.path.exists(API_NOT_WORKING_FILE):
                os.remove(API_NOT_WORKING_FILE)

