
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <time.h>
#include <errno.h>
#include <fcntl.h>

#define _XOPEN_SOURCE
#include <unistd.h>



#include "rsa.h"
#include "mkrsa.h"

#include "sml.h"





/* ==================================================================





   Socket Functions
   
   
   
   
   
   ================================================================== */
int __TCP_ReadFully (int socket, char *data, int length)
{
	int received=0, err;
    
	do {      
		err = recv (socket, data+received, length-received, MSG_WAITALL);
		if (err == EWOULDBLOCK)
		    return err;
		else if (err == SOCKET_ERROR)
			return err;
		else if (err == 0) // closed by peer
			return err;
			
		received += err;
		
	} while (received < length);
	
	return length;
}








/* ==================================================================





   SML API Implementation
   
   
   
   
   
   ================================================================== */

void SML_InitSession(SML_SESSION* session)
{
    int i;

	session->servSd = INVALID_SOCKET;
    session->socket = INVALID_SOCKET;
    session->last_bind_port = -1;

    // starting sequence number
    session->seq = 0;
    session->anonymous_mode = 0;

    // keys are undefined until it is initialized later
    for (i=0; i<SESSION_NUM_KEY; i++) {
        CRYPTO_InitKey(&session->key[i]);
    }
    session->__signature = NULL;
    session->__client_signature = NULL;
    session->__signature_len = 0;
    session->__client_signature_len = 0;

    session->__reenc_key = NULL;
    session->__dec_key = NULL;
    session->__dec_key2 = NULL;

   	// initialize encryption operation variables
    session->__plaintext_pkt_size == -1;
    session->__cipher_buffer = NULL;
    session->__plaintext_buffer = NULL;
}

void SML_DestroySession(SML_SESSION* session)
{
    int i;
   
    // release each key struct
    for (i=0; i<SESSION_NUM_KEY; i++) {
    	CRYPTO_DestroyKey(&session->key[i]);
    }
    
    if (session->__signature)
    	free(session->__signature);
    if (session->__client_signature)
    	free(session->__client_signature);

    if (session->__reenc_key)
    	free(session->__reenc_key);
    if (session->__dec_key)
    	free(session->__dec_key);
    if (session->__dec_key2)
    	free(session->__dec_key2);

	// free cipher packet buffer
	if (session->__cipher_buffer)
		free(session->__cipher_buffer);
	// free plaintext packet buffer
	if (session->__plaintext_buffer)
		free(session->__plaintext_buffer);

	// close socket
	if (session->socket != INVALID_SOCKET)
		close(session->socket);
	if (session->servSd != INVALID_SOCKET)
		close(session->servSd);
	
	// invalidation
	session->socket = INVALID_SOCKET;
	session->servSd = INVALID_SOCKET;
    session->last_bind_port = -1;
    
	session->__reenc_key = NULL;
	session->__dec_key = NULL;
	session->__dec_key2 = NULL;
	session->__cipher_buffer = NULL;
	session->__plaintext_buffer = NULL;
}


int SML_LoadKeyPair(SML_SESSION* session, char keyfilename[], char passwd[], unsigned int key_type)
{
	char *str = NULL;
	int len;
	
    CRYPTO_KEY_PAIR key_pair;
    
    CRYPTO_InitKeyPair(&key_pair);
    
    // ensure key type is supported
    if (key_type == CRYPTO_KEY_UNDEFINED || key_type > CRYPTO_KEY_MAX_VALID)
        return SML_ERROR_KEY_TYPE_UNSUPPORT;

	// load key pair
	if (CRYPTO_LoadKeyPair(&key_pair, keyfilename, passwd, key_type) == FALSE)
		return SML_ERROR_FILE_ERROR;

	CRYPTO_CopyPublicKey(&session->key[MY_PUBLIC_KEY], key_pair);
	CRYPTO_CopyPrivateKey(&session->key[MY_PRIVATE_KEY], key_pair);

	CRYPTO_DestroyKeyPair(&key_pair);

	// create the eCert (private encrypted public key)
    len = CRYPTO_KeyToStringAlloc(&str, session->key[MY_PUBLIC_KEY]);
    session->__signature_len = len << 1;
	session->__signature = (char*)realloc(session->__signature, sizeof(char) * session->__signature_len);

    CRYPTO_Encrypt(session->__signature, &session->__signature_len, str, len, session->key[MY_PRIVATE_KEY]);

	if (str)
		free(str);

    return SML_SUCCESS;
}

int SML_SaveKeyPair(SML_SESSION* session, char keyfilename[], char passwd[])
{
    CRYPTO_KEY_PAIR key_pair;
    
    CRYPTO_InitKeyPair(&key_pair);

    // ensure key type is supported
    if (session->key[MY_PUBLIC_KEY].type == CRYPTO_KEY_UNDEFINED || session->key[MY_PUBLIC_KEY].type > CRYPTO_KEY_MAX_VALID ||
        session->key[MY_PRIVATE_KEY].type == CRYPTO_KEY_UNDEFINED || session->key[MY_PRIVATE_KEY].type > CRYPTO_KEY_MAX_VALID)
        return SML_ERROR_CRYPTO_KEY_UNDEFINED;
        
    // combine the keys to a key pair struct
    CRYPTO_CopyKey(&key_pair.public_key, session->key[MY_PUBLIC_KEY]);
    CRYPTO_CopyKey(&key_pair.private_key, session->key[MY_PRIVATE_KEY]);
    
    // save public key to file
    if (CRYPTO_SaveKeyPair(key_pair, keyfilename, passwd) == FALSE)
    	return SML_ERROR_FILE_ERROR;
    	
   	CRYPTO_DestroyKeyPair(&key_pair);
    
    return SML_SUCCESS;
}

int SML_NewKeyPair(SML_SESSION* session, unsigned int bit_length, unsigned int pub_exp, unsigned int key_type)
{
    CRYPTO_KEY_PAIR key_pair;
    char *str = NULL;
    int len;
    
    // ensure key type is supported
    if (key_type == CRYPTO_KEY_UNDEFINED || key_type > CRYPTO_KEY_MAX_VALID)
        return SML_ERROR_KEY_TYPE_UNSUPPORT;

    // create key according to key type
    CRYPTO_InitKeyPair(&key_pair);
    
    if (CRYPTO_NewKeyPair(&key_pair, bit_length, pub_exp, key_type) == FALSE) {
		CRYPTO_DestroyKeyPair(&key_pair);		
    	return SML_ERROR_NEW_KEY_ERROR;
    }

	/* debug
	CRYPTO_KeyPairToStringAlloc(&str, key_pair);
    fprintf(stderr, "[SML_NewKeyPair DEBUG] \n%s\n\n", str);
    free(str);
    */
    
    
   	CRYPTO_CopyPublicKey(&session->key[MY_PUBLIC_KEY], key_pair);
   	CRYPTO_CopyPrivateKey(&session->key[MY_PRIVATE_KEY], key_pair);

	/* debug
	CRYPTO_KeyToStringAlloc(&str, session->key[MY_PUBLIC_KEY]);
    fprintf(stderr, "[SML_NewKeyPair DEBUG] \n%s\n\n", str);
    free(str);

	CRYPTO_KeyToStringAlloc(&str, session->key[MY_PRIVATE_KEY]);
    fprintf(stderr, "[SML_NewKeyPair DEBUG] \n%s\n\n", str);
    free(str);
    */

	CRYPTO_DestroyKeyPair(&key_pair);

	// create the eCert (private encrypted public key)
    len = CRYPTO_KeyToStringAlloc(&str, session->key[MY_PUBLIC_KEY]);
    session->__signature_len = len << 1;
	session->__signature = (char*)realloc(session->__signature, sizeof(char) * session->__signature_len);

    CRYPTO_Encrypt(session->__signature, &session->__signature_len, str, len, session->key[MY_PRIVATE_KEY]);
        

    return SML_SUCCESS;
}

