
#include <stdio.h>
#include <stdlib.h>

#include "rsa.h"
#include "mkrsa.h"



/*
 * MKRSA_NewKeyPair()
 *
 * generate a new key pair with specific public exponent
 *
 */
BOOL MKRSA_NewKeyPair(MKRSA_KEY_PAIR *key_pair, unsigned int bit_length, unsigned long pub_exp_int)
{
	int i;
    mpz_t temp, pub_exp, pri_exp, mod, phi;
    
    // don't allow too small bit length
    if (bit_length < 64)
        bit_length = 64;
    
    mpz_init2(temp,    bit_length);
    mpz_init2(pub_exp, bit_length);
    mpz_init2(pri_exp, bit_length);
    mpz_init2(mod,     bit_length);
    mpz_init2(phi,     bit_length);
    
    // generate random parameters mod & phi
    __RSA_GenerateRandomParameters(bit_length, mod, phi);
    
    // set the public exponent, must be odd number
    if (pub_exp_int == RSA_RANDOM_PUBLIC_EXPONENT) {
        mpz_urandomb(pub_exp, rand_state, bit_length/2);
		if (mpz_even_p(pub_exp)) {
			mpz_add_ui(pub_exp, pub_exp, 1);
		}
	}
    else {
		if (pub_exp_int % 2 == 0)
       		pub_exp_int++;
	    mpz_set_ui(pub_exp, pub_exp_int);
	}
    
    // increase public exponent by 2 until pub_exp and phi are relative primes
    mpz_gcd(temp, pub_exp, phi);
    while (mpz_cmp_ui(temp, 1) != 0) { 
        mpz_add_ui(pub_exp, pub_exp, 2);
        mpz_gcd(temp, pub_exp, phi);
    }
    
    // generate random private exponent
    __RSA_GeneratePrivateExp(bit_length, pri_exp, pub_exp, phi);


	//////////////////////////////////////////////////////////////////////////
    // create the keys with generated values
	//////////////////////////////////////////////////////////////////////////

	if (key_pair->public_key.param) {
		for (i=0; i<key_pair->public_key.num_param; i++) {
			mpz_clear(key_pair->public_key.param[i]);
		}
	}

	if (key_pair->private_key.param) {
		for (i=0; i<key_pair->private_key.num_param; i++) {
			mpz_clear(key_pair->private_key.param[i]);
		}
	}

	// reallocate to new size
	if ((key_pair->public_key.param = (mpz_t*)realloc(key_pair->public_key.param, sizeof(mpz_t)*MKRSA_NUM_PARAM)) == NULL)
		goto error;

	mpz_init2(key_pair->public_key.param[0], bit_length);
	mpz_init2(key_pair->public_key.param[1], bit_length);
	mpz_init2(key_pair->public_key.param[2], bit_length);

	key_pair->public_key.type = CRYPTO_KEY_RSA;
	key_pair->public_key.num_param = MKRSA_NUM_PARAM;

	if ((key_pair->private_key.param = (mpz_t*)realloc(key_pair->private_key.param, sizeof(mpz_t)*MKRSA_NUM_PARAM)) == NULL)
		goto error;

	mpz_init2(key_pair->private_key.param[0], bit_length);
	mpz_init2(key_pair->private_key.param[1], bit_length);
	mpz_init2(key_pair->private_key.param[2], bit_length);

	key_pair->private_key.type = CRYPTO_KEY_RSA;
	key_pair->private_key.num_param = MKRSA_NUM_PARAM;

	// store key parameters
	key_pair->type = CRYPTO_KEY_RSA;
	
    mpz_set(key_pair->public_key.param[MKRSA_PARAM_EXPONENT], pub_exp);
    mpz_set(key_pair->public_key.param[MKRSA_PARAM_MODULUS], mod);
    mpz_set_ui(key_pair->public_key.param[MKRSA_PARAM_PHI], 0); // don't disclose phi
    
    mpz_set(key_pair->private_key.param[MKRSA_PARAM_EXPONENT], pri_exp);
    mpz_set(key_pair->private_key.param[MKRSA_PARAM_MODULUS], mod);
    mpz_set(key_pair->private_key.param[MKRSA_PARAM_PHI], phi);

	// release temporary resources
    mpz_clear(temp);
    mpz_clear(pub_exp);
    mpz_clear(pri_exp);
    mpz_clear(mod);
    mpz_clear(phi);
    
    return TRUE;
    
error:

	if (key_pair->public_key.param) {
		for (i=0; i<key_pair->public_key.num_param; i++) {
			mpz_clear(key_pair->public_key.param[i]);
		}
		free(key_pair->public_key.param);
		
		key_pair->public_key.type = CRYPTO_KEY_UNDEFINED;
		key_pair->public_key.num_param = 0;
		key_pair->public_key.param = NULL;
	}

	if (key_pair->private_key.param) {
		for (i=0; i<key_pair->private_key.num_param; i++) {
			mpz_clear(key_pair->private_key.param[i]);
		}
		free(key_pair->private_key.param);

		key_pair->private_key.type = CRYPTO_KEY_UNDEFINED;
		key_pair->private_key.num_param = 0;
		key_pair->private_key.param = NULL;
	}

    mpz_clear(temp);
    mpz_clear(pub_exp);
    mpz_clear(pri_exp);
    mpz_clear(mod);
    mpz_clear(phi);
	
	return FALSE;
}




