# Copyright Citrix Systems, Inc. All rights reserved.

"""RSA key generation and signing.

Create new keys using the newkeypair function. It will return a tuple
containing the encoded public and private keys.

Sign values and verify signatures using the sign/verify functions. They
both accept the keys as objects or encoded.
"""
from cwc.crypto.codec import encode_public, encode_private, decode_public, decode_private
from cwc.util import getbytes, frombytes
import base64
import rsa

_HASH_ALGORITH = 'SHA-256'

def newkeypair():
    """Generates public and private keys, and returns them as (pub, priv).

    :returns:   a tuple (pub, pruv) containing the encoded keys
    """
    (pub, priv) = _newkeys(2048)
    return (encode_public(pub), encode_private(priv))

def sign(message, priv_key):
    """Signs a message using the given private key using RSA-SHA256.

    :param message:     the message to sign
    :param priv_key:    the private key to use for signing
    :returns:           the RSA-SHA256 signature encoded as base64
    """
    if message is None:
        return None

    key = priv_key

    if isinstance(priv_key, str):
        key = decode_private(priv_key)

    msg = getbytes(message)
    sig = rsa.sign(msg, key, _HASH_ALGORITH)
    b64 = base64.b64encode(sig)
    return frombytes(b64)

def verify(message, signature, pub_key):
    """Verifies that the given signature matches the message.

    :param message:     the signed message to verify
    :param signature:   the signature, as returned by :py:func:`cwc.crypto.rsa.sign`
    :param pub_key:     the public key to use to verify the signature
    :returns:           a boolean value indicating whether the signature matches the message
    """
    if message is None or signature is None:
        return None

    try:
        key = pub_key
	try:
            if isinstance(pub_key, unicode):
                pub_key=str(pub_key)
        except:
            pass
	
        if isinstance(pub_key, str):
            key = decode_public(pub_key)

        msg = getbytes(message)
        b64 = getbytes(signature)
        sig = base64.b64decode(b64)

        return rsa.verify(msg, sig, key)
    except:
        return False

def _newkeys(nbits):
    (p, q) = _find_p_q(nbits // 2)
    (e, d) = rsa.key.calculate_keys(p, q)
    n = p * q
    return (rsa.PublicKey(n, e), rsa.PrivateKey(n, e, d, p, q))

def _find_p_q(nbits):
    total_bits = nbits * 2

    # IMPORTANT (luisga): p and q must have the same bit length to ensure .NET compatibility.
    pbits = nbits
    qbits = nbits
    
    # choose the initial primes
    p = rsa.prime.getprime(pbits)
    q = rsa.prime.getprime(qbits)

    def primes_acceptable(p, q):
        if p == q:
            return False

        # ensure we have just the right amount of bits
        found_size = rsa.common.bit_size(p * q)
        return total_bits == found_size

    # keep choosing other primes until they match the requirements.
    change_p = False
    while not primes_acceptable(p, q):
        # change p and q on alternate iterations
        if change_p:
            p = rsa.prime.getprime(pbits)
        else:
            q = rsa.prime.getprime(qbits)

        change_p = not change_p

    # p > q as described in http://www.di-mgt.com.au/rsa_alg.html#crt
    return (max(p, q), min(p, q))