void SML_InitSignature(SML_SIGNATURE* signature)
{
	CRYPTO_InitKey(&signature->public_key);
	signature->signed_public_key = NULL;
}

void SML_DestroySignature(SML_SIGNATURE* signature)
{
	CRYPTO_DestroyKey(&signature->public_key);
	if (signature->signed_public_key) {
		free(signature->signed_public_key);
		signature->signed_public_key = NULL;
	}
}

void SML_GetClientSignature(SML_SESSION* session, SML_SIGNATURE* signature)
{
	// copy public key
	CRYPTO_CopyKey(&signature->public_key, session->key[PUBLIC_KEY]);

	// copy signed public key (public key encrypted by private key)
	signature->signed_public_key_len = session->__client_signature_len;
	signature->signed_public_key = (char*)realloc(signature->signed_public_key, sizeof(char) * signature->signed_public_key_len);
	
	fprintf(stderr, "sess key len: %d\n", session->__client_signature_len);
	fprintf(stderr, "sign key len: %d\n", signature->signed_public_key_len);
	
	memcpy(signature->signed_public_key, session->__client_signature, session->__client_signature_len);
}

void SML_SetClientSignature(SML_SESSION* session, SML_SIGNATURE signature)
{
	// copy public key
	CRYPTO_CopyKey(&session->key[MY_PUBLIC_KEY], signature.public_key);

	// copy signed public key
	session->__signature_len = signature.signed_public_key_len;
	session->__signature = (char*)realloc(session->__signature, sizeof(char) * session->__signature_len);
	memcpy(session->__signature, signature.signed_public_key, session->__signature_len);

	// use client's signature when connecting to server
	session->anonymous_mode = 0;
}

int SML_Connect(SML_SESSION* session, char host[], int port)
{
    unsigned int err, len, result;
    char *str = NULL, *str2 = NULL;
	struct hostent *ht;
	struct sockaddr_in servAddr;
	int sock;

	/////////////////////////////////////////////////////////////
	// initialize tcp socket
	sock = socket(PF_INET, SOCK_STREAM, 0);
	if (sock == INVALID_SOCKET) {
		goto sml_connect_error;
	}

	// fill up destination info in servAddr
	servAddr.sin_family = AF_INET;
	servAddr.sin_port = htons(port);
	ht = gethostbyname(host);
	if (ht == NULL) {
		goto sml_connect_error;
	}
	
	bcopy(ht->h_addr, &servAddr.sin_addr, ht->h_length);

	// associate the opened socket with the destination's address
	if (connect(sock, (struct sockaddr*)&servAddr, sizeof(servAddr)) < 0) {
		goto sml_connect_error;
	}

	if (session->socket != INVALID_SOCKET)
		close(session->socket);
	session->socket = sock;

	/////////////////////////////////////////////////////////////
	// receive server public key
	err = __TCP_ReadFully(session->socket, (char*)&len, sizeof(unsigned int));
	if (err != sizeof(unsigned int))
	    goto sml_connect_error;

    if ((str = (char*)malloc(sizeof(char) * len)) == NULL)
    	goto sml_connect_error;

	err = __TCP_ReadFully(session->socket, str, len);
	if (err != len)
		goto sml_connect_error;

    if (CRYPTO_ParseKeyFromString(&session->key[PUBLIC_KEY], str) == FALSE)
		goto sml_connect_error;

	/////////////////////////////////////////////////////////////
    // send out my public key
    if (session->key[MY_PUBLIC_KEY].type == CRYPTO_KEY_UNDEFINED) { // anynomous mode

    	fprintf(stderr, "send anonymous mode request\n");

    	len = 9999;
    	if (send(session->socket, &len, sizeof(unsigned int), 0) == SOCKET_ERROR)
	        goto sml_connect_error;
    }
    else {
	    len = CRYPTO_KeyToStringAlloc(&str, session->key[MY_PUBLIC_KEY]);
	
		/* debug
	    fprintf(stderr, "My public key: %s [len:%d]\n", str, len);
	    */
	    
	    if (send(session->socket, &len, sizeof(unsigned int), 0) == SOCKET_ERROR)
	        goto sml_connect_error;
	
	    if (send(session->socket, str, len, 0) == SOCKET_ERROR)
	        goto sml_connect_error;
	    
	    //////////////////////////////////////////////////////////////
	    // send my private-encrypted public key	
	    if (send (session->socket, &session->__signature_len, sizeof(int), 0) == SOCKET_ERROR)
	        goto sml_connect_error;
	    if (send (session->socket, session->__signature, session->__signature_len, 0) == SOCKET_ERROR)
	        goto sml_connect_error;
	}	
	        
    //////////////////////////////////////////////////////////////
    // read server respond
	err = __TCP_ReadFully(session->socket, (char*)&result, sizeof(unsigned int));
	if (err != sizeof(unsigned int))
		goto sml_connect_error;

	/////////////////////////////////////////////////////////////
	// release temporary resource
	if (str)
		free(str);

    return result;
    
sml_connect_error:

	if (str)
		free(str);
	
	close(sock);

	return FALSE;
}