/*
 * MKRSA_NewReEncryptionKeyPair()
 *
 * genearate a re-encryption/decryption key pair based on a given initial encryption key
 *
 */
BOOL MKRSA_NewReEncryptionKeyPair(MKRSA_KEY_PAIR *key_pair, MKRSA_KEY base_key)
{
	int i, bit_length;
    mpz_t temp, pub_exp, pri_exp;
    
    mpz_init(temp);
    mpz_init(pub_exp);
    mpz_init(pri_exp);
    
    bit_length = base_key.param[MKRSA_PARAM_MODULUS]->_mp_size * mp_bits_per_limb;
    
    // generate random private exponent
    mpz_urandomb(pri_exp, rand_state, bit_length / 2);
	if (mpz_even_p(pri_exp)) {
		mpz_add_ui(pri_exp, pri_exp, 1);
	}
    
    // increase private exponent by 2 until pub_exp and phi are relative primes
    mpz_gcd(temp, pri_exp, base_key.param[MKRSA_PARAM_PHI]);
    while (mpz_cmp_ui(temp, 1) != 0) { 
        mpz_add_ui(pri_exp, pri_exp, 2);
        mpz_gcd(temp, pri_exp, base_key.param[MKRSA_PARAM_PHI]);
    }
        
    // generate random public exponent, based on the product of old pri. exp. & new pri. exp.
    mpz_mul(temp, pri_exp, base_key.param[MKRSA_PARAM_EXPONENT]);
    __RSA_GeneratePrivateExp(bit_length, pub_exp, temp, base_key.param[MKRSA_PARAM_PHI]);


	//////////////////////////////////////////////////////////////////////////
    // create the keys with generated values
	//////////////////////////////////////////////////////////////////////////

	if (key_pair->public_key.param) {
		for (i=0; i<key_pair->public_key.num_param; i++) {
			mpz_clear(key_pair->public_key.param[i]);
		}
	}

	if (key_pair->private_key.param) {
		for (i=0; i<key_pair->private_key.num_param; i++) {
			mpz_clear(key_pair->private_key.param[i]);
		}
	}

	// reallocate to new size
	if ((key_pair->public_key.param = (mpz_t*)realloc(key_pair->public_key.param, sizeof(mpz_t)*MKRSA_NUM_PARAM)) == NULL)
		goto error;

	mpz_init2(key_pair->public_key.param[0], bit_length);
	mpz_init2(key_pair->public_key.param[1], bit_length);
	mpz_init2(key_pair->public_key.param[2], bit_length);

	key_pair->public_key.type = CRYPTO_KEY_RSA;
	key_pair->public_key.num_param = MKRSA_NUM_PARAM;

	if ((key_pair->private_key.param = (mpz_t*)realloc(key_pair->private_key.param, sizeof(mpz_t)*MKRSA_NUM_PARAM)) == NULL)
		goto error;

	mpz_init2(key_pair->private_key.param[0], bit_length);
	mpz_init2(key_pair->private_key.param[1], bit_length);
	mpz_init2(key_pair->private_key.param[2], bit_length);

	key_pair->private_key.type = CRYPTO_KEY_RSA;
	key_pair->private_key.num_param = MKRSA_NUM_PARAM;

	// store key parameters
	key_pair->type = CRYPTO_KEY_RSA;
	
    mpz_set(key_pair->public_key.param[MKRSA_PARAM_EXPONENT], pub_exp);
    mpz_set(key_pair->public_key.param[MKRSA_PARAM_MODULUS], base_key.param[MKRSA_PARAM_MODULUS]);
    mpz_set_ui(key_pair->public_key.param[MKRSA_PARAM_PHI], 0); // don't disclose phi
    
    mpz_set(key_pair->private_key.param[MKRSA_PARAM_EXPONENT], pri_exp);
    mpz_set(key_pair->private_key.param[MKRSA_PARAM_MODULUS], base_key.param[MKRSA_PARAM_MODULUS]);
    mpz_set(key_pair->private_key.param[MKRSA_PARAM_PHI], base_key.param[MKRSA_PARAM_PHI]);

	// release temporary resources
    mpz_clear(temp);
    mpz_clear(pub_exp);
    mpz_clear(pri_exp);

    return TRUE;
    
error:

	if (key_pair->public_key.param) {
		for (i=0; i<key_pair->public_key.num_param; i++) {
			mpz_clear(key_pair->public_key.param[i]);
		}
		free(key_pair->public_key.param);
		
		key_pair->public_key.type = CRYPTO_KEY_UNDEFINED;
		key_pair->public_key.num_param = 0;
		key_pair->public_key.param = NULL;
	}

	if (key_pair->private_key.param) {
		for (i=0; i<key_pair->private_key.num_param; i++) {
			mpz_clear(key_pair->private_key.param[i]);
		}
		free(key_pair->private_key.param);

		key_pair->private_key.type = CRYPTO_KEY_UNDEFINED;
		key_pair->private_key.num_param = 0;
		key_pair->private_key.param = NULL;
	}

    mpz_clear(temp);
    mpz_clear(pub_exp);
    mpz_clear(pri_exp);
	
	return FALSE;
}



