"""
Copyright 2000-2023 Citrix Systems, Inc. All rights reserved.
This software and documentation contain valuable trade secrets
and proprietary property belonging to Citrix Systems, Inc.
None of this software and documentation may be copied,
duplicated or disclosed without the express written permission
of Citrix Systems, Inc.
"""

import os
import sys
import time
import logging
from logging.handlers import RotatingFileHandler
import traceback
import threading

from rainman_core.common.constants import DEFAULT_LOG_LEVEL, UNDEF_LOG_FILE_NAME


DEBUG_FLAG_FILE = "/var/tmp/.debug-rainman"
DEBUG_LOG_FILE_SUFFIX = "-debug.log"
DEBUG_FILE_CHECK_FREQ = 1
MAX_BYTES = 100 * 1024
MAX_LOG_BACKUP_COUNT = 25
LOGS_OF_INTEREST = ["libnitrocli",
                    "rainman_core",
                    "ns_pyutils"
                    ]


class RainLogger():
    __SingeltonInstance = None

    def __init__(
            self,
            logfilename=UNDEF_LOG_FILE_NAME,
            level=DEFAULT_LOG_LEVEL):
        self.logfilename = logfilename
        self.level = level

        if self.__SingeltonInstance:
            self._logger = self.__SingeltonInstance._logger
            self._debug_handler = self.__SingeltonInstance._debug_handler
        else:
            self._logger = None
            self._debug_handler = None

        self.mtime_debug_conf_file = 0
        self.default_logger_levels = {}
        self.custom_logger_levels = {}
        self.force_debug_conf_update = False
        self.debug_conf_file = f"{DEBUG_FLAG_FILE}-{self.logfilename.split('/')[-1]}.conf"

    @property
    def logger(self):
        if not self._logger:
            if self.__SingeltonInstance:
                self._logger = self.__SingeltonInstance.logger
            else:
                self.init_logger()
        return self._logger

    @classmethod
    def getLogger(cls):
        if cls.__SingeltonInstance:
            return cls.__SingeltonInstance.logger
        return logging.getLogger(sys.argv[0])

    @classmethod
    def getRainLogger(cls):
        return cls.__SingeltonInstance

    def init_logger(self):
        if not self._logger and not self.__SingeltonInstance:
            self._logger = logging.getLogger(sys.argv[0])
            self._logger.setLevel(logging.DEBUG)

            # Primary log file handler
            formatter = logging.Formatter(
                '%(asctime)s %(levelname)s %(message)s',
                '%y-%m-%d %H:%M:%S')
            file_handler = RotatingFileHandler(
                self.logfilename,
                maxBytes=MAX_BYTES,
                backupCount=MAX_LOG_BACKUP_COUNT,
                delay=True)
            file_handler.setLevel(self.level)
            file_handler.setFormatter(formatter)
            self._logger.addHandler(file_handler)
            self._add_logs_of_interest(file_handler)

            if self.logfilename != UNDEF_LOG_FILE_NAME:
                RainLogger.__SingeltonInstance = self
                threading.Thread(
                    target=self.looped_debug_operation,
                    name="debugThread",
                    daemon=True).start()

    def _add_logs_of_interest(self, file_handler):
        for logname in LOGS_OF_INTEREST:
            logger = logging.getLogger(logname)
            logger.setLevel(logging.DEBUG)
            logger.addHandler(file_handler)

    def configure_debug_handler(self):
        if os.path.isfile(DEBUG_FLAG_FILE):
            self._add_debug_handler()
            if self._configure_debug_conf_file():
                self._add_3rd_party_debugging()
                self._apply_default_debug_conf()
                self._apply_custom_debug_conf()
        else:
            self._remove_debug_handler()

    def looped_debug_operation(self):
        while True:
            self.configure_debug_handler()
            time.sleep(DEBUG_FILE_CHECK_FREQ)

    def _apply_default_debug_conf(self):
        try:
            for logname, loglevel in self.default_logger_levels.items():
                logging.getLogger(logname).setLevel(loglevel)
        except BaseException:
            self.error_trace(
                f"reseting logging levels failed with {self.default_logger_levels!r}")

    def _apply_custom_debug_conf(self):
        try:
            for logname, loglevel in self.custom_logger_levels.items():
                logging.getLogger(logname).setLevel(loglevel)
        except BaseException:
            self.error_trace(
                f"Custom debug logging failed with {self.custom_logger_levels!r}")

    def _configure_debug_conf_file(self):
        """ return True if debug_conf_file is loaded into custom_logger_levels"""
        loaded_conf = False

        if not os.path.isfile(self.debug_conf_file):
            self.logger.debug("write new debug conf file")
            loggers = []
            for name, level in self.default_logger_levels.items():
                loggers.append(f"#{name}:{logging.getLevelName(level)}")
            with open(self.debug_conf_file, 'w+', encoding='UTF-8') as d_f:
                d_f.write("\n".join(loggers))
            self.force_debug_conf_update = False
            self.mtime_debug_conf_file = os.path.getmtime(self.debug_conf_file)
            self.custom_logger_levels = {}
            return loaded_conf

        mtime = os.path.getmtime(self.debug_conf_file)
        if self.mtime_debug_conf_file < mtime:
            self.logger.debug("load debug conf file changes")
            self.mtime_debug_conf_file = mtime
            self.custom_logger_levels = {}
            with open(self.debug_conf_file, 'r', encoding='UTF-8') as d_f:
                debug_conf_data = d_f.read()
            for line in debug_conf_data.split("\n"):
                if line.startswith("#"):
                    continue
                try:
                    split_words = line.split(":")
                    if len(split_words) == 2:
                        logger_name, logger_level = split_words
                        if logger_level in logging._nameToLevel:
                            self.custom_logger_levels[logger_name] = logger_level
                except BaseException:
                    self.error_trace(
                        f"failed to load custom debug log levels from line {line}")
            loaded_conf = True

        if self.force_debug_conf_update:
            loggers = []
            for name, level in self.default_logger_levels.items():
                if name in self.custom_logger_levels:
                    continue
                loggers.append(f"#{name}:{logging.getLevelName(level)}")
            for name, level_str in self.custom_logger_levels.items():
                loggers.append(f"{name}:{level_str}")
            with open(self.debug_conf_file, 'w', encoding='UTF-8') as d_f:
                d_f.write("\n".join(loggers))
            self.force_debug_conf_update = False
            self.logger.debug("debug conf file updated")

        return loaded_conf

    def _add_debug_handler(self):
        if not self._debug_handler:
            debug_logfilename = f"{self.logfilename[:-4]}{DEBUG_LOG_FILE_SUFFIX}"
            debug_formatter_str = '[%(asctime)s %(name)s] '
            debug_formatter_str += '[%(threadName)s %(module)s %(funcName)s %(lineno)d] '
            debug_formatter_str += '%(levelname)s: %(message)s'
            debug_formatter = logging.Formatter(debug_formatter_str)
            self._debug_handler = logging.FileHandler(
                debug_logfilename, delay=True)
            self._debug_handler.setLevel(logging.DEBUG)
            self._debug_handler.setFormatter(debug_formatter)
            self._logger.addHandler(self._debug_handler)
            self._add_3rd_party_debugging()
            self._apply_default_debug_conf()
            self._apply_custom_debug_conf()

    def _remove_debug_handler(self):
        if self._debug_handler:
            self._remove_3rd_party_debugging()
            self._logger.removeHandler(self._debug_handler)
            self._debug_handler = None
            self.default_logger_levels = {}

    def _add_3rd_party_debugging(self):
        try:
            if self._debug_handler:
                new_loggers = set(self.logger.manager.loggerDict) - \
                    set(self.default_logger_levels)
                for name in new_loggers:
                    logger = logging.getLogger(name)
                    self.default_logger_levels[name] = logging.DEBUG
                    self.logger.debug("enable debug logging for logger '%s'", name)
                    logger.addHandler(self._debug_handler)
                if new_loggers:
                    self.force_debug_conf_update = True
        except BaseException:
            self.error_trace("failed to add 3rd party debugging.")

    def _remove_3rd_party_debugging(self):
        try:
            if self._debug_handler:
                for name in self.default_logger_levels:
                    logger = logging.getLogger(name)
                    if logger is self.logger:
                        continue
                    self.logger.debug("disable logging in debug file for logger '%s'", name)
                    logger.setLevel(logging.CRITICAL)
                    logger.removeHandler(self._debug_handler)
        except BaseException:
            self.error_trace("failed to remove 3rd party debugging.")

    def debug_trace(self, msg=None, *args):
        if msg:
            self.logger.debug(msg, *args)
        self.logger.debug(traceback.format_exc())

    def error_trace(self, msg=None, *args):
        if msg:
            self.logger.error(msg, *args)
        self.logger.error(traceback.format_exc())