SML_SESSION* SML_Accept(SML_SESSION* session, int port, char list_filename[])
{
	char user_name[1024];
    unsigned int i, err, len, integrity, valid_user;
    char *str = NULL, *str2 = NULL;
	struct sockaddr_in servAddr, cliAddr;
	int cliSd = INVALID_SOCKET;
    size_t cliAddrLen;
    int one = 1;    
    SML_SESSION *new_session = NULL;
	
	//////////////////////////////////////////////////////////////
	// bind again if port changed
	if (port != session->last_bind_port) {
		if (session->servSd != INVALID_SOCKET)
			close(session->servSd);
		session->servSd = socket(PF_INET, SOCK_STREAM, 0);
	
		// bind local server port 
		setsockopt(session->servSd, SOL_SOCKET, SO_REUSEADDR, (char*)&one, sizeof(one));
		servAddr.sin_family = AF_INET;
		servAddr.sin_port = htons(port);
		servAddr.sin_addr.s_addr = htonl(INADDR_ANY); /* or = inet_addr("137.189.89.182") */
		if (bind(session->servSd, (struct sockaddr *)&servAddr, sizeof(servAddr)) < 0) {
			perror("Can't bind\n");
	        goto sml_accept_error;
	    }
	    
		if (listen(session->servSd, 10) < 0) {
			perror("Can't listen\n");
			goto sml_accept_error;
		}

	    session->last_bind_port = port;
	}

    // wait for client
	fprintf(stderr, "Listen on port %d......\n", port);
	cliAddrLen = sizeof(cliAddr);
	cliSd = accept(session->servSd, (struct sockaddr *)&cliAddr, &cliAddrLen);
	fprintf(stderr, "Client connected......\n");



    //////////////////////////////////////////////////////////////
    // send my public key
    len = CRYPTO_KeyToStringAlloc(&str, session->key[MY_PUBLIC_KEY]);
    if (send (cliSd, &len, sizeof(int), 0) == SOCKET_ERROR)
        goto sml_accept_error;
    if (send (cliSd, str, len, 0) == SOCKET_ERROR)
        goto sml_accept_error;

	/* debug
	fprintf(stderr, "[DEBUG] send public key: %s\n\n", str);
	*/

    //////////////////////////////////////////////////////////////
    // receive client's public key
    err = __TCP_ReadFully(cliSd, (char*)&len, sizeof(unsigned int));
    if (err != sizeof(unsigned int))
        goto sml_accept_error;

	if (len == 9999) {  // anynomous user mode
		fprintf(stderr, "accept anonymous request!\n");
		integrity = 1;
	    session->anonymous_mode = 1;

    	session->__client_signature_len = 0;
    	session->__client_signature = NULL;

    	if (send (cliSd, &integrity, sizeof(int), 0) == SOCKET_ERROR)
        	goto sml_accept_error;
	}
	else {
	    
	    if ((str = (char*)malloc(sizeof(char) * len)) == NULL)
	    	goto sml_accept_error;
	
	    err = __TCP_ReadFully(cliSd, str, len);
	    if (err != len)
	        goto sml_accept_error;
	
		///* debug
		fprintf(stderr, "client's public key: \n%s\n\n", str);
		//*/

	    CRYPTO_ParseKeyFromString(&session->key[PUBLIC_KEY], str);

		/////////////////////////////////////////////////////////////
		// receive client's encrypted public key and check integrity
		err = __TCP_ReadFully(cliSd, (char*)&session->__client_signature_len, sizeof(unsigned int));
		if (err != sizeof(unsigned int))
		    goto sml_accept_error;
	    
	    if ((session->__client_signature = (char*)realloc(session->__client_signature, sizeof(char) * session->__client_signature_len)) == NULL)
	    	goto sml_accept_error;
	
		err = __TCP_ReadFully(cliSd, session->__client_signature, session->__client_signature_len);
		if (err != session->__client_signature_len)
			goto sml_accept_error;

	
	    //////////////////////////////////////////////////////////////
	    // perform client identity authentication
		integrity = CRYPTO_CheckSignedPublicKey(session->__client_signature, session->__client_signature_len, session->key[PUBLIC_KEY]);
	    if (integrity == FALSE) {
	    	fprintf(stderr, "integrity checking failed!\n");
	    	if (send (cliSd, &integrity, sizeof(int), 0) == SOCKET_ERROR)
	        	goto sml_accept_error;
	    }
	    else {	    
		    if (list_filename == NULL) {
		    	valid_user = 1;
		    	fprintf(stderr, "user unknown authorized.\n");
		    	if (send (cliSd, &integrity, sizeof(int), 0) == SOCKET_ERROR)
		        	goto sml_accept_error;
		    }
		    else {
		    	valid_user = CRYPTO_QueryUserFromListFile(user_name, 1024, session->key[PUBLIC_KEY], list_filename);
		    	if (send (cliSd, &integrity, sizeof(int), 0) == SOCKET_ERROR)
		        	goto sml_accept_error;
		
		    	if (valid_user)
		    		fprintf(stderr, "user [%s] authorized.\n", user_name);
		    	else {
		    		fprintf(stderr, "anonymous user!\n");
		    		session->anonymous_mode = 1;
		    	}
			}
		}
	}

    //////////////////////////////////////////////////////////////
	// release memory
	if (str)
		free(str);
	if (str2)
		free(str2);
		
    //////////////////////////////////////////////////////////////
	// if all ok, return a new SML_SESSION for the new connection
    if (integrity) {
    	new_session = (SML_SESSION *)malloc(sizeof(SML_SESSION));
    	SML_InitSession(new_session);
    	
    	new_session->socket = cliSd;
    	new_session->anonymous_mode = session->anonymous_mode;
    	CRYPTO_CopyKey(&new_session->key[MY_PUBLIC_KEY], session->key[MY_PUBLIC_KEY]);
    	CRYPTO_CopyKey(&new_session->key[MY_PRIVATE_KEY], session->key[MY_PRIVATE_KEY]);
    	CRYPTO_CopyKey(&new_session->key[PUBLIC_KEY], session->key[PUBLIC_KEY]);
    	
    	new_session->__signature_len = session->__signature_len;
    	if (new_session->__signature_len > 0) {
    		new_session->__signature = (char*)malloc(sizeof(char) * new_session->__signature_len);
    		memcpy(new_session->__signature, session->__signature, session->__signature_len);
    	}
    	
    	new_session->__client_signature_len = session->__client_signature_len;
    	if (new_session->anonymous_mode == 0) {
	    	new_session->__client_signature = (char*)malloc(sizeof(char) * new_session->__client_signature_len);
	    	memcpy(new_session->__client_signature, session->__client_signature, session->__client_signature_len);
    	}

	    // default ECP and RPS setting
	    SML_InitRps(
	    	new_session,
		    SML_DEFAULT_ECP_I, 
		    SML_DEFAULT_ECP_P, 
		    SML_DEFAULT_ECP_B, 
		    SML_DEFAULT_ECP_PKT_SIZE, 
		    SML_RPS_MULTI_KEY_RSA, 
		    1024
	    );
	        
    	return new_session;
    }
    else {
    	return NULL;
    }

sml_accept_error:

	if (cliSd != INVALID_SOCKET)
		close(cliSd);

	if (str)
		free(str);
	
	if (str2)
		free(str2);
	
	return NULL;
}



