# Copyright (C) 2018-2023. Cloud Software Group, Inc. All Rights Reserved. Confidential & Proprietary.

"""Authentication/Authorization operations"""
from .__meta__ import __version__
from ccauth.crypto import decode_privatekey, verify, sign, randbytes
from ccauth.config import ConfigurationOptions, LoggingLevel
from ccauth.security import SigningAlgorithm, IdentityValidationContext, IdentityValidationResult, AccessPolicy, AccessQueryContext
from ccauth.net import RequestInfo, RequestMessage, default_request_handler
from ccauth.util import isempty, asstring, asbytes
from .error import ensure_success_code
import ccauth.interop as interop
from ctypes import *
import rsa

class CCAuthHandle:
    """Wrapper around the native ccauth_handle object"""

    def __init__(self, options=None):
        self._handle = None
        self._private_key = None
        self._callback_map = {}

        if options is not None and not isinstance(options, ConfigurationOptions):
            raise ValueError("'options' must be of type 'ConfigurationOptions'.")

        handle = c_void_p()
        err = interop.ccauth_init(byref(handle))
        ensure_success_code(err, 'could not initialize native handle.')

        self._handle = handle
        self._configure(options)

    def _configure(self, options):
        # ccauth_sign_callback
        self._sign_callback = interop.sign_callback_functype(
            lambda h, msg, sz, buf, bufsz: _sign_callback_impl(self, h, msg, sz, buf, bufsz))
        err = interop.ccauth_config(self._handle, interop.ConfigurationOption.SIGN_CALLBACK, self._sign_callback)
        ensure_success_code(err, 'could not set the sign callback.')

        # ccauth_verify_callback
        self._verify_callback = interop.verify_callback_functype(_verify_callback_impl)
        err = interop.ccauth_config(self._handle, interop.ConfigurationOption.VERIFY_CALLBACK, self._verify_callback)
        ensure_success_code(err, 'could not set the verify callback.')

        # ccauth_request_callback
        self._request_handler_func = options.request_handler_func if options.request_handler_func is not None else default_request_handler
        self._request_callback = interop.request_callback_functype(
            lambda h, req, res: _request_callback_impl(self, h, req, res))
        err = interop.ccauth_config(self._handle, interop.ConfigurationOption.REQUEST_CALLBACK, self._request_callback)
        ensure_success_code(err, 'could not set the request callback.')

        # ccauth_rand_callback
        self._rand_callback = interop.rand_callback_functype(_rand_callback_impl)
        err = interop.ccauth_config(self._handle, interop.ConfigurationOption.RAND_CALLBACK, self._rand_callback)
        ensure_success_code(err, 'could not set the rand callback.')

        # ccauth_log_callback
        self._logging_func = options.logging_func
        self._logging_level = options.logging_level if self._logging_func is not None and options.logging_level is not None else LoggingLevel.OFF
        self._log_callback = interop.log_callback_functype(
            lambda h, lvl, msg, sz, udata: _log_callback_impl(self, h, lvl, msg, sz, udata))
        err = interop.ccauth_config(self._handle, interop.ConfigurationOption.LOG_CALLBACK, self._log_callback,
                                    self._logging_level, None)
        ensure_success_code(err, 'could not set the logging callback.')

        # host version
        hostversion = asbytes('ccauth-python-' + __version__)
        err = interop.ccauth_config(self._handle, interop.ConfigurationOption.HOST_VERSION, hostversion)
        ensure_success_code(err, 'could not set the host version.')

        if not isempty(options.route_template):
            route = asbytes(options.route_template)
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.ROUTE_TEMPLATE, route)
            ensure_success_code(err, 'could not set the route template.')

        if not isempty(options.service_name):
            service = asbytes(options.service_name)
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.SERVICE_NAME, service)
            ensure_success_code(err, 'could not set the service name.')

        if not isempty(options.service_instance):
            instance = asbytes(options.service_instance)
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.SERVICE_INSTANCE, instance)
            ensure_success_code(err, 'could not set the service instance id.')

        if options.signing_algorithm is not None:
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.SIGNING_ALGORITHM,
                                        options.signing_algorithm)
            ensure_success_code(err, 'could not set the signing algorithm.')

        if not options.private_key is None:
            if isinstance(options.private_key, rsa.PrivateKey):
                self._private_key = options.private_key
            elif not isempty(options.private_key):
                self._private_key = decode_privatekey(options.private_key)

        if options.capabilities_cache_ttl is not None:
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.CAPABILITIES_CACHE_TTL,
                                        options.capabilities_cache_ttl)
            ensure_success_code(err, 'could not set the cache ttl for service key capabilities.')

        if options.publickey_cache_ttl is not None:
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.PUBLICKEYS_CACHE_TTL,
                                        options.publickey_cache_ttl)
            ensure_success_code(err, 'could not set the cache ttl for public keys.')

        if options.refresh_publickeys_onfail is not None:
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.REFRESH_PUBLIC_KEYS_ONFAIL,
                                        options.refresh_publickeys_onfail)
            ensure_success_code(err, 'could not set whether to refresh public keys.')

        if options.access_cache_ttl is not None:
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.ACCESS_CACHE_TTL,
                                        options.access_cache_ttl)
            ensure_success_code(err, 'could not set the cache ttl for access policies.')

        if options.token_blacklist_cache_ttl is not None:
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.TOKENBLACKLIST_CACHE_TTL,
                                        options.token_blacklist_cache_ttl)
            ensure_success_code(err, 'could not set the cache ttl for blacklisted tokens.')

        if options.circuitbreaker_requirement is not None:
            requirement = interop.CCAuthCircuitBreakerRequirement._marshal(options.circuitbreaker_requirement)
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.CIRCUITBREAKER_REQUIREMENT,
                                        byref(requirement))
            ensure_success_code(err, 'could not set the circuit breaker requirements.')

        if options.token_requirement is not None:
            requirement = interop.CCAuthTokenRequirement._marshal(options.token_requirement)
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.BEARERTOKEN_REQUIREMENT,
                                        byref(requirement))
            ensure_success_code(err, 'could not set the bearer token requirements.')

        if options.servicekey_requirement is not None:
            requirement = interop.CCAuthServiceKeyRequirement._marshal(options.servicekey_requirement)
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.SERVICEKEY_REQUIREMENT,
                                        byref(requirement))
            ensure_success_code(err, 'could not set the service key requirements.')

        if options.allow_custom_admins is not None:
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.ALLOW_CUSTOM_ADMINS,
                                        options.allow_custom_admins)
            ensure_success_code(err, 'could not set whether to allow custom administrators.')

        if options.cache_grace_period is not None:
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.CACHE_GRACE_PERIOD,
                                        options.cache_grace_period)
            ensure_success_code(err, 'could not set the cache grace period.')

        if options.max_content_length is not None:
            err = interop.ccauth_config(self._handle, interop.ConfigurationOption.MAX_CONTENT_LENGTH,
                                        options.max_content_length)
            ensure_success_code(err, 'could not set the maximum request/response content length.')

    def dispose(self):
        """Disposes the native handle from libccauth.

        :param self:    the current py:class:`ccauth.CCAuthHandle` instance
        """
        if self._handle is not None:
            interop.ccauth_deinit(self._handle)
            self._handle = None

        self._callback_map.clear()

    def create_servicekey(self, request_info, signing_algorithm=SigningAlgorithm.DEFAULT):
        """Creates a service key.

        :param self:                the current py:class:`ccauth.CCAuthHandle` instance
        :param request_info:        the target request info, see py:class:`ccauth.net.RequestInfo`
        :param signing_algorithm:   the (optional) signing algorithm to use, see py:class:`ccauth.security.SigningAlgorithm`
        :returns:                   a CC encoded service key
        """
        if request_info is None or not isinstance(request_info, RequestInfo):
            raise ValueError("'request_info' must be of type 'RequestInfo'.")

        if request_info.method is None:
            raise ValueError("missing 'request_info.method'.")

        if request_info.uri is None:
            raise ValueError("missing 'request_info.uri'.")

        if not SigningAlgorithm.isvalid(signing_algorithm):
            raise ValueError("invalid 'request_info.signing_algorithm'.")

        request = interop.CCAuthRequestInfo._marshal(request_info)
        svckey_ptr = c_char_p()

        err = interop.ccauth_create_servicekey(
            self._handle,
            signing_algorithm,
            byref(request),
            byref(svckey_ptr))

        ensure_success_code(err, 'could not create native service key.')

        try:
            return asstring(string_at(svckey_ptr))
        finally:
            interop.ccauth_free(svckey_ptr)

    def validate_identity(self, header, validation_context, request_info=None):
        """Validates a CWSAuth authorization header.

        :param self:                the current py:class:`ccauth.CCAuthHandle` instance
        :param header:              the CWSAuth header to validate
        :param validation_context:  the validation context to use, see py:class:`ccauth.security.IdentityValidationContext`
        :param request_info:        the target request info (for service keys), see py:class:`ccauth.net.RequestInfo`
        :returns:                   an instance of py:class:`ccauth.security.IdentityValidationResult`
        """
        if validation_context is None or not isinstance(validation_context, IdentityValidationContext):
            raise ValueError("'validation_context' must be of type 'IdentityValidationContext'.")

        if request_info is not None and not isinstance(request_info, RequestInfo):
            raise ValueError("'request_info' must be of type 'RequestInfo'.")

        context = interop.CCAuthValidationContext._marshal(validation_context)
        request = interop.CCAuthRequestInfo._marshal(request_info) if request_info is not None else None
        idresult_ptr = c_void_p()

        err = interop.ccauth_validate_identity(
            self._handle,
            asbytes(header),
            byref(request) if request is not None else None,
            byref(context),
            byref(idresult_ptr))

        ensure_success_code(err, 'could not get native identity result.')

        try:
            return IdentityValidationResult._unmarshal(idresult_ptr)
        finally:
            interop.ccauth_free_identity(idresult_ptr)

    def check_access(self, required_policies, query_context):
        """Performs access control.

        :param self:                the current py:class:`ccauth.CCAuthHandle` instance
        :param required_policies:   the required policies to match
        :param query_context:       the query context to use, see py:class:`ccauth.security.AccessQueryContext`
        :returns:                   a boolean value indicating whether access should be granted
        """
        if required_policies is not None and not isinstance(required_policies, list) and not isinstance(required_policies, tuple):
            required_policies = [required_policies]

        if required_policies is None or not all(isinstance(p, AccessPolicy) for p in required_policies):
            raise ValueError("'required_policies' must be a list of 'AccessPolicy'.")

        if query_context is None or not isinstance(query_context, AccessQueryContext):
            raise ValueError("'query_context' must be of type 'AccessQueryContext'.")

        native = [interop.CCAuthAccessPolicy._marshal(p) for p in required_policies]
        policies = (interop.CCAuthAccessPolicy * len(native))(*native)
        context = interop.CCAuthAccessContext._marshal(query_context)
        result = c_int(1)

        err = interop.ccauth_check_access(
            self._handle,
            policies,
            c_size_t(len(policies)),
            byref(context),
            byref(result))

        ensure_success_code(err, 'could not get native check access result.')

        return result.value == 0

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.dispose()

    def __del__(self):
        self.dispose()