/*
 * MKRSA_Encrypt()
 *
 * data encrytpion
 *
 */
BOOL MKRSA_Encrypt(char cipher[], int *out_len, char data[], int in_len, MKRSA_KEY key)
{
	int i, in_offs, out_offs;
	int block_len_mp, block_len_byte, min_out_len, GMP_LIMB_BYTES = mp_bits_per_limb  / 8;
	mpz_t mpz_data;
	
	// compute the encrypted data size
	block_len_mp = key.param[MKRSA_PARAM_MODULUS]->_mp_size - 1;
	block_len_byte = block_len_mp * GMP_LIMB_BYTES;
	
	//fprintf(stderr, "[DEBUG] GMP_LIMB_BITS = %d\n", mp_bits_per_limb);

    min_out_len = ((in_len+block_len_byte-1) / block_len_byte) * (block_len_byte + GMP_LIMB_BYTES);

    //fprintf(stderr, "[DEBUG] in_len = %d, min_out_len = %d\n", in_len, min_out_len);

    // report the minimum required output buffer size if not enough
    if (*out_len < min_out_len) {
    	*out_len = min_out_len;
    	return FALSE;
    }
	*out_len = min_out_len;

    // perform encryption on block level
    mpz_init2(mpz_data, key.param[MKRSA_PARAM_MODULUS]->_mp_size * mp_bits_per_limb);
    memset(cipher, 0, *out_len);
    in_offs = out_offs = 0;

	while (in_offs + block_len_byte <= in_len) {
		mpz_data->_mp_size = block_len_mp;
        memcpy(mpz_data->_mp_d, data+in_offs, block_len_byte);
        in_offs += block_len_byte;

        mpz_powm(mpz_data, mpz_data, key.param[MKRSA_PARAM_EXPONENT], key.param[MKRSA_PARAM_MODULUS]);
        
        memcpy(cipher+out_offs, mpz_data->_mp_d, mpz_data->_mp_size * GMP_LIMB_BYTES);
        out_offs += block_len_byte + GMP_LIMB_BYTES;
    }
    
    // last block, input need padding
    if (in_offs < in_len) {
		mpz_data->_mp_size = block_len_mp;
		memset(mpz_data->_mp_d, 0, block_len_byte);
	    memcpy(mpz_data->_mp_d, data+in_offs, in_len - in_offs);
	    in_offs += block_len_byte;
	
	    mpz_powm(mpz_data, mpz_data, key.param[MKRSA_PARAM_EXPONENT], key.param[MKRSA_PARAM_MODULUS]);
	    
	    memcpy(cipher+out_offs, mpz_data->_mp_d, mpz_data->_mp_size * GMP_LIMB_BYTES);
	    out_offs += block_len_byte + GMP_LIMB_BYTES;
    }
    
    mpz_clear(mpz_data);
    
    return TRUE;
}