void SML_InitRps(SML_SESSION* session, int I, int P, int B, int pkt_size, int algorithm, int bit_len)
{
	CRYPTO_KEY_PAIR pair, pair2;
	char *temp = NULL;
	int user_encrypt_size, actual_encrypt_size;
	int i, j;
	
	CRYPTO_InitKeyPair(&pair);
	CRYPTO_InitKeyPair(&pair2);

	// create the initial RPS encryption key and a re-encryption/decryption key pair
	switch (algorithm) {
	case SML_RPS_MULTI_KEY_RSA:
	default:
		MKRSA_NewKeyPair(&pair, bit_len, MKRSA_RANDOM_PUBLIC_EXPONENT);
		MKRSA_NewReEncryptionKeyPair(&pair2, pair.private_key);
		break;
	}
	
	CRYPTO_CopyKey(&session->key[RPS_KEY], pair.private_key);
	
	session->__reenc_key_len = CRYPTO_KeyToStringAlloc(&session->__reenc_key, pair2.private_key);

	// encrypte the two decryption keys with clients public key
	session->__dec_key_len = CRYPTO_KeyToStringAlloc(&temp, pair.public_key) * 2;
	session->__dec_key = (char*)realloc(session->__dec_key, sizeof(char) * session->__dec_key_len);
	CRYPTO_Encrypt(
		session->__dec_key, 
		&session->__dec_key_len,
		temp,
		session->__dec_key_len/2, 
		session->key[PUBLIC_KEY]);

	// DEBUG
	fprintf(stderr, "[DEBUG] RPS Initial Decryption Key: \n%s\n\n", temp);
	// DEBUG - END

	session->__dec_key2_len = CRYPTO_KeyToStringAlloc(&temp, pair2.public_key) * 2;
	session->__dec_key2 = (char*)realloc(session->__dec_key2, sizeof(char) * session->__dec_key2_len);
	CRYPTO_Encrypt(
		session->__dec_key2, 
		&session->__dec_key2_len,
		temp,
		session->__dec_key2_len/2, 
		session->key[PUBLIC_KEY]);
	
	// DEBUG	
	fprintf(stderr, "[DEBUG] RPS Reencryption's Decryption Key: \n%s\n\n", temp);
	// DEBUG - END
	
	CRYPTO_DestroyKeyPair(&pair);
	CRYPTO_DestroyKeyPair(&pair2);
	free(temp);
	
	session->I = I;
	session->B = B;
	session->P = P;
	session->pkt_size = pkt_size;

	/////////////////////////////////////////////////////////////////////////////
	// calculate encryption related parameters
	session->__plaintext_pkt_size = pkt_size;

	// unit size is the data size for each encryption operation
	session->__plaintext_unit_size = CRYPTO_GetPlaintextUnitSize(session->key[RPS_KEY]);
	session->__plaintext_unit_size_mp = session->__plaintext_unit_size / (mp_bits_per_limb / 8);
	session->__cipher_unit_size = CRYPTO_GetCipherUnitSize(session->key[RPS_KEY]);
	session->__cipher_unit_size_mp = session->__cipher_unit_size / (mp_bits_per_limb / 8);
	
	// encrypt_size is the actual size inside a packet to be encrypted
	user_encrypt_size = pkt_size * session->P / 100;
	actual_encrypt_size = user_encrypt_size - user_encrypt_size % session->__plaintext_unit_size;
	if (actual_encrypt_size == 0)
		actual_encrypt_size = session->__plaintext_unit_size;
	
	// encrypt_unit is the no. of unit data to be encrypted
	session->__encrypt_unit_total = actual_encrypt_size / session->__plaintext_unit_size;
	if (session->B > session->__encrypt_unit_total)
		session->B = session->__encrypt_unit_total;
	session->__encrypt_unit_per_block = 
		(session->__encrypt_unit_total + session->B -1) / session->B;
	session->__encrypt_unit_last_block =
		session->__encrypt_unit_total - 
		session->__encrypt_unit_per_block * (session->B - 1);

	// compute the size of the unencrypt part in each ECP block
	session->__unencrypt_size_per_block = (pkt_size / session->B) - 
		session->__encrypt_unit_per_block * session->__plaintext_unit_size;

	// compute the cipher pkt size
	session->__cipher_pkt_size = 
		session->__plaintext_pkt_size + 
		session->__encrypt_unit_total * (session->__cipher_unit_size - session->__plaintext_unit_size) +
		4; // additional 4 bytes for ECP seq. no.

	// compute the non-encrypt portion on last block
	session->__unencrypt_size_last_block = pkt_size - 
		session->__unencrypt_size_per_block * (session->B - 1) -
		session->__plaintext_unit_size * session->__encrypt_unit_per_block * (session->B-1) -
		session->__plaintext_unit_size * session->__encrypt_unit_last_block;
	
	// pre-malloc the buffer for the cipher packet
	session->__cipher_buffer = 
		(char*)realloc(session->__cipher_buffer, sizeof(char) * session->__cipher_pkt_size);

	// pre-malloc the buffer for the plaintext packet
	session->__plaintext_buffer = 
		(char*)realloc(session->__plaintext_buffer, sizeof(char) * session->__plaintext_pkt_size);

	// DEBUG
	fprintf(stderr, "[DEBUG] encryption related parameters: I (%d) P (%d) B (%d)\n\n", session->I, session->P, session->B);
	fprintf(stderr, "\tinput packet size: %d\n", session->__plaintext_pkt_size);
	fprintf(stderr, "\toutput packet size: %d\n", session->__cipher_pkt_size);
	fprintf(stderr, "\tencryption blocks: %d\n", session->B);
	fprintf(stderr, "\tencryption units total: %d\n", session->__encrypt_unit_total);
	fprintf(stderr, "\tencryption units per block: %d\n", session->__encrypt_unit_per_block);
	fprintf(stderr, "\tencryption units on last block: %d\n", session->__encrypt_unit_last_block);
	fprintf(stderr, "\tunencrypt size per block: %d\n", session->__unencrypt_size_per_block);
	fprintf(stderr, "\n");

	for (i=0; i<session->B-1; i++) {
		fprintf(stderr, "\tnon-encryption: %d\n", session->__unencrypt_size_per_block);
		fprintf(stderr, "\tencrypt: ");
		for (j=0; j<session->__encrypt_unit_per_block; j++) {
			fprintf(stderr, "%d ", session->__plaintext_unit_size);
		}
		fprintf(stderr, "\n");
	}
	fprintf(stderr, "\tnon-encryption: %d\n", session->__unencrypt_size_last_block);
	fprintf(stderr, "\tencrypt: ");
	for (j=0; j<session->__encrypt_unit_last_block; j++) {
		fprintf(stderr, "%d ", session->__plaintext_unit_size);
	}
	fprintf(stderr, "\n\n");
	// DEBUG - END
}

void RPS_Init(RPS* rps)
{
    rps->I        = SML_DEFAULT_ECP_I;
    rps->P        = SML_DEFAULT_ECP_P;
    rps->B        = SML_DEFAULT_ECP_B;
    rps->pkt_size = SML_DEFAULT_ECP_PKT_SIZE;
    
    rps->__reenc_key     = NULL;
    rps->__dec_key       = NULL;
    rps->__dec_key2      = NULL;
    rps->__reenc_key_len = 0;
    rps->__dec_key_len   = 0;
    rps->__dec_key2_len  = 0;
}

void RPS_Destroy(RPS* rps)
{
	if (rps->__reenc_key) free(rps->__reenc_key);
	if (rps->__dec_key)   free(rps->__dec_key);
	if (rps->__dec_key2)  free(rps->__dec_key2);
    rps->__reenc_key     = NULL;
    rps->__dec_key       = NULL;
    rps->__dec_key2      = NULL;
    rps->__reenc_key_len = 0;
    rps->__dec_key_len   = 0;
    rps->__dec_key2_len  = 0;
}