def _sign_callback_impl(handle, _, plain, size, buffer, buffer_size):
    plainbytes = bytes(POINTER(c_ubyte * size).from_buffer(plain)[0])
    signed = sign(plainbytes, handle._private_key)
    length = len(signed)

    if length > buffer_size:
        return 0

    memmove(buffer, signed, len(signed))
    return length


def _verify_callback_impl(_, signature, sigsize, message, msgsize, publickey):
    sig = bytes(POINTER(c_ubyte * sigsize).from_buffer(signature)[0])
    msg = bytes(POINTER(c_ubyte * msgsize).from_buffer(message)[0])

    try:
        return 1 if verify(msg, sig, publickey) else 0
    except:
        return 0


def _request_callback_impl(handle, _, request, response):
    request_message = RequestMessage._unmarshal(request)
    response_message = handle._request_handler_func(request_message)

    if response_message is not None:
        r = response[0]
        r.status = response_message.status

        if response_message.content is not None and len(response_message.content) > 0:
            content_callback = interop.content_callback_functype(
                lambda buf, bufsize, userdata: _response_content_callback_impl(handle, buf, bufsize, userdata))
            cbid = id(content_callback)
            handle._callback_map[cbid] = (content_callback, bytes(response_message.content))
            r.callback = content_callback
            r.userdata = cast(cbid, c_void_p)


def _response_content_callback_impl(handle, buffer, buffer_size, userdata):
    cbid = int(userdata)

    if cbid not in handle._callback_map:
        raise ValueError('invalid userdata')

    callback_info = handle._callback_map[cbid]

    if id(callback_info[0]) != cbid:
        raise ValueError('invalid callback')

    content = callback_info[1]
    content_size = len(content)
    copy_size = min(content_size, buffer_size)

    memmove(buffer, content, copy_size)

    if copy_size > 0:
        handle._callback_map[cbid] = (callback_info[0], content[copy_size:])
    else:
        del handle._callback_map[cbid]

        if buffer_size <= 0 and content_size > 0:
            copy_size = -1

    return copy_size


def _rand_callback_impl(_, buffer, size):
    b = randbytes(size)
    memmove(buffer, b, size)
    return size


def _log_callback_impl(handle, _, level, message, size, __):
    if handle._logging_func is not None and level >= handle._logging_level:
        handle._logging_func(level, asstring(string_at(message, size)))