BOOL MKRSA_ReEncrypt(char cipher[], int *out_len, char data[], int in_len, MKRSA_KEY key)
{
	int i, in_offs, out_offs;
	int block_len_mp, block_len_byte, min_out_len, GMP_LIMB_BYTES = mp_bits_per_limb  / 8;
	mpz_t mpz_data;

	// encrypted data size will be equals to the plaintext data size for reencryption
	block_len_mp = key.param[MKRSA_PARAM_MODULUS]->_mp_size;
	block_len_byte = block_len_mp * GMP_LIMB_BYTES;

    min_out_len = in_len;

    // report the minimum required output buffer size if not enough
    if (*out_len < min_out_len) {
    	*out_len = min_out_len;
    	return FALSE;
    }
	*out_len = min_out_len;

    // perform encryption on block level
    mpz_init2(mpz_data, key.param[MKRSA_PARAM_MODULUS]->_mp_size * mp_bits_per_limb);
    memset(cipher, 0, *out_len);
    in_offs = out_offs = 0;

	while (in_offs + block_len_byte <= in_len) {
		mpz_data->_mp_size = block_len_mp;
        memcpy(mpz_data->_mp_d, data+in_offs, block_len_byte);
        in_offs += block_len_byte;

        mpz_powm(mpz_data, mpz_data, key.param[MKRSA_PARAM_EXPONENT], key.param[MKRSA_PARAM_MODULUS]);

        memcpy(cipher+out_offs, mpz_data->_mp_d, mpz_data->_mp_size * GMP_LIMB_BYTES);
        out_offs += block_len_byte;
    }

    // last block, input need padding
    if (in_offs < in_len) {
		mpz_data->_mp_size = block_len_mp;
		memset(mpz_data->_mp_d, 0, block_len_byte);
	    memcpy(mpz_data->_mp_d, data+in_offs, in_len - in_offs);
	    in_offs += block_len_byte;

	    mpz_powm(mpz_data, mpz_data, key.param[MKRSA_PARAM_EXPONENT], key.param[MKRSA_PARAM_MODULUS]);

	    memcpy(cipher+out_offs, mpz_data->_mp_d, mpz_data->_mp_size * GMP_LIMB_BYTES);
	    out_offs += block_len_byte;
    }

    mpz_clear(mpz_data);

    return TRUE;
}

BOOL MKRSA_Decrypt(char data[], int *out_len, char cipher[], int in_len, MKRSA_KEY key)
{
	int i, in_offs, out_offs;
	int block_len_mp, block_len_byte, min_out_len, GMP_LIMB_BYTES = mp_bits_per_limb  / 8;
	mpz_t mpz_data;

	// compute the decrypted data size
	block_len_mp = key.param[MKRSA_PARAM_MODULUS]->_mp_size;
	block_len_byte = block_len_mp * GMP_LIMB_BYTES;
	
	//fprintf(stderr, "[DEBUG] GMP_LIMB_BITS = %d\n", mp_bits_per_limb );

    min_out_len = ((in_len-block_len_byte-1) / block_len_byte) * (block_len_byte - GMP_LIMB_BYTES);

    //fprintf(stderr, "[DEBUG] in_len = %d, min_out_len = %d\n", in_len, min_out_len);

    // report the minimum required output buffer size if not enough
    if (*out_len < min_out_len) {
    	*out_len = min_out_len;
    	return FALSE;
    }
	*out_len = min_out_len;

    // perform decryption on block level
    mpz_init2(mpz_data, key.param[MKRSA_PARAM_MODULUS]->_mp_size * mp_bits_per_limb);
    memset(data, 0, *out_len);
    in_offs = out_offs = 0;

	while (in_offs + block_len_byte <= in_len) {
		mpz_data->_mp_size = block_len_mp;
        memcpy(mpz_data->_mp_d, cipher+in_offs, block_len_byte);
        in_offs += block_len_byte;

        mpz_powm(mpz_data, mpz_data, key.param[MKRSA_PARAM_EXPONENT], key.param[MKRSA_PARAM_MODULUS]);
        
        memcpy(data+out_offs, mpz_data->_mp_d, mpz_data->_mp_size * GMP_LIMB_BYTES);
        out_offs += block_len_byte - GMP_LIMB_BYTES;
    }
    
    // last block, input need padding
    if (in_offs < in_len) {
		mpz_data->_mp_size = block_len_mp;
		memset(mpz_data->_mp_d, 0, block_len_byte);
        memcpy(mpz_data->_mp_d, cipher+in_offs, in_len-in_offs);
        in_offs += block_len_byte;

        mpz_powm(mpz_data, mpz_data, key.param[MKRSA_PARAM_EXPONENT], key.param[MKRSA_PARAM_MODULUS]);
        
        memcpy(data+out_offs, mpz_data->_mp_d, mpz_data->_mp_size * GMP_LIMB_BYTES);
        out_offs += block_len_byte - GMP_LIMB_BYTES;
    }
    
    mpz_clear(mpz_data);
    return TRUE;
}