void SML_GetRpsSetting(SML_SESSION* session, RPS* rps)
{
    rps->I        = session->I;
    rps->P        = session->P;
    rps->B        = session->B;
    rps->pkt_size = session->pkt_size;
    
    rps->__reenc_key_len = session->__reenc_key_len;
    rps->__dec_key_len   = session->__dec_key_len;
    rps->__dec_key2_len  = session->__dec_key2_len;

    rps->__reenc_key     = realloc(rps->__reenc_key, rps->__reenc_key_len);
    rps->__dec_key       = realloc(rps->__dec_key, rps->__dec_key_len);
    rps->__dec_key2      = realloc(rps->__dec_key2, rps->__dec_key2_len);
    
    memcpy(rps->__reenc_key, session->__reenc_key, rps->__reenc_key_len);
    memcpy(rps->__dec_key, session->__dec_key, rps->__dec_key_len);
    memcpy(rps->__dec_key2, session->__dec_key2, rps->__dec_key2_len);
}

void SML_SetRpsSetting(SML_SESSION* session, RPS rps)
{
    session->I        = rps.I;
    session->P        = rps.P;
    session->B        = rps.B;
    session->pkt_size = rps.pkt_size;

    // set the second decryption key as the first decryption key
    // as the client need the second decryption key to decrypt whose
    // re-encrypted packets
    session->__reenc_key_len = rps.__reenc_key_len;
    session->__dec_key_len   = rps.__dec_key2_len;

    session->__reenc_key     = realloc(session->__reenc_key, session->__reenc_key_len);
    session->__dec_key       = realloc(session->__dec_key, session->__dec_key_len);

    memcpy(session->__reenc_key, rps.__reenc_key, session->__reenc_key_len);
    memcpy(session->__dec_key, rps.__dec_key2, session->__dec_key_len);

    session->__dec_key2_len  = 0;
    if (session->__dec_key2) {
    	free(session->__dec_key2);
	    session->__dec_key2 = NULL;
	}
}

void SML_NewRpsReEncryptionKey(SML_SESSION *session)
{
	CRYPTO_KEY_PAIR pair;
	char *temp = NULL;
	
	CRYPTO_InitKeyPair(&pair);

	// create a new re-encryption/decryption key pair
	switch (session->key[RPS_KEY].type) {
	case CRYPTO_KEY_RSA:
	default:
		MKRSA_NewReEncryptionKeyPair(&pair, session->key[RPS_KEY]);
		break;
	}
	
	session->__reenc_key_len = CRYPTO_KeyToStringAlloc(&session->__reenc_key, pair.private_key);

	// encrypte the decryption keys with client's public key
	session->__dec_key2_len = CRYPTO_KeyToStringAlloc(&temp, pair.public_key) * 2;
	session->__dec_key2 = (char*)realloc(session->__dec_key2, sizeof(char) * session->__dec_key2_len);
	CRYPTO_Encrypt(
		session->__dec_key2, 
		&session->__dec_key2_len,
		temp,
		session->__dec_key2_len/2, 
		session->key[PUBLIC_KEY]);

	// DEBUG
	fprintf(stderr, "[DEBUG] New RPS Decryption Key: \n%s\n\n", temp);
	// DEBUG - END
	
	CRYPTO_DestroyKeyPair(&pair);
}

void SML_SaveRpsSetting(SML_SESSION* session, char filename[], char passwd[])
{
	char *str = NULL;
	int len;
	FILE *fp;
	
	// open output file
	if ((fp = fopen(filename, "wb")) == NULL) {
		return; // error
	}
	
	// initial encryption key
	len = CRYPTO_KeyToStringAlloc(&str, session->key[RPS_KEY]);	
	if (str) {
		fwrite(&len, sizeof(int), 1, fp);
		fwrite(str, len, 1, fp);
		free(str);
	}
	else { // error
		fclose(fp);
		return;
	}
	
	// initial decryption key
	fwrite(&session->__dec_key_len, sizeof(int), 1, fp);
	fwrite(session->__dec_key, session->__dec_key_len, 1, fp);

	// encryption parameters
	fwrite(&session->I, sizeof(unsigned int), 1, fp);        // I
    fwrite(&session->P, sizeof(unsigned int), 1, fp);        // P
    fwrite(&session->B, sizeof(unsigned int), 1, fp);        // B
    fwrite(&session->pkt_size, sizeof(unsigned int), 1, fp); // pkt_size
    
    fwrite(&session->__cipher_unit_size_mp, sizeof(unsigned int), 1, fp);       // cipher unit size in mp
    fwrite(&session->__cipher_unit_size, sizeof(unsigned int), 1, fp);          // cipher unit size in byte
    fwrite(&session->__plaintext_unit_size_mp, sizeof(unsigned int), 1, fp);    // plaintext unit size in mp
    fwrite(&session->__plaintext_unit_size, sizeof(unsigned int), 1, fp);       // plaintext unit size in byte
    fwrite(&session->__encrypt_unit_total, sizeof(unsigned int), 1, fp);        // total units to encrypt
    fwrite(&session->__encrypt_unit_per_block, sizeof(unsigned int), 1, fp);    // encyrpt units per block
    fwrite(&session->__encrypt_unit_last_block, sizeof(unsigned int), 1, fp);   // encyrpt units on last block
    fwrite(&session->__plaintext_pkt_size, sizeof(unsigned int), 1, fp);        // plaintext packet size
    fwrite(&session->__cipher_pkt_size, sizeof(unsigned int), 1, fp);           // cipher packet size
    fwrite(&session->__unencrypt_size_last_block, sizeof(unsigned int), 1, fp); // non-encrypt size on last block
    fwrite(&session->__unencrypt_size_per_block, sizeof(unsigned int), 1, fp);  // non-encrypt size on other blocks
    
    fclose(fp);
}

void SML_LoadRpsSetting(SML_SESSION* session, char filename[], char passwd[])
{
	char *str = NULL, *temp = NULL;
	int len;
	FILE *fp;
	CRYPTO_KEY_PAIR pair;
	
	// open output file
	if ((fp = fopen(filename, "rb")) == NULL) {
		return; // error
	}
	
	// initial encryption key
	fread(&len, sizeof(int), 1, fp);
	str = (char*)malloc(len);
	if (str) {
		fread(str, len, 1, fp);
	}
	else { // error
		fclose(fp);
		return;
	}
	CRYPTO_ParseKeyFromString(&session->key[RPS_KEY], str);
	
	fread(&session->__dec_key_len, sizeof(int), 1, fp);
	session->__dec_key = (char*)realloc(session->__dec_key, sizeof(char) * session->__dec_key_len);
	if (session->__dec_key) {
		fread(session->__dec_key, session->__dec_key_len, 1, fp);
	}

	// encryption parameters
	fread(&session->I, sizeof(unsigned int), 1, fp);        // I
    fread(&session->P, sizeof(unsigned int), 1, fp);        // P
    fread(&session->B, sizeof(unsigned int), 1, fp);        // B
    fread(&session->pkt_size, sizeof(unsigned int), 1, fp); // pkt_size
    
    fread(&session->__cipher_unit_size_mp, sizeof(unsigned int), 1, fp);       // cipher unit size in mp
    fread(&session->__cipher_unit_size, sizeof(unsigned int), 1, fp);          // cipher unit size in byte
    fread(&session->__plaintext_unit_size_mp, sizeof(unsigned int), 1, fp);    // plaintext unit size in mp
    fread(&session->__plaintext_unit_size, sizeof(unsigned int), 1, fp);       // plaintext unit size in byte
    fread(&session->__encrypt_unit_total, sizeof(unsigned int), 1, fp);        // total units to encrypt
    fread(&session->__encrypt_unit_per_block, sizeof(unsigned int), 1, fp);    // encyrpt units per block
    fread(&session->__encrypt_unit_last_block, sizeof(unsigned int), 1, fp);   // encyrpt units on last block
    fread(&session->__plaintext_pkt_size, sizeof(unsigned int), 1, fp);        // plaintext packet size
    fread(&session->__cipher_pkt_size, sizeof(unsigned int), 1, fp);           // cipher packet size
    fread(&session->__unencrypt_size_last_block, sizeof(unsigned int), 1, fp); // non-encrypt size on last block
    fread(&session->__unencrypt_size_per_block, sizeof(unsigned int), 1, fp);  // non-encrypt size on other blocks
    
    fclose(fp);
    
	// create the re-encryption/decryption key pair
	CRYPTO_InitKeyPair(&pair);	

	switch (session->key[RPS_KEY].type) {
	case CRYPTO_KEY_RSA:
	default:
		MKRSA_NewReEncryptionKeyPair(&pair, session->key[RPS_KEY]);
		break;
	}
	
	session->__reenc_key_len = CRYPTO_KeyToStringAlloc(&session->__reenc_key, pair.private_key);

	// encrypte the decryption keys with clients public key
	session->__dec_key2_len = CRYPTO_KeyToStringAlloc(&temp, pair.public_key) * 2;
	session->__dec_key2 = (char*)realloc(session->__dec_key2, sizeof(char) * session->__dec_key2_len);

	CRYPTO_Encrypt(
		session->__dec_key2,
		&session->__dec_key2_len,
		temp,
		session->__dec_key2_len/2, 
		session->key[PUBLIC_KEY]);

	// DEBUG	
	fprintf(stderr, "[DEBUG] RPS Reencryption's Decryption Key: \n%s\n\n", temp);
	// DEBUG - END
}

void SML_SendRpsSetting(SML_SESSION* session)
{
	// initial decryption key
    send(session->socket, &session->__dec_key_len, sizeof(int), 0);
    send(session->socket, session->__dec_key, session->__dec_key_len, 0);
    
    fprintf(stderr, "send dec key len: %d\n", session->__dec_key_len);
    
    // reencryption key
    send(session->socket, &session->__reenc_key_len, sizeof(int), 0);
    send(session->socket, session->__reenc_key, session->__reenc_key_len, 0);

    fprintf(stderr, "send reenc key len: %d\n", session->__reenc_key_len);
    
    // decryption key for reencryption
    send(session->socket, &session->__dec_key2_len, sizeof(int), 0);
    send(session->socket, session->__dec_key2, session->__dec_key2_len, 0);

    fprintf(stderr, "send dec key2 len: %d\n", session->__dec_key2_len);
    
    // encryption parameters
    send(session->socket, &session->I, sizeof(unsigned int), 0);        // I
    send(session->socket, &session->P, sizeof(unsigned int), 0);        // P
    send(session->socket, &session->B, sizeof(unsigned int), 0);        // B
    send(session->socket, &session->pkt_size, sizeof(unsigned int), 0); // pkt_size
    
    send(session->socket, &session->__cipher_unit_size_mp, sizeof(unsigned int), 0);       // cipher unit size in mp
    send(session->socket, &session->__cipher_unit_size, sizeof(unsigned int), 0);          // cipher unit size in byte
    send(session->socket, &session->__plaintext_unit_size_mp, sizeof(unsigned int), 0);    // plaintext unit size in mp
    send(session->socket, &session->__plaintext_unit_size, sizeof(unsigned int), 0);       // plaintext unit size in byte
    send(session->socket, &session->__encrypt_unit_total, sizeof(unsigned int), 0);        // total units to encrypt
    send(session->socket, &session->__encrypt_unit_per_block, sizeof(unsigned int), 0);    // encyrpt units per block
    send(session->socket, &session->__encrypt_unit_last_block, sizeof(unsigned int), 0);   // encyrpt units on last block
    send(session->socket, &session->__plaintext_pkt_size, sizeof(unsigned int), 0);        // plaintext packet size
    send(session->socket, &session->__cipher_pkt_size, sizeof(unsigned int), 0);           // cipher packet size
    send(session->socket, &session->__unencrypt_size_last_block, sizeof(unsigned int), 0); // non-encrypt size on last block
    send(session->socket, &session->__unencrypt_size_per_block, sizeof(unsigned int), 0);  // non-encrypt size on other blocks
}

void SML_ReceiveRpsSetting(SML_SESSION* session)
{
	int len, i, j;
	char *temp;
	
	// receive the decryption key for initial decryption
	__TCP_ReadFully(session->socket, (char*)&session->__dec_key_len, sizeof(unsigned int));
    session->__dec_key = (char*)realloc(session->__dec_key, sizeof(char) * session->__dec_key_len);
	__TCP_ReadFully(session->socket, session->__dec_key, session->__dec_key_len);
	
	fprintf(stderr, "Receive Rps Setting (dec key len): %d\n", session->__dec_key_len);

	// receive the reencryption key
	__TCP_ReadFully(session->socket, (char*)&session->__reenc_key_len, sizeof(unsigned int));
    session->__reenc_key = (char*)realloc(session->__reenc_key, sizeof(char) * session->__reenc_key_len);
	__TCP_ReadFully(session->socket, session->__reenc_key, session->__reenc_key_len);

	// receive the decryption key for reencryption
	__TCP_ReadFully(session->socket, (char*)&session->__dec_key2_len, sizeof(unsigned int));
    session->__dec_key2 = (char*)realloc(session->__dec_key2, sizeof(char) * session->__dec_key2_len);
	__TCP_ReadFully(session->socket, session->__dec_key2, session->__dec_key2_len);

    __TCP_ReadFully(session->socket, (char*)&session->I, sizeof(unsigned int));        // I
    __TCP_ReadFully(session->socket, (char*)&session->P, sizeof(unsigned int));        // P
    __TCP_ReadFully(session->socket, (char*)&session->B, sizeof(unsigned int));        // B
    __TCP_ReadFully(session->socket, (char*)&session->pkt_size, sizeof(unsigned int)); // pkt_size
    
    __TCP_ReadFully(session->socket, (char*)&session->__cipher_unit_size_mp, sizeof(unsigned int));       // cipher unit size in mp
    __TCP_ReadFully(session->socket, (char*)&session->__cipher_unit_size, sizeof(unsigned int));          // cipher unit size in byte
    __TCP_ReadFully(session->socket, (char*)&session->__plaintext_unit_size_mp, sizeof(unsigned int));    // plaintext unit size in mp
    __TCP_ReadFully(session->socket, (char*)&session->__plaintext_unit_size, sizeof(unsigned int));       // plaintext unit size in byte
    __TCP_ReadFully(session->socket, (char*)&session->__encrypt_unit_total, sizeof(unsigned int));        // total units to encrypt
    __TCP_ReadFully(session->socket, (char*)&session->__encrypt_unit_per_block, sizeof(unsigned int));    // encyrpt units per block
    __TCP_ReadFully(session->socket, (char*)&session->__encrypt_unit_last_block, sizeof(unsigned int));   // encyrpt units on last block
    __TCP_ReadFully(session->socket, (char*)&session->__plaintext_pkt_size, sizeof(unsigned int));        // plaintext packet size
    __TCP_ReadFully(session->socket, (char*)&session->__cipher_pkt_size, sizeof(unsigned int));           // cipher packet size
    __TCP_ReadFully(session->socket, (char*)&session->__unencrypt_size_last_block, sizeof(unsigned int)); // non-encrypt size on last block
    __TCP_ReadFully(session->socket, (char*)&session->__unencrypt_size_per_block, sizeof(unsigned int));  // non-encrypt size on other blocks

	// pre-malloc the buffer for the cipher
	session->__cipher_buffer = 
		(char*)realloc(session->__cipher_buffer, sizeof(char) * session->__cipher_pkt_size);

	// pre-malloc the buffer for the plaintext packet
	session->__plaintext_buffer = 
		(char*)realloc(session->__plaintext_buffer, sizeof(char) * session->__plaintext_pkt_size);

	// DEBUG
	fprintf(stderr, "[DEBUG] encryption related parameters: I (%d) P (%d) B (%d)\n\n", session->I, session->P, session->B);
	fprintf(stderr, "\tinput packet size: %d\n", session->__plaintext_pkt_size);
	fprintf(stderr, "\toutput packet size: %d\n", session->__cipher_pkt_size);
	fprintf(stderr, "\tencryption blocks: %d\n", session->B);
	fprintf(stderr, "\tencryption units total: %d\n", session->__encrypt_unit_total);
	fprintf(stderr, "\tencryption units per block: %d\n", session->__encrypt_unit_per_block);
	fprintf(stderr, "\tencryption units on last block: %d\n", session->__encrypt_unit_last_block);
	fprintf(stderr, "\tunencrypt size per block: %d\n", session->__unencrypt_size_per_block);
	fprintf(stderr, "\n");
	
	for (i=0; i<session->B-1; i++) {
		fprintf(stderr, "\tnon-encryption: %d\n", session->__unencrypt_size_per_block);
		fprintf(stderr, "\tencrypt: ");
		for (j=0; j<session->__encrypt_unit_per_block; j++) {
			fprintf(stderr, "%d ", session->__plaintext_unit_size);
		}
		fprintf(stderr, "\n");
	}
	fprintf(stderr, "\tnon-encryption: %d\n", session->__unencrypt_size_last_block);
	fprintf(stderr, "\tencrypt: ");
	for (j=0; j<session->__encrypt_unit_last_block; j++) {
		fprintf(stderr, "%d ", session->__plaintext_unit_size);
	}
	fprintf(stderr, "\n\n");

	// DEBUG - END
}


int SML_GetRpsEncryptedPacketSize(SML_SESSION* session)
{
	// maximum size is cipher_packet_size including the 4 bytes RPS seq. no. 
	return  session->__cipher_pkt_size;
}



int SML_TcpSendEncryptRps(SML_SESSION* session, char packet_data[], unsigned int len)
{
	int block, sub_block;
	char *cipher = session->__cipher_buffer, *data = session->__plaintext_buffer;
	mpz_t mpz_data;

	//fprintf(stderr, "seq: %d...", session->seq);
	
	////////////////////////////////////////////////////////////
	// packet need encryption
	if (session->seq % session->I == 0) {

		// init buffers
		memset(data, 0, session->__plaintext_pkt_size); // padding packet with zeros
		memcpy(data, packet_data, len);
		memset(cipher, 0, session->__cipher_pkt_size);
		mpz_init2(mpz_data, 4096);

		// embed ECP seq. no. in first 4 bytes of the packet
		memcpy(cipher, (char*)&session->seq, 4);
		cipher = cipher + 4;

		// encrypt each ECP block (excluse last block)
		for (block=0; block<session->B - 1; block++) {
			// copy non-enrypt portion
			memcpy(cipher, data, session->__unencrypt_size_per_block);
			cipher += session->__unencrypt_size_per_block;
			data += session->__unencrypt_size_per_block;

			// encrypt the encrypt portion
			for (sub_block=0; sub_block<session->__encrypt_unit_per_block; sub_block++) {
				mpz_data->_mp_size = session->__plaintext_unit_size_mp;
				memcpy(mpz_data->_mp_d, data, session->__plaintext_unit_size);
			
				mpz_powm(mpz_data, mpz_data, 
						session->key[RPS_KEY].param[RSA_PARAM_EXPONENT],
						session->key[RPS_KEY].param[RSA_PARAM_MODULUS]
						);

				memcpy(cipher, mpz_data->_mp_d, mpz_data->_mp_size*(mp_bits_per_limb / 8));
				
				cipher += session->__cipher_unit_size;
				data += session->__plaintext_unit_size;
			}
		}

		// encrypt last ECP block
		memcpy(cipher, data, session->__unencrypt_size_per_block);
		cipher += session->__unencrypt_size_last_block;
		data += session->__unencrypt_size_last_block;

		for (sub_block=0; sub_block<session->__encrypt_unit_last_block; sub_block++) {
			mpz_data->_mp_size = session->__plaintext_unit_size_mp;
			memcpy(mpz_data->_mp_d, data, session->__plaintext_unit_size);
		
			mpz_powm(mpz_data, mpz_data, 
						session->key[RPS_KEY].param[RSA_PARAM_EXPONENT],
						session->key[RPS_KEY].param[RSA_PARAM_MODULUS]
						);

			memcpy(cipher, mpz_data->_mp_d, mpz_data->_mp_size*(mp_bits_per_limb / 8));
			
			cipher += session->__cipher_unit_size;
			data += session->__plaintext_unit_size;
		}
		
		mpz_clear(mpz_data);
	}
	/////////////////////////////////////////////////////////////////
	// packet without encryption
	else {
		
		// init buffers
		memset(cipher, 0, session->__cipher_pkt_size);

		// embed ECP seq. no. in first 4 bytes of the packet
		memcpy(cipher, (char*)&session->seq, 4);
		cipher = cipher + 4;

		// copy packet data
		memcpy(cipher, packet_data, len);
	}

	// send encrypted packet through TCP socket
	session->seq++;

	return send(session->socket, (char*)session->__cipher_buffer, session->__cipher_pkt_size, 0);
}

int SML_TcpSendReEncryptRps(SML_SESSION* session, char packet_data[], unsigned int len)
{
	int block, sub_block;
	char *cipher = session->__cipher_buffer;
	mpz_t mpz_data;
	int seq;

	// extract the ECP seq. no.
	memcpy((char*)&seq, packet_data, 4);

	//fprintf(stderr, "seq: %d...", seq);

	////////////////////////////////////////////////////////////
	// packet need encryption
	if (seq % session->I == 0) {

		// init buffers and embed ECP seq. no. in first 4 bytes of the packet
		memset(cipher, 0, session->__cipher_pkt_size);
		memcpy(cipher, packet_data, len);
		cipher = cipher + 4; // skip 4 bytes ECP seq. no.
		mpz_init2(mpz_data, 4096);

		// encrypt each ECP block (excluse last block)
		for (block=0; block<session->B - 1; block++) {
			// skip non-enrypt portion
			cipher += session->__unencrypt_size_per_block;

			// encrypt the encrypt portion
			for (sub_block=0; sub_block<session->__encrypt_unit_per_block; sub_block++) {
				mpz_data->_mp_size = session->__plaintext_unit_size_mp;
				memcpy(mpz_data->_mp_d, cipher, session->__plaintext_unit_size);
			
				mpz_powm(mpz_data, mpz_data, 
						session->key[RPS_KEY].param[RSA_PARAM_EXPONENT],
						session->key[RPS_KEY].param[RSA_PARAM_MODULUS]
						);

				memcpy(cipher, mpz_data->_mp_d, mpz_data->_mp_size*(mp_bits_per_limb / 8));
				
				cipher += session->__cipher_unit_size;
			}
		}

		// encrypt last ECP block
		cipher += session->__unencrypt_size_last_block;

		for (sub_block=0; sub_block<session->__encrypt_unit_last_block; sub_block++) {
			mpz_data->_mp_size = session->__plaintext_unit_size_mp;
			memcpy(mpz_data->_mp_d, cipher, session->__plaintext_unit_size);
		
			mpz_powm(mpz_data, mpz_data, 
						session->key[RPS_KEY].param[RSA_PARAM_EXPONENT],
						session->key[RPS_KEY].param[RSA_PARAM_MODULUS]
						);

			memcpy(cipher, mpz_data->_mp_d, mpz_data->_mp_size*(mp_bits_per_limb / 8));
		}
		
		mpz_clear(mpz_data);
	}
	/////////////////////////////////////////////////////////////////
	// packet without encryption
	else {

		// init buffers
		memset(cipher, 0, session->__cipher_pkt_size);

		// copy packet data
		memcpy(cipher, packet_data, len);
	}

	return send(session->socket, (char*)session->__cipher_buffer, session->__cipher_pkt_size, 0);
}

int SML_TcpReceiveDecryptRps(SML_SESSION* session, char packet_data[], unsigned int len)
{
	char *temp;
	int block, sub_block, seq;
	char *cipher = session->__cipher_buffer, *data = session->__plaintext_buffer;
	mpz_t mpz_data;

	////////////////////////////////////////////////////////////
	// retrieve the decryption key on first decryption
	if (session->__dec_key != NULL) {
	
		temp = (char*) malloc(sizeof(char) * session->__dec_key_len);
		
		fprintf(stderr, "len: %d\n", session->__dec_key_len);
		
		CRYPTO_Decrypt(
			temp, 
			&session->__dec_key_len, 
			session->__dec_key, 
			session->__dec_key_len, 
			session->key[MY_PRIVATE_KEY]);
			
		CRYPTO_ParseKeyFromString(&session->key[RPS_KEY], temp);
		
		// DEBUG
		fprintf(stderr, "[DEBUG] RPS Decryption Key: \n%s\n\n", temp);
		// DEBUG - END

		free(temp);
		free(session->__dec_key);
		session->__dec_key = NULL;
	}

	// receive the packet
   	memset(cipher, 0, session->__cipher_pkt_size);
   	__TCP_ReadFully(session->socket, (char*)cipher, session->__cipher_pkt_size);

	memcpy((char*)&seq, cipher, 4);
	cipher = cipher + 4;

	//fprintf(stderr, "seq: %d I: %d...", seq, session->I);

	//////////////////////////////////////////////////////////////
	// packets those encrypted
	if (seq % session->I == 0) {

		// init buffers
		memset(data, 0, session->__plaintext_pkt_size);
		mpz_init2(mpz_data, 4096);
		
		// decrypt each ECP block (excluse last block)
		for (block=0; block<session->B - 1; block++) {
			// copy non-enrypt portion
			memcpy(data, cipher, session->__unencrypt_size_per_block);
			cipher += session->__unencrypt_size_per_block;
			data += session->__unencrypt_size_per_block;

			// decrypt the encrypt portion
			for (sub_block=0; sub_block<session->__encrypt_unit_per_block; sub_block++) {
				mpz_data->_mp_size = session->__cipher_unit_size_mp;
				memcpy(mpz_data->_mp_d, cipher, session->__cipher_unit_size);
			
				mpz_powm(mpz_data, mpz_data, 
						session->key[RPS_KEY].param[RSA_PARAM_EXPONENT],
						session->key[RPS_KEY].param[RSA_PARAM_MODULUS]
						);

				memcpy(data, mpz_data->_mp_d, mpz_data->_mp_size*(mp_bits_per_limb / 8));
				
				cipher += session->__cipher_unit_size;
				data += session->__plaintext_unit_size;
			}
		}

		// decrypt last ECP block
		memcpy(data, cipher, session->__unencrypt_size_per_block);
		cipher += session->__unencrypt_size_last_block;
		data += session->__unencrypt_size_last_block;

		for (sub_block=0; sub_block<session->__encrypt_unit_last_block; sub_block++) {
			mpz_data->_mp_size = session->__cipher_unit_size_mp;
			memcpy(mpz_data->_mp_d, data, session->__cipher_unit_size);
		
			mpz_powm(mpz_data, mpz_data, 
						session->key[RPS_KEY].param[RSA_PARAM_EXPONENT],
						session->key[RPS_KEY].param[RSA_PARAM_MODULUS]
						);

			memcpy(data, mpz_data->_mp_d, mpz_data->_mp_size*(mp_bits_per_limb / 8));
			
			cipher += session->__cipher_unit_size;
			data += session->__plaintext_unit_size;
		}
		
		mpz_clear(mpz_data);
	
		// copy plaintext to user packet buffer
		memcpy(packet_data, session->__plaintext_buffer, len);
	}
	//////////////////////////////////////////////////////////////
	// packet without encryption
	else {
		memcpy(packet_data, cipher, len);
	}
}


int SML_TcpReceiveRps(SML_SESSION* session, char packet_data[], unsigned int len)
{
	int seq;
	char *cipher = session->__cipher_buffer;

	// receive the packet
   	memset(cipher, 0, session->__cipher_pkt_size);
   	__TCP_ReadFully(session->socket, (char*)cipher, session->__cipher_pkt_size);

	memcpy((char*)&seq, cipher, 4);
	//fprintf(stderr, "seq: %d I: %d...", seq, session->I);

	memcpy(packet_data, cipher, len);
}

